commit fe485dd86d161e2e20efb6fc1d2d619a4106ef8e Author: urania Date: Mon Jun 22 16:06:57 2026 +0200 first commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ab844c4 --- /dev/null +++ b/.gitignore @@ -0,0 +1,15 @@ +# Build artifacts +/main +/server +/nadir + +# Local environment / secrets +.env +config.yaml +config.yml + +# Editor +*.swp +server + +CLAUDE.md \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..a43fbc3 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2026 urania + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..8fcecb0 --- /dev/null +++ b/README.md @@ -0,0 +1,497 @@ +# Nadir + +Nadir is a lightweight, modular Linux system-administration backend - a modern, +FOSS system admin panels. It exposes a typed REST API for the everyday tasks +you'd otherwise SSH in to do: inspect the host, manage systemd services, edit +local users and groups, install packages, and read logs - all behind +role-based access control and a tamper-evident audit trail. + +The API is generated with [Huma](https://huma.rocks) (OpenAPI 3.1) and ships +interactive docs at `/docs`. The backend is Go and self-contained: no external +database, no agent, no runtime dependencies beyond the standard system tools it +drives (`systemctl`, `hostnamectl`, `useradd`, the host package manager, …). + +--- + +## What it does + +Functionality is organized into **modules**. Each module owns a slice of the +API and declares its own permission vocabulary. + +- **System** - Dashboard overview (OS/kernel, CPU, memory, disks, load, uptime, + network interfaces, temperatures); get/set hostname; time, timezone, and NTP; + locale and console keymap; reboot and power off. +- **Services** - List and inspect systemd units; start / stop / restart / enable + / disable; read service logs from the journal or an allowlisted file, as a + snapshot or a live Server-Sent-Events stream. +- **Users** - List, inspect, create, and delete local accounts; set a password; + set supplementary groups. +- **Groups** - List, inspect, create, and delete local groups. +- **Packages** - List installed packages and available updates; install, remove, + and upgrade - streamed live over SSE. Auto-detects `dnf`, `apt`, or `pacman`. +- **Networking** - List network interfaces, routing tables, and DNS settings; configure IPv4 settings with temporary applying and safety auto-rollback; bring interfaces up or down. +- **Audit** - Read-only trail of every privileged write (who, what, when, result). +- **Terminal** - Interactive shell access. Upgrades connection to a WebSocket and spawns a PTY shell as the logged-in user (requires `root` permission). +- **Meta** - Self-description for clients: `/api/_modules`, `/api/whoami`, + `/api/health`. + +### Security model at a glance + +- **Authentication** is delegated to PAM (`pam_unix`), so logins use real system + credentials. A successful login sets an `HttpOnly`, `SameSite=Strict` session + cookie; sessions are stored in SQLite and survive restarts. +- **Machine credentials** for non-interactive callers (e.g. a central dashboard + managing many nodes) authenticate with a static `Authorization: Bearer nad_…` + token instead of a PAM session. Mint with `nadir token add ` (shown once, + only its SHA-256 is stored); revoke with `nadir token rm ` (immediate, no + restart). A token is an ordinary RBAC subject - its name is assigned a role in + `config.yaml` `assignments`, so a leaked token is scoped, not implicitly admin. + The audit trail records the actor as `token:` to distinguish it from a + human. CSRF does not apply: browsers never auto-attach a Bearer header, so the + same-origin cookie defense is irrelevant for token auth. Bad-token guesses are + throttled per source IP. +- **Authorization** is RBAC driven entirely by `config.yaml`. Every protected + operation declares a `module` and one of three permission tiers: + - `read` - inspect (list users, read status, view logs…) + - `write` - routine changes (create a user, restart a service, set the hostname…) + - `root` - high-impact or irreversible actions (reboot, delete an account, + **reset a password**, **change group membership**). Password and group- + membership changes are `root` precisely because they can hand someone root. +- **Brute-force throttling** on login (per username + source IP cooldown). +- **CSRF** defense via `SameSite=Strict` plus a same-origin check on writes. +- **Audit** of every mutation, written off the request path to SQLite. +- The server **must run as root** - PAM reads `/etc/shadow`, and the system + tools it drives (`hostnamectl`, `systemctl`, `useradd`, `shutdown`, …) require + it. + + + +--- + +## Installing + +### Prerequisites + +- Linux with **systemd** (the Services module and the `nadir` service wrapper + use it). +- **Root** access (see above). +- Go (recent) to build from source. + +### Build + +The entry point is the `main` package under `cmd/server`: + +```bash +go build -o nadir ./cmd/server +``` + +This produces a single static-ish binary, `nadir`. + +### Run directly + +On first start, `nadir` requires a configuration file to exist. If the configuration is missing, the server will fail to start and ask you to run `nadir install` (to install the systemd service) or use `--save-config`. + +To generate a default configuration file (assigning the admin role to your current user) without installing the systemd service: + +```bash +./nadir --save-config +``` + +To save it for the root user (who runs the server): + +```bash +sudo ./nadir --save-config +``` + +You can also specify a custom path using `-f`/`--config`: + +```bash +./nadir --save-config -f ./config.yaml +``` + +Once the configuration file is created, start the server directly: + +```bash +sudo ./nadir # same as: sudo ./nadir run +``` + +By default it reads `~/.config/config.yaml` (resolving to the running user's home, i.e., `/root/.config/config.yaml` when run as root); override with the `-f`/`--config` flag or `CONFIG_PATH` env var: + +```bash +sudo ./nadir -f /etc/nadir/config.yaml +# or: sudo CONFIG_PATH=/etc/nadir/config.yaml ./nadir +``` + +By default it serves **HTTPS** with a self-signed certificate (see +[Deployment note 2](#2-tls-three-modes)) on the `hostname:port` from the config, +and exposes interactive docs at `https://:/docs` and the raw spec at +`/openapi.json`. + +### Run in the background (`-d`) + +Like `docker run -d`, this detaches from the terminal and returns your shell: + +```bash +sudo ./nadir run -d +# nadir running in background (pid 12345); logs: /var/lib/nadir/server.log +# follow with: nadir logs +``` + +Output goes to `/var/lib/nadir/server.log`. + +### Install as a systemd service (start on boot) + +For a real deployment, register nadir as a service so it starts on boot and is +managed with the usual tooling: + +```bash +sudo ./nadir install # writes the unit, enables it, and starts it now +sudo ./nadir status +sudo ./nadir logs # follow the journal live +``` + +`install` writes `/etc/systemd/system/nadir.service` pinning the **absolute** +binary and the absolute config file path (so it doesn't depend on the working directory at +boot), runs `systemctl daemon-reload`, and `enable --now`. If no configuration file +exists at the target path, `install` automatically creates a default config file and +assigns the admin role to the installing user. + +### CLI reference + +| Command | Effect | +| ------------------------------------------------ | --------------------------------------------------------------------------- | +| `nadir [run] [-d]` | Start the server. `-d` / `--detach` runs it in the background. | +| `nadir --save-config` | Save the default configuration template to the target path and exit. | +| `nadir install` | Install + enable the systemd service (starts now and on boot). | +| `nadir uninstall` | Stop, disable, and remove the systemd service. | +| `nadir start` \| `stop` \| `restart` \| `status` | Control the running service. | +| `nadir enable` \| `disable` | Toggle start-on-boot without removing the unit. | +| `nadir logs` | Follow logs - journald if installed as a service, otherwise the detach log. | +| `nadir help` | Show usage. | + +Most commands need root. + +--- + +## Configuration (`config.yaml`) + +`config.yaml` is the single source of truth for runtime configuration: server +and TLS settings, which roles exist, what each role can do, and who holds which +role. By default, it reads `~/.config/config.yaml`. The path can be overridden using +the `-f` / `--config` CLI flags or the `CONFIG_PATH` environment variable. + +```yaml +server: + secure_tls: true # Secure flag on the session cookie (keep true behind TLS) + trust_proxy: true # a reverse proxy terminates TLS; see Deployment note 3 + # tls_cert: /etc/nadir/tls/cert.pem # or terminate TLS in nadir yourself + # tls_key: /etc/nadir/tls/key.pem + hostname: 100.64.0.189 + port: 9999 + +# Quote "*" - bare * is YAML alias syntax and fails to parse. +roles: + admin: + "*": ["*"] # every permission on every module (including future ones) + auditor: + "*": ["read"] # read-only everywhere + system_ops: + system: ["read", "write"] + +assignments: + urania: [admin] + +# Optional: per-unit allowlist of log files the Services module may read. +log_files: + nginx: + - /var/log/nginx/access.log + - /var/log/nginx/error.log +``` + +### `server` + +| Key | Default | Meaning | +| --------------------- | ------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `secure_tls` | `true` | Sets the `Secure` flag on the session cookie. Keep `true` whenever the browser reaches nadir over HTTPS (direct or via proxy); `false` only for local plain-HTTP dev. | +| `trust_proxy` | `false` | When `true`, nadir serves plaintext HTTP and trusts `X-Forwarded-For` / forwarded `Host` from the proxy. See [Deployment note 3](#3-reverse-proxy--vpn). | +| `tls_cert`, `tls_key` | - | PEM paths. When both are set (and `trust_proxy` is off), nadir terminates TLS with this pair. | +| `hostname` | - | Address to bind. Use `127.0.0.1` for local-only, or an overlay/VPN address to expose nadir only on that interface. | +| `port` | - | TCP port to listen on. | + +TLS selection is covered in [Deployment note 2](#2-tls-three-modes). + +### `roles` / `assignments` + +- `roles` maps a role name to `module → [permissions]`. `"*"` as the module key + means "all modules"; `"*"` in the permission list means "all permissions". +- `assignments` maps a username to the roles they hold; effective grants are the + union. +- `"*"` must be quoted - bare `*` is YAML alias syntax and fails to parse. +- Module keys and permissions are validated at startup against the modules + actually compiled in. An unknown module, an unexported permission, or an + assignment to an undefined role aborts startup with a clear message rather + than silently granting or denying access. +- Each module owns its permission vocabulary via `Permissions()`, so adding a + module automatically makes it available to wildcard roles and validatable for + restricted ones. Clients discover the live module/permission set at + `GET /api/_modules`, and a user's own grants at `GET /api/whoami`. + +### `log_files` + +An allowlist, keyed by unit, of log file paths the Services module is allowed to +read via the `source=file` log endpoints. The caller can only read paths an +admin has listed here - never an arbitrary file. + +--- + +## Deployment notes + +These notes capture the non-obvious operational decisions. They'll seed +the formal installation guide. + +### 1. PAM service + +**Nadir authenticates against its own PAM service, `/etc/pam.d/nadir`, and the +server creates that file on startup if it is missing** (see +`internal/auth/pamservice.go`). Here is why. + +#### What went wrong with stock services + +Originally we authenticated against the `"login"` service. On a Framework +laptop (and many other machines) `/etc/pam.d/login` pulls in `system-auth`, +whose auth stack lists `pam_fprintd.so` as `sufficient` **before** +`pam_unix.so`: + +``` +auth sufficient pam_fprintd.so # fingerprint, tried first +auth sufficient pam_unix.so nullok # password, only reached if fprintd fails +``` + +Our PAM conversation callback only answers the password prompt; it can't swipe +a finger. So `pam_fprintd` would start a fingerprint scan and **block until its +~30-second timeout** before falling through to the password check. Every login +took 30s. (It was never a network, D-Bus, systemd, or NSS problem — +`hostnamectl` was instant and there is no SSSD/LDAP on the box.) + +Switching to `"passwd"` is not a fix either: `/etc/pam.d/passwd` has only a +`password` stack and no `auth` stack, so it can't verify a login. + +#### The fix + +Ship a dedicated, minimal service - exactly what `sshd`, `cockpit`, and +`polkit` do. `/etc/pam.d/nadir` contains only: + +``` +#%PAM-1.0 +auth required pam_unix.so +account required pam_unix.so +``` + +That is a straight `/etc/shadow` password check plus an account-validity check +— no fingerprint, no systemd, no env loading, no DNS. Authentication drops from +~30s to milliseconds, and we stop inheriting whatever the distro's login stack +happens to do. + +Notes: + +- We omit `nullok` on purpose: this service is reachable over the network, and + `nullok` would let passwordless accounts log in. +- `EnsurePAMService()` **only writes the file when it is absent** - a missing + service falls through to `/etc/pam.d/other` (`pam_deny`), which looks identical + to "wrong credentials". If an admin customizes the file, nadir leaves it + untouched. +- `pam_unix` reads `/etc/shadow`, so the server must run as root. + +### 2. TLS: three modes + +Credentials and session cookies must never travel in cleartext. Nadir picks how +the connection is secured from `config.yaml`, in priority order: + +1. **Behind a reverse proxy** (`trust_proxy: true`) - a proxy such as Traefik + terminates TLS and forwards plaintext to nadir on a trusted network. Keep + `secure_tls: true` (the browser↔proxy leg is HTTPS). This is the deployment + covered in note 3. +2. **Nadir terminates TLS** (`tls_cert` + `tls_key`) - point both at a PEM + certificate/key pair and nadir serves HTTPS directly. Use this when there is + no proxy. +3. **Self-signed (dev only)** - when none of the above is configured, nadir + generates a fresh in-memory self-signed certificate (valid for `localhost` + and the loopback addresses, one year). Browsers will warn; that's expected. + Never rely on this in production. + +To create a persistent self-signed pair for mode 2 in development: + +```bash +openssl req -x509 -newkey rsa:2048 -nodes \ + -keyout key.pem -out cert.pem -days 365 \ + -subj "/O=nadir-dev-local/CN=localhost" \ + -addext "subjectAltName=DNS:localhost,IP:127.0.0.1,IP:::1" +``` + +…then set `tls_cert`/`tls_key` to those paths. + +### 3. Reverse proxy + VPN + +When nadir runs behind a TLS-terminating reverse proxy (e.g. Traefik) on a +private overlay network, set `trust_proxy: true`. Nadir then serves plaintext +HTTP and trusts `X-Forwarded-For` (used by the login throttle) and the forwarded +`Host` (used by the CSRF same-origin check). **That trust is only safe if +nothing but the proxy can reach the app's port** - otherwise any client that +reaches it directly can forge those headers. + +The recommended shape: the proxy and the app each sit on a WireGuard-based +overlay, and nadir binds to its overlay address so the public/LAN interfaces +never answer. + +```yaml +server: + trust_proxy: true + secure_tls: true # browser↔proxy leg is HTTPS, so keep the cookie Secure + hostname: 100.64.0.189 # the app's overlay IP - only the VPN interface listens + port: 9999 +``` + +**Netbird / Tailscale** assign peers out of `100.64.0.0/10` (RFC 6598 CGNAT), +which is not publicly routable - binding there means only VPN peers can connect. +**Plain WireGuard** is the same idea with a private range you pick (e.g. +`10.0.0.0/24`); bind to the app's address on the `wg0` interface. + +Two things make the header trust airtight: + +1. **Restrict the port to the proxy peer only.** Binding to the overlay limits + reachability to _all_ VPN peers, not just the proxy. Tighten it so only the + proxy can reach `:9999`: + - _Netbird_: an access-control policy allowing the proxy peer/group → the app + peer on tcp/9999, denying others. + - _Tailscale_: an ACL rule (`"src": ["tag:proxy"], "dst": ["tag:nadir:9999"]`). + - _Plain WireGuard_: a host firewall rule on the app, e.g. + `iptables -A INPUT -i wg0 ! -s -p tcp --dport 9999 -j DROP`. + +2. **Make the proxy overwrite client-supplied forwarded headers.** Otherwise a + client sending its own `X-Forwarded-For` / `X-Forwarded-Host` can have it + passed through. In Traefik, mark the overlay as trusted on the entrypoint: + + ```yaml + # traefik static config + entryPoints: + websecure: + address: ":443" + forwardedHeaders: + trustedIPs: + - 100.64.0.0/10 # or your wg subnet, e.g. 10.0.0.0/24 + ``` + + And ensure it forwards the original host (Traefik does by default; nginx needs + `proxy_set_header Host $host;`), since the CSRF check compares `Origin` + against `Host`. + +With both in place, the only path to the app is proxy → overlay → app, and the +forwarded headers are trustworthy. Without step 1 you're trusting every peer on +the overlay - fine for a single-tenant network you fully control, risky on a +shared one. + +### 4. Connecting a dashboard (machine clients) + +To manage one or more Nadir instances via a central dashboard or non-interactive client, authenticate requests using a static Bearer token rather than interactive PAM credentials. + +Here is how to authorize and connect a dashboard: + +#### Step 1: Mint a token +Run `nadir token add ` (for example, `dashboard`) to generate a unique API key: +```bash +sudo nadir token add dashboard +``` +This generates a secure token starting with `nad_`. **Copy this token immediately**; only its SHA-256 hash is stored in `/var/lib/nadir/tokens.db` (shared via SQLite WAL between server and CLI), and the raw key cannot be retrieved again. + +#### Step 2: Authorize the token in `config.yaml` +Minting and authorizing are deliberately separate steps (safe default). A newly minted token does not grant any access. + +To grant the token a role, edit the `assignments` map in your `config.yaml`: +```yaml +assignments: + dashboard: [admin] # or another role like [system_ops] or [auditor] +``` +The audit log will record mutations performed by this token as `token:` (e.g., `token:dashboard`), distinguishing it from human logins. + +#### Step 3: Restart Nadir +While token creation and revocation (`nadir token rm`) are written to the database and take effect immediately, policy assignments live in `config.yaml`. To reload the configuration and authorize the new token name, you must restart the Nadir server: +```bash +sudo systemctl restart nadir +``` + +#### Step 4: Configure the dashboard client +Configure your client to include the token in the HTTP `Authorization` header of every API request: +```http +Authorization: Bearer nad_your_secret_token_here +``` + +#### Note on CORS / Cross-Origin requests +If your dashboard runs as a web application directly in the user's browser (cross-origin relative to the Nadir instance) and makes state-changing write requests (`POST`, `PUT`, `DELETE`), the browser will include an `Origin` header. + +To defend against CSRF, Nadir's middleware rejects state-changing requests if an `Origin` header is present and does not match the request's `Host` header. + +To connect a browser-based dashboard hosted on a different origin, choose one of these patterns: +1. **Server-to-Server Calls (Recommended):** Build the dashboard with a backend that calls Nadir's API. Because the backend is not a browser, it does not send an `Origin` header, allowing the requests to pass. +2. **Reverse Proxy:** Terminate the dashboard and the Nadir instance under the same origin (e.g., dashboard at `https://control.example.com/` and Nadir at `https://control.example.com/api/nadir-node-1/`), letting a reverse proxy route the requests. +3. **Header Rewriting:** Have a proxy in front of Nadir rewrite/strip the `Origin` header for authorized token requests before forwarding them to Nadir. + +--- + +## Layout + +``` +cmd/ process entry point + CLI (run / install / logs …), TLS, service wiring +internal/auth PAM auth, sessions, login/logout, login throttle, PAM service install +internal/config config.yaml loader + startup validation +internal/meta /api/_modules, /api/whoami, /api/health discovery endpoints +internal/module the Module interface +internal/modules concrete modules: + system - info, hostname, time/timezone/NTP, locale/keymap, power + services - systemd unit control + journal/file logs (snapshot + SSE) + users - local accounts + groups - local groups + packages - dnf/apt/pacman install/remove/upgrade (streamed) + audit - read-only audit trail + networking - network interfaces, routing tables, DNS, and IP configurations + terminal - interactive PTY shell over WebSocket +internal/oscmd shared command runner (timeouts, stderr surfacing) + helpers +internal/rbac roles, permissions ("*" wildcards), HTTP middleware (RBAC + CSRF) +internal/audit SQLite-backed audit log writer +``` + +## API docs + +With the server running, browse `https://:/docs` for the Scalar UI, +or fetch the raw OpenAPI document from `/openapi.json`. + +--- + +## Built with LLM assistance + +This project was built with the help of large language models - but every +architectural choice, security decision, and operational trade-off is the +author's. The LLM never drove; it was a power tool, not a co-pilot with the +wheel. + +In practice, the workflow looks like this: the author designs the feature, +decides how it should fit into the existing module structure, specifies the API +surface, and defines the security and permission semantics. The LLM then +accelerates the mechanical side - scaffolding boilerplate, drafting +implementations from precise instructions, generating documentation, and +proposing test cases. Every line of output is reviewed, corrected where needed, +and integrated only when it meets the project's standards. + +What the LLM provides is _commodity leverage_: it collapses the time between +"I know exactly what I want" and "it's written, tested, and documented." What +it does not provide is judgment - that stays with the person who understands the +system, its threat model, and its users. + +--- + +## License + +[MIT](./LICENSE) + +## Credits + +Favicon: [Orbit](https://lucide.dev/icons/orbit) from [Lucide](https://lucide.dev), recolored. Lucide icons are licensed under the [ISC License](https://github.com/lucide-icons/lucide/blob/main/LICENSE). diff --git a/apidoc.go b/apidoc.go new file mode 100644 index 0000000..ee29633 --- /dev/null +++ b/apidoc.go @@ -0,0 +1,28 @@ +// Package nadir exists only to embed shared documentation (the README) so a +// single source of truth can feed both GitHub and the OpenAPI description. +package nadir + +import _ "embed" + +// README is the project README. Content up to the "" marker +// is reused as the API description; see cmd/main.go. +// +//go:embed README.md +var README string + +// Favicon is the orbit icon (lucide.dev, recolored midnight-blue) served as +// both the app favicon and the /docs page icon. +// +//go:embed favicon.svg +var Favicon string + +// InstallScriptTemplate is the curl|sh bootstrap script served at /install.sh. +// It contains the placeholder __NADIR_BASE_URL__, substituted at request time +// with the scheme+host the script was fetched from - so it downloads the +// binary from the very instance that served it. See cmd/server/server.go. +// +//go:embed install.sh.tmpl +var InstallScriptTemplate string + +// Version is the current version of the Nadir application. +const Version = "0.0.1" diff --git a/config.example.yaml b/config.example.yaml new file mode 100644 index 0000000..2847f51 --- /dev/null +++ b/config.example.yaml @@ -0,0 +1,96 @@ +# ────────────────────────────────────────────────────────────────────── +# Nadir configuration - config.yaml +# +# This is the single source of truth for runtime settings: server/TLS, +# roles, role assignments, and log-file allowlists. The only env var +# nadir reads is CONFIG_PATH, the bootstrap pointer to this file. +# By default, nadir uses ~/.config/config.yaml. You can override this +# with the -f / --config flag or by setting the CONFIG_PATH env var. +# ────────────────────────────────────────────────────────────────────── + +server: + # Secure attribute on the session cookie. + # Keep true when the browser reaches nadir over HTTPS (direct or via proxy). + # Set false only for local plain-HTTP development. + secure_tls: true + + # TLS mode, in priority order: + # + # 1. trust_proxy: true + # A reverse proxy (e.g. Traefik) terminates TLS and forwards plaintext + # to nadir. Bind hostname to a private/overlay address so only the proxy + # can reach nadir. X-Forwarded-For is trusted in this mode. + # + # 2. tls_cert + tls_key (uncomment below) + # Nadir terminates TLS itself with your PEM certificate and key. + # + # 3. Neither (default) + # Nadir generates a fresh in-memory self-signed certificate on every + # start. Browsers will warn - this is for development only. + # + # Keep secure_tls: true in modes 1 and 2. + # trust_proxy: false + # tls_cert: /etc/nadir/tls/cert.pem + # tls_key: /etc/nadir/tls/key.pem + + # Address and port to bind. + # Use 127.0.0.1 for local-only, or an overlay/VPN address (e.g. a Netbird + # or Tailscale IP) to expose nadir only on that interface. + hostname: localhost + port: 9999 + +# ────────────────────────────────────────────────────────────────────── +# Roles +# +# Maps a role name to { module → [permissions] }. +# - Module key "*" means "all modules (including future ones)". +# - Permission "*" means "all permissions the module exports". +# - IMPORTANT: quote "*" - bare * is YAML alias syntax and fails to parse. +# +# Each module exports its own permission vocabulary via Permissions(). +# Valid tiers are: read, write, root. Unknown modules or permissions +# cause a startup error, not a silent denial. +# ────────────────────────────────────────────────────────────────────── +roles: + # Full access - every permission on every module. + admin: + "*": ["*"] + + # Read-only on all modules - good for monitoring dashboards. + # auditor: + # "*": ["read"] + + # Scoped operator - can read and write the system module only. + # system_ops: + # system: ["read", "write"] + +# ────────────────────────────────────────────────────────────────────── +# Assignments +# +# Maps a local username to one or more roles. Effective grants are the +# union of all assigned roles' permissions. The username must match a +# real system account (PAM authenticates against /etc/shadow). +# +# A machine credential (Bearer token) is assigned a role the same way, by its +# token name rather than a system username. Mint with `nadir token add +# central-dashboard`, then grant it scoped access below. Until it is listed +# here the token authenticates but can do nothing. +# ────────────────────────────────────────────────────────────────────── +assignments: + # Replace with your admin username. + ubuntu: [admin] + # central-dashboard: [auditor] + +# ────────────────────────────────────────────────────────────────────── +# Log files (optional) +# +# Per-unit allowlist of log files the Services module may serve via the +# source=file log endpoints. Only paths listed here are readable - the +# caller can never request an arbitrary file. +# ────────────────────────────────────────────────────────────────────── +# log_files: +# nginx: +# - /var/log/nginx/access.log +# - /var/log/nginx/error.log +# files: +# - /var/log/auth.log diff --git a/favicon.svg b/favicon.svg new file mode 100644 index 0000000..2507875 --- /dev/null +++ b/favicon.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..823623b --- /dev/null +++ b/go.mod @@ -0,0 +1,26 @@ +module nadir + +go 1.26.4 + +require github.com/msteinert/pam v1.2.0 + +require github.com/danielgtaylor/huma/v2 v2.38.0 + +require ( + gopkg.in/yaml.v3 v3.0.1 + modernc.org/sqlite v1.52.0 + github.com/coder/websocket v1.8.15 + github.com/creack/pty v1.1.24 +) + +require ( + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/mattn/go-isatty v0.0.21 // indirect + github.com/ncruces/go-strftime v1.0.0 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + golang.org/x/sys v0.43.0 // indirect + modernc.org/libc v1.72.3 // indirect + modernc.org/mathutil v1.7.1 // indirect + modernc.org/memory v1.11.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..6fbbbbb --- /dev/null +++ b/go.sum @@ -0,0 +1,74 @@ +github.com/coder/websocket v1.8.15 h1:6B2JPeOGlpff2Uz6vOEH1Vzpi0iUz20A+lPVhPHtNUA= +github.com/coder/websocket v1.8.15/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg= +github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s= +github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE= +github.com/danielgtaylor/huma/v2 v2.38.0 h1:fb0WZCatnaiHLphMQDDWDjygNxfMkX/ENma3QsRl7vY= +github.com/danielgtaylor/huma/v2 v2.38.0/go.mod h1:k9hwjlgWFt1t2jsmQGlsgXAG2FBTZa4kkjV581qAtfo= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/fxamacker/cbor/v2 v2.9.1 h1:2rWm8B193Ll4VdjsJY28jxs70IdDsHRWgQYAI80+rMQ= +github.com/fxamacker/cbor/v2 v2.9.1/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ= +github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= +github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= +github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= +github.com/mattn/go-isatty v0.0.21 h1:xYae+lCNBP7QuW4PUnNG61ffM4hVIfm+zUzDuSzYLGs= +github.com/mattn/go-isatty v0.0.21/go.mod h1:ZXfXG4SQHsB/w3ZeOYbR0PrPwLy+n6xiMrJlRFqopa4= +github.com/msteinert/pam v1.2.0 h1:mYfjlvN2KYs2Pb9G6nb/1f/nPfAttT/Jee5Sq9r3bGE= +github.com/msteinert/pam v1.2.0/go.mod h1:d2n0DCUK8rGecChV3JzvmsDjOY4R7AYbsNxAT+ftQl0= +github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= +github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= +github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= +golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8= +golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= +golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= +golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/term v0.6.0 h1:clScbb1cHjoCkyRbWwBEUZ5H/tIFu5TAXIqaZD0Gcjw= +golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U= +golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k= +golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +modernc.org/cc/v4 v4.28.2 h1:3tQ0lf2ADtoby2EtSP+J7IE2SHwEJdP8ioR59wx7XpY= +modernc.org/cc/v4 v4.28.2/go.mod h1:OnovgIhbbMXMu1aISnJ0wvVD1KnW+cAUJkIrAWh+kVI= +modernc.org/ccgo/v4 v4.34.0 h1:yRLPFZieg532OT4rp4JFNIVcquwalMX26G95WQDqwCQ= +modernc.org/ccgo/v4 v4.34.0/go.mod h1:AS5WYMyBakQ+fhsHhtP8mWB82KTGPkNNJDGfGQCe0/A= +modernc.org/fileutil v1.4.0 h1:j6ZzNTftVS054gi281TyLjHPp6CPHr2KCxEXjEbD6SM= +modernc.org/fileutil v1.4.0/go.mod h1:EqdKFDxiByqxLk8ozOxObDSfcVOv/54xDs/DUHdvCUU= +modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI= +modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito= +modernc.org/gc/v3 v3.1.2 h1:ZtDCnhonXSZexk/AYsegNRV1lJGgaNZJuKjJSWKyEqo= +modernc.org/gc/v3 v3.1.2/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY= +modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks= +modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI= +modernc.org/libc v1.72.3 h1:ZnDF4tXn4NBXFutMMQC4vtbTFSXhhKzR73fv0beZEAU= +modernc.org/libc v1.72.3/go.mod h1:dn0dZNnnn1clLyvRxLxYExxiKRZIRENOfqQ8XEeg4Qs= +modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= +modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= +modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= +modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw= +modernc.org/opt v0.2.0 h1:tGyef5ApycA7FSEOMraay9SaTk5zmbx7Tu+cJs4QKZg= +modernc.org/opt v0.2.0/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns= +modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w= +modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE= +modernc.org/sqlite v1.52.0 h1:p4dhYh2tXZCiyaqHwRVJDjIGKWyXayiQpThxgDzJaxo= +modernc.org/sqlite v1.52.0/go.mod h1:tcNzv5p84E0skkmJn038y+hWJbLQXQqEnQfeh5r2JLM= +modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0= +modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A= +modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= +modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= diff --git a/install.sh.tmpl b/install.sh.tmpl new file mode 100644 index 0000000..b77fb11 --- /dev/null +++ b/install.sh.tmpl @@ -0,0 +1,52 @@ +#!/bin/sh +set -e +# Nadir bootstrap installer. +# +# Downloads the nadir binary from the SAME instance that served this script +# and installs it as a systemd service. Intended for spinning up nadir on a +# sibling host quickly (same architecture as the source instance): +# +# curl -fsSL https://:/install.sh | sudo sh +# +# This is not a general-purpose distribution channel: there is no separate +# release server, no signature, and no version pinning beyond "whatever the +# source instance is currently running." Trust it exactly as much as you +# trust the source host. +# +# Everything is wrapped in do_install and invoked on the last line, the same +# pattern get.docker.com uses: when piped via `curl | sh`, sh executes while +# still receiving bytes from the network. Defining the whole body as a +# function before running anything means sh has fully parsed the script +# before do_install ever runs, which avoids exactly the failure mode where a +# command further down (here, the binary download) races the still-open +# script-fetch connection and fails with a write error. + +BASE_URL="__NADIR_BASE_URL__" + +do_install() { + if [ "$(id -u)" -ne 0 ]; then + echo "this script must be run as root (try: curl -fsSL $BASE_URL/install.sh | sudo sh)" >&2 + exit 1 + fi + + if ! command -v curl >/dev/null 2>&1; then + echo "curl is required" >&2 + exit 1 + fi + + echo "downloading nadir from $BASE_URL ..." + # No -s here: the progress meter is the whole point of leaving it visible, + # so you can see how much is left mid-download. + curl -f --progress-bar -L "$BASE_URL/nadir-binary" -o /usr/local/bin/nadir.tmp + mv /usr/local/bin/nadir.tmp /usr/local/bin/nadir + chmod +x /usr/local/bin/nadir + + echo "binary installed at /usr/local/bin/nadir" + echo "installing as a systemd service ..." + /usr/local/bin/nadir install + + echo + echo "done. check status with: nadir status" +} + +do_install diff --git a/internal/auditlog/audit.go b/internal/auditlog/audit.go new file mode 100644 index 0000000..f4ea613 --- /dev/null +++ b/internal/auditlog/audit.go @@ -0,0 +1,148 @@ +// Package audit records privileged write operations to an embedded SQLite +// database so there is a durable "who did what" trail. +package auditlog + +import ( + "database/sql" + "fmt" + "log" + "os" + "path/filepath" + "sync" + "time" + + _ "modernc.org/sqlite" +) + +// buffer sizes the channel that decouples Record (request path) from the DB +// writer goroutine. Bursts up to this many entries return immediately; only a +// sustained overflow falls back to a synchronous write. +const buffer = 1024 + +// Entry is one recorded action. +type Entry struct { + Time string `json:"time" example:"2026-06-20T08:15:04Z" doc:"When the action occurred (RFC3339, UTC)"` + Username string `json:"username" example:"alice" doc:"Who performed it"` + Method string `json:"method" example:"POST" doc:"HTTP method"` + Path string `json:"path" example:"/api/users" doc:"Request path"` + Module string `json:"module" example:"users" doc:"Target module"` + Status int `json:"status" example:"200" doc:"HTTP response status"` + + // ts is the capture time (unix seconds), carried through the writer channel + // so the stored timestamp reflects when the action happened, not when the + // background writer got to it. + ts int64 +} + +type Store struct { + db *sql.DB + ch chan Entry + done chan struct{} // closed when the writer has drained and exited + + // mu guards the channel against a Record racing Close. Record takes RLock + // (many concurrent senders); Close takes the write lock to flip closed and + // close the channel exactly once. Without it, a Record after Close would + // panic sending on a closed channel. + mu sync.RWMutex + closed bool +} + +// New opens (creating if needed) the audit database at path and starts the +// background writer that drains recorded entries to disk. +func New(path string) (*Store, error) { + if err := os.MkdirAll(filepath.Dir(path), 0700); err != nil { + return nil, fmt.Errorf("audit db dir: %w", err) + } + db, err := sql.Open("sqlite", path) + if err != nil { + return nil, fmt.Errorf("open audit db: %w", err) + } + // ponytail: single connection serializes writes, same rationale as the + // session store; the background writer is the only writer. + db.SetMaxOpenConns(1) + if _, err := db.Exec(`CREATE TABLE IF NOT EXISTS audit ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + ts INTEGER NOT NULL, + username TEXT NOT NULL, + method TEXT NOT NULL, + path TEXT NOT NULL, + module TEXT NOT NULL, + status INTEGER NOT NULL + )`); err != nil { + return nil, fmt.Errorf("create audit table: %w", err) + } + s := &Store{db: db, ch: make(chan Entry, buffer), done: make(chan struct{})} + go s.writer() + return s, nil +} + +// writer is the single goroutine that persists entries, keeping all DB writes +// off the request path. It exits once Close closes the channel, after draining +// whatever is buffered. +func (s *Store) writer() { + defer close(s.done) + for e := range s.ch { + s.insert(e) + } +} + +// Close stops accepting entries, drains those still buffered, and closes the +// database. Call it during shutdown, after the HTTP server has stopped so no +// further Record calls can race with the channel close. +func (s *Store) Close() error { + s.mu.Lock() + s.closed = true + close(s.ch) + s.mu.Unlock() + <-s.done + return s.db.Close() +} + +func (s *Store) insert(e Entry) { + if _, err := s.db.Exec( + `INSERT INTO audit (ts, username, method, path, module, status) VALUES (?, ?, ?, ?, ?, ?)`, + e.ts, e.Username, e.Method, e.Path, e.Module, e.Status, + ); err != nil { + log.Printf("audit: insert failed (%s %s by %s): %v", e.Method, e.Path, e.Username, err) + } +} + +// Record queues an entry for the background writer and returns immediately. If +// the buffer is full (sustained overload), it writes synchronously instead so an +// audit entry is never silently dropped - only then does it pay DB latency. +func (s *Store) Record(username, method, path, module string, status int) { + e := Entry{ts: time.Now().Unix(), Username: username, Method: method, Path: path, Module: module, Status: status} + s.mu.RLock() + defer s.mu.RUnlock() + if s.closed { + s.insert(e) // post-shutdown straggler: write directly, channel is gone + return + } + select { + case s.ch <- e: + default: + s.insert(e) + } +} + +// List returns the most recent entries, newest first, capped at limit. +func (s *Store) List(limit int) ([]Entry, error) { + rows, err := s.db.Query( + `SELECT ts, username, method, path, module, status FROM audit ORDER BY id DESC LIMIT ?`, limit) + if err != nil { + return nil, err + } + defer rows.Close() + + entries := []Entry{} + for rows.Next() { + var e Entry + var ts int64 + if err := rows.Scan(&ts, &e.Username, &e.Method, &e.Path, &e.Module, &e.Status); err != nil { + return nil, err + } + e.Time = time.Unix(ts, 0).UTC().Format(time.RFC3339) + entries = append(entries, e) + } + return entries, rows.Err() +} diff --git a/internal/auditlog/audit_test.go b/internal/auditlog/audit_test.go new file mode 100644 index 0000000..ce5c0f0 --- /dev/null +++ b/internal/auditlog/audit_test.go @@ -0,0 +1,107 @@ +package auditlog + +import ( + "path/filepath" + "testing" + "time" +) + +func newTestStore(t *testing.T) *Store { + t.Helper() + s, err := New(filepath.Join(t.TempDir(), "audit.db")) + if err != nil { + t.Fatal(err) + } + return s +} + +// eventually polls List until at least want entries are persisted (writes are +// async) or it times out. +func eventually(t *testing.T, s *Store, want int) []Entry { + t.Helper() + deadline := time.Now().Add(time.Second) + for { + got, err := s.List(100) + if err != nil { + t.Fatal(err) + } + if len(got) >= want { + return got + } + if time.Now().After(deadline) { + t.Fatalf("timed out waiting for %d entries, have %d", want, len(got)) + } + time.Sleep(5 * time.Millisecond) + } +} + +func TestRecordAndList(t *testing.T) { + s := newTestStore(t) + s.Record("alice", "POST", "/api/users", "users", 200) + s.Record("bob", "DELETE", "/api/users/x", "users", 403) + + got := eventually(t, s, 2) + // Newest first. + if got[0].Username != "bob" || got[0].Status != 403 { + t.Errorf("newest entry wrong: %+v", got[0]) + } + if got[1].Username != "alice" || got[1].Method != "POST" { + t.Errorf("oldest entry wrong: %+v", got[1]) + } + if got[0].Time == "" { + t.Error("time not populated") + } +} + +func TestListLimit(t *testing.T) { + s := newTestStore(t) + for range 5 { + s.Record("u", "POST", "/api/x", "system", 200) + } + eventually(t, s, 5) + got, err := s.List(3) + if err != nil { + t.Fatal(err) + } + if len(got) != 3 { + t.Errorf("limit not honored: got %d, want 3", len(got)) + } +} + +func TestCloseDrains(t *testing.T) { + path := filepath.Join(t.TempDir(), "audit.db") + s, err := New(path) + if err != nil { + t.Fatal(err) + } + for range 50 { + s.Record("u", "POST", "/api/x", "system", 200) + } + // Close must flush everything still buffered before returning. + if err := s.Close(); err != nil { + t.Fatal(err) + } + + // Reopen and confirm all 50 made it to disk. + s2, err := New(path) + if err != nil { + t.Fatal(err) + } + got, err := s2.List(100) + if err != nil { + t.Fatal(err) + } + if len(got) != 50 { + t.Errorf("Close did not drain: got %d entries, want 50", len(got)) + } +} + +func TestListEmpty(t *testing.T) { + got, err := newTestStore(t).List(10) + if err != nil { + t.Fatal(err) + } + if len(got) != 0 { + t.Errorf("fresh store should be empty, got %d", len(got)) + } +} diff --git a/internal/auth/bearer.go b/internal/auth/bearer.go new file mode 100644 index 0000000..ec68e7a --- /dev/null +++ b/internal/auth/bearer.go @@ -0,0 +1,48 @@ +package auth + +import ( + "strings" + "time" +) + +// TokenAuth verifies Bearer credentials against the TokenStore and throttles +// brute force. Unlike the login throttle (keyed on username+IP), a Bearer token +// has no "login" step to rate-limit, so guesses are throttled by source IP +// alone. The window is looser than login's because a legitimate dashboard may +// fire many requests in a minute - only repeated *failures* count. +type TokenAuth struct { + store *TokenStore + throttle *failLimiter +} + +// NewTokenAuth wraps a store with an IP-keyed failure throttle. +func NewTokenAuth(store *TokenStore) *TokenAuth { + return &TokenAuth{store: store, throttle: newFailLimiter(20, time.Minute)} +} + +// Verify resolves a presented Bearer token to its name. throttled is true when +// the source IP is in cooldown after too many bad tokens; the caller should +// answer 429 without consulting the store. +func (a *TokenAuth) Verify(ip, raw string) (name string, ok, throttled bool) { + if a.throttle.blocked(ip) { + return "", false, true + } + name, found := a.store.Lookup(raw) + if !found { + a.throttle.fail(ip) + return "", false, false + } + a.throttle.reset(ip) + return name, true, false +} + +// BearerToken extracts the credential from an Authorization header value, +// reporting whether it was a Bearer scheme. Returns ("", false) for cookie-only +// or unauthenticated requests. +func BearerToken(authHeader string) (string, bool) { + const scheme = "Bearer " + if len(authHeader) <= len(scheme) || !strings.EqualFold(authHeader[:len(scheme)], scheme) { + return "", false + } + return strings.TrimSpace(authHeader[len(scheme):]), true +} diff --git a/internal/auth/login.go b/internal/auth/login.go new file mode 100644 index 0000000..4d624e1 --- /dev/null +++ b/internal/auth/login.go @@ -0,0 +1,90 @@ +package auth + +import ( + "context" + "net/http" + "time" + + "nadir/internal/auditlog" + + "github.com/danielgtaylor/huma/v2" +) + +// authenticator verifies a username/password (PAM in production). It's a field +// of the login handler rather than a package global so tests can inject a stub +// without mutating shared state. +type authenticator func(username, password string) error + +type LoginInput struct { + Body struct { + Username string `json:"username" doc:"System username"` + Password string `json:"password" doc:"System password"` + } +} + +type LoginOutput struct { + SetCookie http.Cookie `header:"Set-Cookie"` + Body struct { + Status string `json:"status" example:"logged in"` + } +} + +// RegisterLogin wires the login operation into the Huma API. It has no +// permission metadata, so the RBAC middleware lets it through unauthenticated. +// +// secure sets the Secure attribute on the session cookie. Keep it true in +// production (the cookie is then only sent over HTTPS); set it false for local +// development over plain HTTP, where a Secure cookie would never be sent back. +func RegisterLogin(api huma.API, sessions *SessionStore, auditor *auditlog.Store, secure bool) { + // loginThrottle blunts brute force: 5 failures for a username+source IP + // trigger a one-minute cooldown. See throttle.go for the ceiling/upgrade path. + registerLogin(api, sessions, auditor, secure, Authenticate, newFailLimiter(5, time.Minute)) +} + +func registerLogin(api huma.API, sessions *SessionStore, auditor *auditlog.Store, secure bool, authenticate authenticator, throttle *failLimiter) { + huma.Register(api, huma.Operation{ + OperationID: "login", + Method: "POST", + Path: "/api/login", + Summary: "Authenticate and start a session", + Description: "Verifies the username and password against PAM (the " + + "dedicated nadir service) and, on success, sets an HttpOnly " + + "session cookie used to authorize all other endpoints.", + Tags: []string{"Authentication"}, + Errors: []int{401, 429}, + }, func(ctx context.Context, in *LoginInput) (*LoginOutput, error) { + // Throttle brute force: too many recent failures for this account/source + // put it in a short cooldown before the password is even checked. + throttleKey := in.Body.Username + "|" + ClientIP(ctx) + if throttle.blocked(throttleKey) { + return nil, huma.Error429TooManyRequests("too many failed login attempts; wait a minute") + } + // Record both outcomes: failed logins are the brute-force signal, and the + // username is captured even on failure (which account is being targeted). + if err := authenticate(in.Body.Username, in.Body.Password); err != nil { + throttle.fail(throttleKey) + auditor.Record(in.Body.Username, "POST", "/api/login", "auth", http.StatusUnauthorized) + return nil, huma.Error401Unauthorized("invalid credentials") + } + throttle.reset(throttleKey) + auditor.Record(in.Body.Username, "POST", "/api/login", "auth", http.StatusOK) + + sessionID, err := sessions.Create(in.Body.Username) + if err != nil { + return nil, huma.Error500InternalServerError("could not create session", err) + } + out := &LoginOutput{ + SetCookie: http.Cookie{ + Name: "nadir_session_id", + Value: sessionID, + Path: "/", + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteStrictMode, + Expires: time.Now().Add(24 * time.Hour), + }, + } + out.Body.Status = "logged in" + return out, nil + }) +} diff --git a/internal/auth/login_test.go b/internal/auth/login_test.go new file mode 100644 index 0000000..72c3358 --- /dev/null +++ b/internal/auth/login_test.go @@ -0,0 +1,185 @@ +package auth + +import ( + "errors" + "net/http" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "nadir/internal/auditlog" + "nadir/internal/oscmd" + + "github.com/danielgtaylor/huma/v2" + "github.com/danielgtaylor/huma/v2/adapters/humago" + "github.com/danielgtaylor/huma/v2/humatest" +) + +func TestMain(m *testing.M) { + if oscmd.RunHelperProcess() { + return + } + os.Exit(m.Run()) +} + +func TestEnsurePAMService(t *testing.T) { + tempFile := filepath.Join(t.TempDir(), "nadir-pam-test") + oldPath := pamServicePath + pamServicePath = tempFile + defer func() { pamServicePath = oldPath }() + + if err := EnsurePAMService(); err != nil { + t.Fatal(err) + } + + data, err := os.ReadFile(tempFile) + if err != nil { + t.Fatal(err) + } + if string(data) != pamServiceContent { + t.Errorf("got content %q, want %q", string(data), pamServiceContent) + } + + customContent := "custom pam content" + if err := os.WriteFile(tempFile, []byte(customContent), 0644); err != nil { + t.Fatal(err) + } + + if err := EnsurePAMService(); err != nil { + t.Fatal(err) + } + + data, err = os.ReadFile(tempFile) + if err != nil { + t.Fatal(err) + } + if string(data) != customContent { + t.Errorf("EnsurePAMService clobbered file: got %q, want %q", string(data), customContent) + } +} + +func TestLoginLogoutThrottling(t *testing.T) { + tempDir := t.TempDir() + auditStore, err := auditlog.New(filepath.Join(tempDir, "audit.db")) + if err != nil { + t.Fatal(err) + } + defer auditStore.Close() + + sessions, err := NewSessionStore(filepath.Join(tempDir, "sessions.db")) + if err != nil { + t.Fatal(err) + } + + mux := http.NewServeMux() + api := humatest.Wrap(t, humago.New(mux, huma.DefaultConfig("Test", "1.0.0"))) + + // Inject a stub authenticator and a short-window throttle (3 failures) at + // registration — no package globals to mutate. The throttle keys on + // username+IP, so the success/failure cases below and the throttle case (a + // distinct user) don't interfere. + authMock := func(username, password string) error { + if password == "correct" { + return nil + } + return errors.New("pam error") + } + registerLogin(api, sessions, auditStore, false, authMock, newFailLimiter(3, 500*time.Millisecond)) + RegisterLogout(api, sessions, false) + + // 1. Test failed login + resp := api.Post("/api/login", struct { + Username string `json:"username"` + Password string `json:"password"` + }{ + Username: "admin", + Password: "wrong", + }) + if resp.Code != http.StatusUnauthorized { + t.Errorf("failed login: got code %d, want %d", resp.Code, http.StatusUnauthorized) + } + + // 2. Test successful login + resp = api.Post("/api/login", struct { + Username string `json:"username"` + Password string `json:"password"` + }{ + Username: "admin", + Password: "correct", + }) + if resp.Code != http.StatusOK { + t.Errorf("successful login: got code %d, want %d", resp.Code, http.StatusOK) + } + + cookieHeader := resp.Header().Get("Set-Cookie") + if !strings.Contains(cookieHeader, "nadir_session_id=") { + t.Fatalf("Set-Cookie header missing nadir_session_id: %q", cookieHeader) + } + + var sessionID string + parts := strings.SplitSeq(cookieHeader, ";") + for part := range parts { + part = strings.TrimSpace(part) + if after, ok := strings.CutPrefix(part, "nadir_session_id="); ok { + sessionID = after + break + } + } + + if sessionID == "" { + t.Fatal("nadir_session_id cookie not found") + } + + _, ok := sessions.GetByToken(sessionID) + if !ok { + t.Fatal("session not found in session store") + } + + // 3. Test logout + resp = api.Post("/api/logout", "Cookie: nadir_session_id="+sessionID, struct{}{}) + if resp.Code != http.StatusOK { + t.Errorf("logout failed: got code %d, want %d", resp.Code, http.StatusOK) + } + + _, ok = sessions.GetByToken(sessionID) + if ok { + t.Fatal("session still valid after logout") + } + + // 4. Test throttling (the handler was registered with a 3-failure limiter). + for range 3 { + api.Post("/api/login", struct { + Username string `json:"username"` + Password string `json:"password"` + }{ + Username: "throttled-user", + Password: "wrong", + }) + } + + resp = api.Post("/api/login", struct { + Username string `json:"username"` + Password string `json:"password"` + }{ + Username: "throttled-user", + Password: "correct", + }) + if resp.Code != http.StatusTooManyRequests { + t.Errorf("throttled login: got code %d, want %d", resp.Code, http.StatusTooManyRequests) + } + + time.Sleep(600 * time.Millisecond) + + resp = api.Post("/api/login", struct { + Username string `json:"username"` + Password string `json:"password"` + }{ + Username: "throttled-user", + Password: "correct", + }) + if resp.Code != http.StatusOK { + t.Errorf("login after cooldown: got code %d, want %d", resp.Code, http.StatusOK) + } +} diff --git a/internal/auth/logout.go b/internal/auth/logout.go new file mode 100644 index 0000000..9421914 --- /dev/null +++ b/internal/auth/logout.go @@ -0,0 +1,53 @@ +package auth + +import ( + "context" + "net/http" + + "github.com/danielgtaylor/huma/v2" +) + +type LogoutInput struct { + SessionID string `cookie:"nadir_session_id"` +} + +type LogoutOutput struct { + SetCookie http.Cookie `header:"Set-Cookie"` + Body struct { + Status string `json:"status" example:"logged out"` + } +} + +// RegisterLogout wires the logout operation. Like login it carries no permission +// metadata, so the RBAC middleware lets it through: it deletes whatever session +// the cookie names (a no-op for an unknown/expired token) and clears the cookie. +func RegisterLogout(api huma.API, sessions *SessionStore, secure bool) { + huma.Register(api, huma.Operation{ + OperationID: "logout", + Method: "POST", + Path: "/api/logout", + Summary: "End the current session", + Description: "Invalidates the session named by the cookie and clears it. " + + "Always succeeds, even without a valid session.", + Tags: []string{"Authentication"}, + }, func(ctx context.Context, in *LogoutInput) (*LogoutOutput, error) { + if in.SessionID != "" { + if err := sessions.Delete(in.SessionID); err != nil { + return nil, huma.Error500InternalServerError("could not end session", err) + } + } + out := &LogoutOutput{ + SetCookie: http.Cookie{ + Name: "nadir_session_id", + Value: "", + Path: "/", + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteStrictMode, + MaxAge: -1, // delete the cookie now + }, + } + out.Body.Status = "logged out" + return out, nil + }) +} diff --git a/internal/auth/pam.go b/internal/auth/pam.go new file mode 100644 index 0000000..117edca --- /dev/null +++ b/internal/auth/pam.go @@ -0,0 +1,18 @@ +package auth + +import "github.com/msteinert/pam" + +func Authenticate(username, password string) error { + t, err := pam.StartFunc(PAMService, username, func(s pam.Style, msg string) (string, error) { + if s == pam.PromptEchoOff { + return password, nil + } + return "", nil + }) + + if err != nil { + return err + } + + return t.Authenticate(0) +} diff --git a/internal/auth/pamservice.go b/internal/auth/pamservice.go new file mode 100644 index 0000000..a4a287d --- /dev/null +++ b/internal/auth/pamservice.go @@ -0,0 +1,51 @@ +package auth + +import ( + "fmt" + "os" +) + +// PAMService is the name of the PAM service nadir authenticates against. +// It maps to /etc/pam.d/. We use a dedicated service rather than +// a stock one like "login" so we control exactly which modules run during +// authentication. See README.md ("PAM service") for the full rationale. +const PAMService = "nadir" + +var pamServicePath = "/etc/pam.d/" + PAMService + +// pamServiceContent is the minimal stack nadir needs: verify the password +// against /etc/shadow and confirm the account is valid. It deliberately omits +// pam_fprintd (blocks ~30s waiting for a fingerprint swipe that never comes), +// pam_systemd, pam_env, and the rest of the distro's login stack. +const pamServiceContent = `#%PAM-1.0 +# Managed by nadir. Do not rely on hand edits surviving - nadir recreates this +# file on startup only if it is missing. Minimal auth stack: verify the +# password against /etc/shadow and confirm the account is valid. Deliberately +# omits pam_fprintd (blocks ~30s on a fingerprint swipe), pam_systemd, pam_env. +auth required pam_unix.so +account required pam_unix.so +` + +// EnsurePAMService writes the PAM service file if it is missing. nadir already +// runs as root (pam_unix needs to read /etc/shadow), so it can install its own +// PAM config rather than relying on a separate install step. +// +// It will not overwrite an existing file: an admin who has customized +// /etc/pam.d/nadir keeps their version. Returns an error only on a real I/O +// problem so main can fail loudly instead of later looking like bad +// credentials (a missing file falls through to pam_deny via /etc/pam.d/other). +func EnsurePAMService() error { + switch _, err := os.Stat(pamServicePath); { + case err == nil: + return nil // already present - leave admin customizations intact + case os.IsNotExist(err): + // fall through and create it + default: + return fmt.Errorf("stat %s: %w", pamServicePath, err) + } + + if err := os.WriteFile(pamServicePath, []byte(pamServiceContent), 0644); err != nil { + return fmt.Errorf("write %s: %w (need root, and /etc/pam.d must be writable)", pamServicePath, err) + } + return nil +} diff --git a/internal/auth/session.go b/internal/auth/session.go new file mode 100644 index 0000000..522b033 --- /dev/null +++ b/internal/auth/session.go @@ -0,0 +1,99 @@ +package auth + +import ( + "crypto/rand" + "database/sql" + "encoding/hex" + "fmt" + "os" + "path/filepath" + "time" + + _ "modernc.org/sqlite" +) + +const sessionTTL = 24 * time.Hour + +type Session struct { + Username string +} + +// SessionStore persists sessions in an embedded SQLite database so they survive +// process restarts. Expired rows are dropped lazily on read. +type SessionStore struct { + db *sql.DB +} + +// NewSessionStore opens (creating if needed) the SQLite database at path and +// ensures the sessions table exists. The parent directory is created too. +func NewSessionStore(path string) (*SessionStore, error) { + if err := os.MkdirAll(filepath.Dir(path), 0700); err != nil { + return nil, fmt.Errorf("session db dir: %w", err) + } + db, err := sql.Open("sqlite", path) + if err != nil { + return nil, fmt.Errorf("open session db: %w", err) + } + // ponytail: single connection serializes access, avoiding SQLite's + // "database is locked" under concurrent writes. A sysadmin panel's session + // rate is trivial; switch to WAL + a real pool only if that ever bites. + db.SetMaxOpenConns(1) + if _, err := db.Exec(`CREATE TABLE IF NOT EXISTS sessions ( + token TEXT PRIMARY KEY, + username TEXT NOT NULL, + expires_at INTEGER NOT NULL + )`); err != nil { + return nil, fmt.Errorf("create sessions table: %w", err) + } + return &SessionStore{db: db}, nil +} + +// Ping reports whether the session database is reachable. Used by the health +// check. +func (s *SessionStore) Ping() error { return s.db.Ping() } + +func (s *SessionStore) Create(username string) (string, error) { + token := randomToken() + expires := time.Now().Add(sessionTTL) + if _, err := s.db.Exec( + `INSERT INTO sessions (token, username, expires_at) VALUES (?, ?, ?)`, + token, username, expires.Unix(), + ); err != nil { + return "", err + } + return token, nil +} + +// Delete removes a session, invalidating it immediately (logout). Deleting an +// unknown token is a no-op. +func (s *SessionStore) Delete(token string) error { + _, err := s.db.Exec(`DELETE FROM sessions WHERE token = ?`, token) + return err +} + +func (s *SessionStore) GetByToken(token string) (Session, bool) { + var username string + var expires int64 + err := s.db.QueryRow( + `SELECT username, expires_at FROM sessions WHERE token = ?`, token, + ).Scan(&username, &expires) + if err != nil { + return Session{}, false + } + if time.Now().Unix() > expires { + s.db.Exec(`DELETE FROM sessions WHERE token = ?`, token) + return Session{}, false + } + return Session{Username: username}, true +} + +func randomToken() string { + b := make([]byte, 32) + // crypto/rand.Read never returns an error on supported platforms; if it + // somehow does, an all-zero (guessable) token would be a security hole, so + // fail hard rather than hand one out. + if _, err := rand.Read(b); err != nil { + panic("crypto/rand failed: " + err.Error()) + } + return hex.EncodeToString(b) +} diff --git a/internal/auth/session_test.go b/internal/auth/session_test.go new file mode 100644 index 0000000..1e91730 --- /dev/null +++ b/internal/auth/session_test.go @@ -0,0 +1,88 @@ +package auth + +import ( + "path/filepath" + "testing" + "time" +) + +func TestSessionPersistsAcrossReopen(t *testing.T) { + path := filepath.Join(t.TempDir(), "sessions.db") + + store, err := NewSessionStore(path) + if err != nil { + t.Fatal(err) + } + token, err := store.Create("urania") + if err != nil { + t.Fatal(err) + } + + // Reopen the same file: a fresh process must still see the session. + reopened, err := NewSessionStore(path) + if err != nil { + t.Fatal(err) + } + sess, ok := reopened.GetByToken(token) + if !ok || sess.Username != "urania" { + t.Fatalf("session lost after reopen: got %+v ok=%v", sess, ok) + } +} + +func TestExpiredSessionRejected(t *testing.T) { + store, err := NewSessionStore(filepath.Join(t.TempDir(), "sessions.db")) + if err != nil { + t.Fatal(err) + } + // Write a row that expired an hour ago, bypassing Create's TTL. + _, err = store.db.Exec( + `INSERT INTO sessions (token, username, expires_at) VALUES (?, ?, ?)`, + "stale", "urania", time.Now().Add(-time.Hour).Unix(), + ) + if err != nil { + t.Fatal(err) + } + if _, ok := store.GetByToken("stale"); ok { + t.Fatal("expired session was accepted") + } + // Lazy cleanup should have deleted the row. + var n int + store.db.QueryRow(`SELECT count(*) FROM sessions WHERE token = ?`, "stale").Scan(&n) + if n != 0 { + t.Fatalf("expired row not cleaned up: %d rows remain", n) + } +} + +func TestDeleteInvalidatesSession(t *testing.T) { + store, err := NewSessionStore(filepath.Join(t.TempDir(), "sessions.db")) + if err != nil { + t.Fatal(err) + } + token, err := store.Create("urania") + if err != nil { + t.Fatal(err) + } + if _, ok := store.GetByToken(token); !ok { + t.Fatal("session should exist before logout") + } + if err := store.Delete(token); err != nil { + t.Fatal(err) + } + if _, ok := store.GetByToken(token); ok { + t.Fatal("session still valid after logout") + } + // Deleting an unknown/already-deleted token is a no-op, not an error. + if err := store.Delete(token); err != nil { + t.Errorf("deleting unknown token should be a no-op, got %v", err) + } +} + +func TestUnknownTokenRejected(t *testing.T) { + store, err := NewSessionStore(filepath.Join(t.TempDir(), "sessions.db")) + if err != nil { + t.Fatal(err) + } + if _, ok := store.GetByToken("nope"); ok { + t.Fatal("unknown token was accepted") + } +} diff --git a/internal/auth/throttle.go b/internal/auth/throttle.go new file mode 100644 index 0000000..ac084a8 --- /dev/null +++ b/internal/auth/throttle.go @@ -0,0 +1,110 @@ +package auth + +import ( + "context" + "net" + "net/http" + "strings" + "sync" + "time" +) + +// failLimiter throttles repeated login failures keyed by "username|ip". After +// max failures it imposes a cooldown of window before any further attempt for +// that key is accepted, blunting brute force against PAM/shadow. +// +// ponytail: in-memory, single-process, fixed cooldown - correct for a single +// app instance. State is lost on restart, but only an operator restarts the +// process (an attacker can't), so persisting it would buy nothing. Source +// spoofing is handled by the network (VPN-only + trusted proxy sets XFF), not +// here. Reach for pam_faillock only if a failed web login should also lock the +// OS account against ssh/console - a different layer we deliberately don't span. +type failLimiter struct { + mu sync.Mutex + attempts map[string]*attemptState + max int + window time.Duration +} + +type attemptState struct { + count int + until time.Time +} + +// maxTrackedKeys bounds memory: an attacker rotating username/IP can't grow the +// map without limit. When exceeded we drop all throttle state - a crude reset +// that briefly forgets cooldowns, acceptable for a single-node panel. +const maxTrackedKeys = 10000 + +func newFailLimiter(max int, window time.Duration) *failLimiter { + return &failLimiter{attempts: map[string]*attemptState{}, max: max, window: window} +} + +// blocked reports whether the key is currently in cooldown. +func (l *failLimiter) blocked(key string) bool { + l.mu.Lock() + defer l.mu.Unlock() + s := l.attempts[key] + return s != nil && time.Now().Before(s.until) +} + +// fail records a failed attempt and starts a cooldown once max is reached. +func (l *failLimiter) fail(key string) { + l.mu.Lock() + defer l.mu.Unlock() + if len(l.attempts) > maxTrackedKeys { + l.attempts = map[string]*attemptState{} + } + s := l.attempts[key] + if s == nil { + s = &attemptState{} + l.attempts[key] = s + } + s.count++ + if s.count >= l.max { + s.until = time.Now().Add(l.window) + s.count = 0 // restart the window after the cooldown is set + } +} + +// reset clears state for a key after a successful login. +func (l *failLimiter) reset(key string) { + l.mu.Lock() + defer l.mu.Unlock() + delete(l.attempts, key) +} + +type ctxKey int + +const clientIPKey ctxKey = 0 + +// WithClientIP wraps a handler so the client's IP is available downstream via +// ClientIP, so the login throttle can key on source IP without coupling to a +// specific HTTP adapter. +// +// When trustProxy is set, the IP is taken from the first X-Forwarded-For hop +// (the original client behind the reverse proxy). X-Forwarded-For is fully +// caller-controlled, so this is only honored when an admin has opted in by +// putting nadir behind a proxy - and nadir must then be reachable only by that +// proxy. When trustProxy is false the header is ignored and RemoteAddr wins. +func WithClientIP(trustProxy bool, next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ip, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + ip = r.RemoteAddr + } + if trustProxy { + if xff := r.Header.Get("X-Forwarded-For"); xff != "" { + first, _, _ := strings.Cut(xff, ",") + ip = strings.TrimSpace(first) + } + } + next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), clientIPKey, ip))) + }) +} + +// ClientIP returns the IP recorded by WithClientIP, or "" if absent. +func ClientIP(ctx context.Context) string { + ip, _ := ctx.Value(clientIPKey).(string) + return ip +} diff --git a/internal/auth/tokens.go b/internal/auth/tokens.go new file mode 100644 index 0000000..af748f0 --- /dev/null +++ b/internal/auth/tokens.go @@ -0,0 +1,131 @@ +package auth + +import ( + "crypto/sha256" + "database/sql" + "encoding/hex" + "fmt" + "os" + "path/filepath" + "strings" + "time" + + _ "modernc.org/sqlite" +) + +// tokenPrefix marks a Nadir machine credential, like GitHub's "ghp_". It's +// cosmetic (helps secret scanners and humans recognize a leaked token), not a +// security boundary - the prefix is hashed and stored along with the rest. +const tokenPrefix = "nad_" + +// TokenStore persists machine credentials for non-interactive callers (a +// central dashboard managing N nodes), so they authenticate with a static +// Bearer token instead of a per-host PAM session. Only the SHA-256 of each +// token is stored - a leaked DB or backup can't hand out live credentials. +// +// It lives in its own SQLite file because both the server (read, on every +// Bearer request) and the `nadir token` CLI (write, when minting/revoking) +// touch it. WAL + a busy timeout let those two processes share the file without +// "database is locked". +type TokenStore struct { + db *sql.DB +} + +// TokenInfo is one stored credential's public metadata (never the secret). +type TokenInfo struct { + Name string + Created time.Time +} + +// NewTokenStore opens (creating if needed) the token database at path. +func NewTokenStore(path string) (*TokenStore, error) { + if err := os.MkdirAll(filepath.Dir(path), 0700); err != nil { + return nil, fmt.Errorf("token db dir: %w", err) + } + // WAL + busy_timeout: the server process and the CLI process both open this + // file, so a plain rollback journal would surface transient lock errors. + dsn := path + "?_pragma=busy_timeout(5000)&_pragma=journal_mode(WAL)" + db, err := sql.Open("sqlite", dsn) + if err != nil { + return nil, fmt.Errorf("open token db: %w", err) + } + if _, err := db.Exec(`CREATE TABLE IF NOT EXISTS tokens ( + name TEXT PRIMARY KEY, + token_hash TEXT NOT NULL UNIQUE, + created_at INTEGER NOT NULL + )`); err != nil { + return nil, fmt.Errorf("create tokens table: %w", err) + } + return &TokenStore{db: db}, nil +} + +func (s *TokenStore) Close() error { return s.db.Close() } + +// Create mints a new token named name and returns the raw secret. The secret is +// shown only here, at generation time - only its hash is stored. A duplicate +// name is rejected (the caller should revoke and re-mint to rotate). +func (s *TokenStore) Create(name string) (string, error) { + if name == "" { + return "", fmt.Errorf("token name required") + } + raw := tokenPrefix + randomToken() + if _, err := s.db.Exec( + `INSERT INTO tokens (name, token_hash, created_at) VALUES (?, ?, ?)`, + name, hashToken(raw), time.Now().Unix(), + ); err != nil { + return "", fmt.Errorf("store token %q (already exists?): %w", name, err) + } + return raw, nil +} + +// Lookup returns the token name for a presented raw secret. +// +// No constant-time compare is needed: we look the secret up by its SHA-256, so +// the value compared in the index is already a hash of attacker-controlled +// input. A timing side-channel could at most leak a stored hash, which is +// useless without a SHA-256 preimage (the actual token). +func (s *TokenStore) Lookup(raw string) (string, bool) { + if !strings.HasPrefix(raw, tokenPrefix) { + return "", false + } + var name string + err := s.db.QueryRow( + `SELECT name FROM tokens WHERE token_hash = ?`, hashToken(raw), + ).Scan(&name) + if err != nil { + return "", false + } + return name, true +} + +// Delete revokes a token by name. Revoking an unknown name is a no-op. The +// change is effective immediately - the server reads this DB live, no restart. +func (s *TokenStore) Delete(name string) error { + _, err := s.db.Exec(`DELETE FROM tokens WHERE name = ?`, name) + return err +} + +// List returns all token names with their creation time, newest first. +func (s *TokenStore) List() ([]TokenInfo, error) { + rows, err := s.db.Query(`SELECT name, created_at FROM tokens ORDER BY created_at DESC`) + if err != nil { + return nil, err + } + defer rows.Close() + infos := []TokenInfo{} + for rows.Next() { + var t TokenInfo + var created int64 + if err := rows.Scan(&t.Name, &created); err != nil { + return nil, err + } + t.Created = time.Unix(created, 0).UTC() + infos = append(infos, t) + } + return infos, rows.Err() +} + +func hashToken(raw string) string { + sum := sha256.Sum256([]byte(raw)) + return hex.EncodeToString(sum[:]) +} diff --git a/internal/auth/tokens_test.go b/internal/auth/tokens_test.go new file mode 100644 index 0000000..cd8c225 --- /dev/null +++ b/internal/auth/tokens_test.go @@ -0,0 +1,45 @@ +package auth + +import ( + "path/filepath" + "testing" +) + +func TestTokenStore(t *testing.T) { + store, err := NewTokenStore(filepath.Join(t.TempDir(), "tokens.db")) + if err != nil { + t.Fatal(err) + } + defer store.Close() + + raw, err := store.Create("dash") + if err != nil { + t.Fatal(err) + } + if len(raw) < len(tokenPrefix)+32 || raw[:len(tokenPrefix)] != tokenPrefix { + t.Fatalf("token %q lacks %q prefix or is too short", raw, tokenPrefix) + } + + // Round-trip: the minted secret resolves to its name. + if name, ok := store.Lookup(raw); !ok || name != "dash" { + t.Errorf("Lookup(valid) = %q,%v; want dash,true", name, ok) + } + // A wrong secret (and a non-prefixed one) must not resolve. + if _, ok := store.Lookup(tokenPrefix + "wrong"); ok { + t.Error("Lookup(wrong) succeeded") + } + if _, ok := store.Lookup("no-prefix"); ok { + t.Error("Lookup(no prefix) succeeded") + } + // Duplicate name is rejected. + if _, err := store.Create("dash"); err == nil { + t.Error("Create duplicate name succeeded; want error") + } + // Revocation is immediate. + if err := store.Delete("dash"); err != nil { + t.Fatal(err) + } + if _, ok := store.Lookup(raw); ok { + t.Error("Lookup after Delete succeeded") + } +} diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..b6e3996 --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,173 @@ +package config + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + "gopkg.in/yaml.v3" + + "nadir/internal/module" + "nadir/internal/rbac" +) + +// File is the on-disk YAML shape. Module keys and permission values may be +// "*" (must be quoted in YAML to avoid alias syntax). +// +// Example: +// +// server: +// secure_tls: false +// roles: +// admin: +// "*": ["*"] +// auditor: +// "*": [read] +// assignments: +// urania: [admin] +// +// LogFiles is the allowlist of log files readable per unit via the file log +// source. The path query parameter is matched against this set, so an admin +// (not the caller) decides which files are exposable. Example: +// +// log_files: +// nginx.service: +// - /var/log/nginx/access.log +// - /var/log/nginx/error.log +type File struct { + Server Server `yaml:"server"` + Roles map[string]map[string][]string `yaml:"roles"` + Assignments map[string][]string `yaml:"assignments"` + LogFiles map[string][]string `yaml:"log_files"` +} + +// Server holds process-level settings. +type Server struct { + // SecureTLS controls the Secure attribute on the session cookie. A pointer + // so an omitted key means "unset" (defaults to true / production-safe) + // rather than the zero value false. Set false for local HTTP development. + SecureTLS *bool `yaml:"secure_tls"` + Hostname string `yaml:"hostname"` + Port string `yaml:"port"` + // TLS is provided one of three ways, in priority order: + // 1. TrustProxy true - a reverse proxy (e.g. https://example.com) does + // TLS; nadir serves plaintext HTTP and trusts X-Forwarded-For. nadir + // must then be reachable ONLY by the proxy (bind localhost). + // 2. TLSCert+TLSKey - nadir terminates TLS with this PEM pair. + // 3. neither - nadir self-signs in memory (dev only). + // Keep secure_tls true in modes 1 and 2 so the session cookie stays Secure. + TrustProxy bool `yaml:"trust_proxy"` + TLSCert string `yaml:"tls_cert"` + TLSKey string `yaml:"tls_key"` +} + +// SecureCookie reports whether the session cookie should carry the Secure +// attribute, defaulting to true when server.secure_tls is omitted. +func (f *File) SecureCookie() bool { + if f.Server.SecureTLS == nil { + return true + } + return *f.Server.SecureTLS +} + +// DefaultPath returns the default configuration path at +// ~/.config/nadir/config.yaml (honoring XDG_CONFIG_HOME via os.UserConfigDir). +func DefaultPath() (string, error) { + dir, err := os.UserConfigDir() + if err != nil { + return "", fmt.Errorf("detect config dir: %w", err) + } + return filepath.Join(dir, "nadir", "config.yaml"), nil +} + +// ExpandPath expands leading ~ to the current user's home directory. +func ExpandPath(path string) (string, error) { + if strings.HasPrefix(path, "~/") { + home, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("detect home dir: %w", err) + } + return filepath.Join(home, path[2:]), nil + } + return path, nil +} + +// Load reads and parses the YAML file at path. +func Load(path string) (*File, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("read config %s: %w", path, err) + } + var f File + if err := yaml.Unmarshal(data, &f); err != nil { + return nil, fmt.Errorf("parse config %s: %w", path, err) + } + return &f, nil +} + +// Apply validates the config against the loaded modules and installs roles +// + assignments into the RBAC store. Any reference to an unknown module or +// unknown permission causes a startup error. +func Apply(f *File, roles *rbac.RBAC, mods []module.Module) error { + // Build a lookup: module ID -> the permissions it exposes. + knownPerms := map[string]map[rbac.Permission]bool{} + for _, m := range mods { + set := map[rbac.Permission]bool{} + for _, p := range m.Permissions() { + set[p] = true + } + knownPerms[m.ID()] = set + } + + // 1. Define roles, validating every named module + permission. + for roleName, grants := range f.Roles { + converted := map[string][]rbac.Permission{} + for modKey, perms := range grants { + if modKey != rbac.Wildcard { + if _, ok := knownPerms[modKey]; !ok { + return fmt.Errorf("role %q references unknown module %q", roleName, modKey) + } + } + permList := make([]rbac.Permission, 0, len(perms)) + for _, raw := range perms { + p := rbac.Permission(raw) + if p == rbac.All { + permList = append(permList, p) + continue + } + if !permissionExists(p, modKey, knownPerms) { + return fmt.Errorf("role %q grants %q on module %q, but that permission is not exported by any matching module", roleName, raw, modKey) + } + permList = append(permList, p) + } + converted[modKey] = permList + } + roles.DefineRole(rbac.Role{Name: roleName, ModuleGrants: converted}) + } + + // 2. Apply assignments, validating role names. + for user, assigned := range f.Assignments { + for _, roleName := range assigned { + if !roles.RoleExists(roleName) { + return fmt.Errorf("user %q assigned to unknown role %q", user, roleName) + } + roles.AssignRole(user, roleName) + } + } + return nil +} + +// permissionExists returns true if perm is exported by the named module, +// or by any module when modKey is the wildcard. +func permissionExists(perm rbac.Permission, modKey string, known map[string]map[rbac.Permission]bool) bool { + if modKey == rbac.Wildcard { + for _, perms := range known { + if perms[perm] { + return true + } + } + return false + } + return known[modKey][perm] +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..4a75cf5 --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,152 @@ +package config + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "nadir/internal/module" + "nadir/internal/rbac" + + "github.com/danielgtaylor/huma/v2" +) + +// fakeModule implements module.Module with a fixed permission set, so config +// tests don't depend on the concrete modules' exec behavior. +type fakeModule struct { + id string + perms []rbac.Permission +} + +func (f fakeModule) ID() string { return f.id } +func (f fakeModule) Name() string { return f.id } +func (f fakeModule) Permissions() []rbac.Permission { return f.perms } +func (f fakeModule) Register(huma.API) {} + +func mods() []module.Module { + return []module.Module{ + fakeModule{id: "system", perms: []rbac.Permission{rbac.Read, rbac.Write, rbac.Root}}, + fakeModule{id: "services", perms: []rbac.Permission{rbac.Read, rbac.Write}}, + } +} + +func TestSecureCookieDefaultsTrue(t *testing.T) { + if !(&File{}).SecureCookie() { + t.Error("omitted secure_tls should default to true") + } + no := false + if (&File{Server: Server{SecureTLS: &no}}).SecureCookie() { + t.Error("secure_tls: false should disable the Secure flag") + } +} + +func TestLoad(t *testing.T) { + path := filepath.Join(t.TempDir(), "config.yaml") + os.WriteFile(path, []byte(` +server: + secure_tls: false +roles: + admin: + "*": ["*"] +assignments: + urania: [admin] +`), 0600) + + f, err := Load(path) + if err != nil { + t.Fatal(err) + } + if f.SecureCookie() { + t.Error("secure_tls: false not parsed") + } + if len(f.Roles["admin"]) == 0 || len(f.Assignments["urania"]) == 0 { + t.Error("roles/assignments not parsed") + } +} + +func TestLoadMissingFile(t *testing.T) { + if _, err := Load(filepath.Join(t.TempDir(), "nope.yaml")); err == nil { + t.Fatal("expected error for missing file") + } +} + +func TestApplyValid(t *testing.T) { + f := &File{ + Roles: map[string]map[string][]string{ + "admin": {"*": {"*"}}, + "auditor": {"*": {"read"}}, + "sysop": {"system": {"read", "root"}}, + }, + Assignments: map[string][]string{"urania": {"admin"}, "bob": {"sysop"}}, + } + roles := rbac.New() + if err := Apply(f, roles, mods()); err != nil { + t.Fatal(err) + } + if !roles.Can("urania", "services", rbac.Write) { + t.Error("admin wildcard not applied") + } + if !roles.Can("bob", "system", rbac.Root) || roles.Can("bob", "system", rbac.Write) { + t.Error("sysop grants not applied correctly") + } +} + +func TestApplyErrors(t *testing.T) { + tests := []struct { + name string + f *File + }{ + {"unknown module", &File{Roles: map[string]map[string][]string{ + "r": {"firewall": {"read"}}}}}, + {"unknown perm on known module", &File{Roles: map[string]map[string][]string{ + "r": {"system": {"banana"}}}}}, + {"wildcard module with perm no module exports", &File{Roles: map[string]map[string][]string{ + "r": {"*": {"banana"}}}}}, + {"assignment to unknown role", &File{ + Roles: map[string]map[string][]string{"r": {"system": {"read"}}}, + Assignments: map[string][]string{"u": {"ghost"}}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := Apply(tt.f, rbac.New(), mods()); err == nil { + t.Errorf("expected error for %s", tt.name) + } + }) + } +} + +func TestApplyWildcardPermAlwaysOK(t *testing.T) { + f := &File{Roles: map[string]map[string][]string{"r": {"system": {"*"}}}} + if err := Apply(f, rbac.New(), mods()); err != nil { + t.Fatalf("wildcard permission should validate: %v", err) + } +} + +func TestDefaultPathAndExpandPath(t *testing.T) { + defaultPath, err := DefaultPath() + if err != nil { + t.Fatalf("DefaultPath failed: %v", err) + } + expectedSuffix := filepath.Join("nadir", "config.yaml") + if !filepath.IsAbs(defaultPath) || !strings.HasSuffix(defaultPath, expectedSuffix) { + t.Errorf("expected default path to end with %q and be absolute, got %q", expectedSuffix, defaultPath) + } + + expanded, err := ExpandPath("~/foo/config.yaml") + if err != nil { + t.Fatalf("ExpandPath failed: %v", err) + } + if !filepath.IsAbs(expanded) || !strings.HasSuffix(expanded, filepath.Join("foo", "config.yaml")) { + t.Errorf("ExpandPath did not resolve ~/ correctly, got %q", expanded) + } + + plain := "/etc/nadir/config.yaml" + expandedPlain, err := ExpandPath(plain) + if err != nil { + t.Fatalf("ExpandPath failed for plain path: %v", err) + } + if expandedPlain != plain { + t.Errorf("expected no-op for %q, got %q", plain, expandedPlain) + } +} diff --git a/internal/meta/health.go b/internal/meta/health.go new file mode 100644 index 0000000..61d0aae --- /dev/null +++ b/internal/meta/health.go @@ -0,0 +1,44 @@ +package meta + +import ( + "context" + + "nadir" + "nadir/internal/auth" + + "github.com/danielgtaylor/huma/v2" +) + +type HealthOutput struct { + Body struct { + Status string `json:"status" example:"ok" doc:"Overall health"` + Database string `json:"database" example:"ok" doc:"Embedded SQLite session store state"` + Version string `json:"version" example:"1.0.0" doc:"Application version"` + } +} + +// RegisterHealth adds a public liveness/readiness probe. It is intentionally +// unauthenticated (no permission metadata) so load balancers and orchestrators +// can reach it. Returns 503 when the SQLite session store is unreachable, so +// probes can key off the status code without parsing the body. +func RegisterHealth(api huma.API, sessions *auth.SessionStore) { + huma.Register(api, huma.Operation{ + OperationID: "health", + Method: "GET", + Path: "/api/health", + Summary: "Health check", + Description: "Public liveness/readiness probe. Reports whether the embedded " + + "SQLite session store is reachable. Returns 503 when it is not.", + Tags: []string{"Meta"}, + Errors: []int{503}, + }, func(ctx context.Context, _ *struct{}) (*HealthOutput, error) { + if err := sessions.Ping(); err != nil { + return nil, huma.Error503ServiceUnavailable("session database unreachable", err) + } + out := &HealthOutput{} + out.Body.Status = "ok" + out.Body.Database = "ok" + out.Body.Version = nadir.Version + return out, nil + }) +} diff --git a/internal/meta/meta.go b/internal/meta/meta.go new file mode 100644 index 0000000..a1d129a --- /dev/null +++ b/internal/meta/meta.go @@ -0,0 +1,58 @@ +package meta + +import ( + "cmp" + "context" + "slices" + + "nadir/internal/module" + + "github.com/danielgtaylor/huma/v2" +) + +// ModuleInfo describes a registered module: its stable ID, display name, and +// the permissions it exposes. The frontend uses this to drive navigation and +// render the role/permission matrix. +type ModuleInfo struct { + ID string `json:"id" example:"system" doc:"Stable module identifier"` + Name string `json:"name" example:"System" doc:"Human-readable module name"` + Permissions []string `json:"permissions" doc:"Permissions this module exposes (never includes the \"*\" wildcard)"` +} + +type ModulesOutput struct { + Body struct { + Modules []ModuleInfo `json:"modules" doc:"Registered modules, sorted by ID"` + } +} + +// Register adds the read-only module-discovery endpoint. It is intentionally +// public: it exposes only the API's static shape (module IDs and permission +// vocabulary), the same information already served by /openapi.json. The module +// list is fixed at startup, so the response is computed once here. +func Register(api huma.API, mods []module.Module) { + infos := make([]ModuleInfo, 0, len(mods)) + for _, m := range mods { + perms := m.Permissions() + ps := make([]string, len(perms)) + for i, p := range perms { + ps[i] = string(p) + } + infos = append(infos, ModuleInfo{ID: m.ID(), Name: m.Name(), Permissions: ps}) + } + slices.SortFunc(infos, func(a, b ModuleInfo) int { return cmp.Compare(a.ID, b.ID) }) + + huma.Register(api, huma.Operation{ + OperationID: "list-modules", + Method: "GET", + Path: "/api/_modules", + Summary: "List registered modules", + Description: "Returns every registered module with its ID, display name, " + + "and exported permissions. Public (same static shape as /openapi.json); " + + "used by the frontend for navigation and the role/permission matrix.", + Tags: []string{"Meta"}, + }, func(ctx context.Context, _ *struct{}) (*ModulesOutput, error) { + out := &ModulesOutput{} + out.Body.Modules = infos + return out, nil + }) +} diff --git a/internal/meta/whoami.go b/internal/meta/whoami.go new file mode 100644 index 0000000..ab3c411 --- /dev/null +++ b/internal/meta/whoami.go @@ -0,0 +1,67 @@ +package meta + +import ( + "context" + + "nadir/internal/auth" + "nadir/internal/module" + "nadir/internal/rbac" + + "github.com/danielgtaylor/huma/v2" +) + +// WhoamiInput carries the session cookie. The endpoint is not behind the RBAC +// middleware (it requires no specific permission), so it validates the session +// itself. +type WhoamiInput struct { + SessionID string `cookie:"nadir_session_id"` +} + +// WhoamiBody reports who the caller is and, per module, which permissions they +// actually hold. Combined with /api/_modules (the full module/permission grid), +// this gives the frontend everything it needs to render the permission matrix. +type WhoamiBody struct { + Username string `json:"username" example:"urania" doc:"Authenticated username"` + Permissions map[string][]string `json:"permissions" doc:"Module ID -> permissions the caller holds. Modules where they hold none are omitted."` +} + +type WhoamiOutput struct{ Body WhoamiBody } + +// RegisterWhoami adds the current-user endpoint. It resolves the caller's +// concrete grants by asking the RBAC store about each module's permissions, +// so "*" wildcards in roles are expanded for free. +func RegisterWhoami(api huma.API, sessions *auth.SessionStore, roles *rbac.RBAC, mods []module.Module) { + huma.Register(api, huma.Operation{ + OperationID: "whoami", + Method: "GET", + Path: "/api/whoami", + Summary: "Get the current user and their permissions", + Description: "Returns the authenticated username and, per module, the " + + "permissions the caller holds (wildcards resolved). Pair with " + + "/api/_modules to render the full permission matrix.", + Tags: []string{"Meta"}, + Errors: []int{401}, + }, func(ctx context.Context, in *WhoamiInput) (*WhoamiOutput, error) { + sess, ok := sessions.GetByToken(in.SessionID) + if !ok { + return nil, huma.Error401Unauthorized("unauthorized") + } + + held := make(map[string][]string) + for _, m := range mods { + var perms []string + for _, p := range m.Permissions() { + if roles.Can(sess.Username, m.ID(), p) { + perms = append(perms, string(p)) + } + } + if len(perms) > 0 { + held[m.ID()] = perms + } + } + + out := &WhoamiOutput{} + out.Body = WhoamiBody{Username: sess.Username, Permissions: held} + return out, nil + }) +} diff --git a/internal/module/module.go b/internal/module/module.go new file mode 100644 index 0000000..aae3031 --- /dev/null +++ b/internal/module/module.go @@ -0,0 +1,14 @@ +package module + +import ( + "nadir/internal/rbac" + + "github.com/danielgtaylor/huma/v2" +) + +type Module interface { + ID() string + Name() string + Permissions() []rbac.Permission // permissions this module exposes (no "*") + Register(api huma.API) +} diff --git a/internal/modules/audit/module.go b/internal/modules/audit/module.go new file mode 100644 index 0000000..0dce079 --- /dev/null +++ b/internal/modules/audit/module.go @@ -0,0 +1,62 @@ +package audit + +import ( + "context" + + "nadir/internal/auditlog" + "nadir/internal/rbac" + + "github.com/danielgtaylor/huma/v2" +) + +const ModuleID = "audit" + +type Module struct { + store *auditlog.Store +} + +func New(store *auditlog.Store) *Module { return &Module{store: store} } + +func (m *Module) ID() string { return ModuleID } +func (m *Module) Name() string { return "Audit" } + +// Permissions: read to view the audit trail. There is no write - entries are +// produced by the middleware, never by an API call. +func (m *Module) Permissions() []rbac.Permission { + return []rbac.Permission{rbac.Read} +} + +// Types are named AuditList* (not ListInput/ListOutput) because Huma derives +// OpenAPI schema names from the Go type name alone, not package-qualified, so a +// bare "ListOutput" here would collide with the packages module's. +type AuditListInput struct { + Limit int `query:"limit" default:"200" minimum:"1" maximum:"10000" doc:"Max entries to return, newest first"` +} + +type AuditListOutput struct { + Body struct { + Entries []auditlog.Entry `json:"entries" doc:"Recorded actions, newest first"` + } +} + +func (m *Module) Register(api huma.API) { + huma.Register(api, huma.Operation{ + OperationID: "audit-list", + Method: "GET", + Path: "/api/audit", + Summary: "List recorded actions", + Description: "Returns the audit trail of privileged write operations " + + "(who, what, when, result), newest first.", + Tags: []string{"Audit"}, + Metadata: map[string]any{"module": ModuleID, "permission": "read"}, + Errors: []int{401, 403, 500}, + }, func(ctx context.Context, in *AuditListInput) (*AuditListOutput, error) { + entries, err := m.store.List(in.Limit) + if err != nil { + return nil, huma.Error500InternalServerError("read audit log failed", err) + } + out := &AuditListOutput{} + out.Body.Entries = entries + return out, nil + }) +} diff --git a/internal/modules/groups/groups.go b/internal/modules/groups/groups.go new file mode 100644 index 0000000..bfb5fe7 --- /dev/null +++ b/internal/modules/groups/groups.go @@ -0,0 +1,225 @@ +package groups + +import ( + "context" + "os" + "regexp" + "strconv" + "strings" + + "nadir/internal/oscmd" + + "github.com/danielgtaylor/huma/v2" +) + +const tagGroups = "Groups" + +var groupPath = "/etc/group" + +// systemGIDMax mirrors the user convention: regular groups start at 1000. +const systemGIDMax = 1000 + +var ( + readErrors = []int{401, 403, 500} + writeErrors = []int{400, 401, 403, 404, 409, 500} +) + +// groupNameRe matches valid group names (same rule as usernames). Starting with +// a letter/underscore also rejects leading-dash flag injection. +var groupNameRe = regexp.MustCompile(`^[a-z_][a-z0-9_-]{0,31}\$?$`) + +// Group mirrors one /etc/group entry. +type Group struct { + Name string `json:"name" example:"wheel" doc:"Group name"` + GID int `json:"gid" example:"10" doc:"Group ID"` + Members []string `json:"members" doc:"Supplementary members (primary-group members are not listed here)"` + System bool `json:"system" doc:"True for system groups (gid < 1000)"` +} + +type ListGroupsOutput struct { + Body struct { + Groups []Group `json:"groups" doc:"All groups from /etc/group"` + } +} + +type GetGroupOutput struct{ Body Group } + +type GroupPath struct { + Group string `path:"group" example:"wheel" doc:"Group name"` +} + +type CreateGroupInput struct { + Body struct { + Name string `json:"name" example:"developers" doc:"Group name"` + GID *int `json:"gid,omitempty" example:"1500" doc:"Explicit GID (omit to auto-assign)"` + System bool `json:"system,omitempty" doc:"Create a system group (groupadd --system)"` + } +} + +func registerGroups(api huma.API) { + huma.Register(api, huma.Operation{ + OperationID: "groups-list", + Method: "GET", + Path: "/api/groups", + Summary: "List groups", + Description: "Returns every group in /etc/group, including system groups " + + "(flagged via `system`).", + Tags: []string{tagGroups}, + Metadata: op("read"), + Errors: readErrors, + }, func(ctx context.Context, _ *struct{}) (*ListGroupsOutput, error) { + list, err := listGroups() + if err != nil { + return nil, huma.Error500InternalServerError("read "+groupPath+" failed", err) + } + out := &ListGroupsOutput{} + out.Body.Groups = list + return out, nil + }) + + huma.Register(api, huma.Operation{ + OperationID: "groups-get", + Method: "GET", + Path: "/api/groups/{group}", + Summary: "Get a single group", + Description: "Returns one group by name. 404 if it does not exist.", + Tags: []string{tagGroups}, + Metadata: op("read"), + Errors: []int{400, 401, 403, 404, 500}, + }, func(ctx context.Context, in *GroupPath) (*GetGroupOutput, error) { + if err := validateGroupName(in.Group); err != nil { + return nil, err + } + g, ok, err := lookupGroup(in.Group) + if err != nil { + return nil, huma.Error500InternalServerError("read "+groupPath+" failed", err) + } + if !ok { + return nil, huma.Error404NotFound("group not found: " + in.Group) + } + return &GetGroupOutput{Body: g}, nil + }) + + huma.Register(api, huma.Operation{ + OperationID: "groups-create", + Method: "POST", + Path: "/api/groups", + Summary: "Create a group", + Description: "Creates a group via groupadd. 409 if the group already exists.", + Tags: []string{tagGroups}, + Metadata: op("write"), + Errors: writeErrors, + }, func(ctx context.Context, in *CreateGroupInput) (*GetGroupOutput, error) { + if err := validateGroupName(in.Body.Name); err != nil { + return nil, err + } + if _, ok, err := lookupGroup(in.Body.Name); err != nil { + return nil, huma.Error500InternalServerError("read "+groupPath+" failed", err) + } else if ok { + return nil, huma.Error409Conflict("group already exists: " + in.Body.Name) + } + + args := []string{} + if in.Body.System { + args = append(args, "--system") + } + if in.Body.GID != nil { + args = append(args, "-g", strconv.Itoa(*in.Body.GID)) + } + args = append(args, "--", in.Body.Name) + if _, err := oscmd.Run("groupadd", args...); err != nil { + return nil, huma.Error500InternalServerError("groupadd failed", err) + } + + g, ok, err := lookupGroup(in.Body.Name) + if err != nil || !ok { + return nil, huma.Error500InternalServerError("group created but could not be read back", err) + } + return &GetGroupOutput{Body: g}, nil + }) + + huma.Register(api, huma.Operation{ + OperationID: "groups-delete", + Method: "DELETE", + Path: "/api/groups/{group}", + Summary: "Delete a group", + Description: "Removes a group via groupdel. Returns 409 if it is the primary " + + "group of an existing user (groupdel refuses), 404 if it does not exist.", + Tags: []string{tagGroups}, + // Deleting a group is irreversible - gated behind root, not write. + Metadata: op("root"), + Errors: []int{400, 401, 403, 404, 409, 500}, + }, func(ctx context.Context, in *GroupPath) (*oscmd.StatusOutput, error) { + if err := validateGroupName(in.Group); err != nil { + return nil, err + } + if _, ok, err := lookupGroup(in.Group); err != nil { + return nil, huma.Error500InternalServerError("read "+groupPath+" failed", err) + } else if !ok { + return nil, huma.Error404NotFound("group not found: " + in.Group) + } + if _, err := oscmd.Run("groupdel", "--", in.Group); err != nil { + // groupdel refuses to remove a user's primary group; make that actionable. + return nil, huma.Error409Conflict("groupdel failed (is it a user's primary group?)", err) + } + return oscmd.OK(), nil + }) +} + +func validateGroupName(name string) error { + if !groupNameRe.MatchString(name) { + return huma.Error400BadRequest("invalid group name: " + name) + } + return nil +} + +func listGroups() ([]Group, error) { + data, err := os.ReadFile(groupPath) + if err != nil { + return nil, err + } + return parseGroup(data), nil +} + +func lookupGroup(name string) (Group, bool, error) { + list, err := listGroups() + if err != nil { + return Group{}, false, err + } + for _, g := range list { + if g.Name == name { + return g, true, nil + } + } + return Group{}, false, nil +} + +// parseGroup parses /etc/group content. Blank, commented, or malformed lines +// (fewer than 4 fields, non-numeric gid) are skipped. +func parseGroup(data []byte) []Group { + var groups []Group + for line := range strings.SplitSeq(string(data), "\n") { + if line == "" || strings.HasPrefix(line, "#") { + continue + } + f := strings.Split(line, ":") + if len(f) < 4 { + continue + } + gid, err := strconv.Atoi(f[2]) + if err != nil { + continue + } + var members []string + if f[3] != "" { + members = strings.Split(f[3], ",") + } + groups = append(groups, Group{ + Name: f[0], + GID: gid, + Members: members, + System: gid < systemGIDMax, + }) + } + return groups +} diff --git a/internal/modules/groups/groups_handler_test.go b/internal/modules/groups/groups_handler_test.go new file mode 100644 index 0000000..550a78b --- /dev/null +++ b/internal/modules/groups/groups_handler_test.go @@ -0,0 +1,109 @@ +package groups + +import ( + "encoding/json" + "net/http" + "os" + "path/filepath" + "reflect" + "testing" + + "nadir/internal/oscmd" + + "github.com/danielgtaylor/huma/v2" + "github.com/danielgtaylor/huma/v2/adapters/humago" + "github.com/danielgtaylor/huma/v2/humatest" +) + +func TestMain(m *testing.M) { + if oscmd.RunHelperProcess() { + return + } + os.Exit(m.Run()) +} + +func TestGroupsHandlers(t *testing.T) { + tempGroup := filepath.Join(t.TempDir(), "group") + initialContent := "root:x:0:\nwheel:x:10:alice,bob\n" + if err := os.WriteFile(tempGroup, []byte(initialContent), 0644); err != nil { + t.Fatal(err) + } + + oldGroup := groupPath + groupPath = tempGroup + defer func() { groupPath = oldGroup }() + + mux := http.NewServeMux() + api := humatest.Wrap(t, humago.New(mux, huma.DefaultConfig("Test", "1.0.0"))) + registerGroups(api) + + // 1. Test GET /api/groups + resp := api.Get("/api/groups") + if resp.Code != http.StatusOK { + t.Errorf("list groups: got %d, want %d", resp.Code, http.StatusOK) + } + var listRes ListGroupsOutput + if err := json.Unmarshal(resp.Body.Bytes(), &listRes.Body); err != nil { + t.Fatal(err) + } + if len(listRes.Body.Groups) != 2 { + t.Errorf("got %d groups, want 2", len(listRes.Body.Groups)) + } + + // 2. Test GET /api/groups/{group} + resp = api.Get("/api/groups/wheel") + if resp.Code != http.StatusOK { + t.Errorf("get group: got %d, want %d", resp.Code, http.StatusOK) + } + var getRes GetGroupOutput + if err := json.Unmarshal(resp.Body.Bytes(), &getRes.Body); err != nil { + t.Fatal(err) + } + if getRes.Body.Name != "wheel" || getRes.Body.GID != 10 { + t.Errorf("get group: got %+v", getRes.Body) + } + + resp = api.Get("/api/groups/nonexistent") + if resp.Code != http.StatusNotFound { + t.Errorf("get non-existent group: got %d, want %d", resp.Code, http.StatusNotFound) + } + + // 3. Test POST /api/groups + oscmd.SetMock("groupadd", func(args []string) oscmd.MockCommand { + wantArgs := []string{"-g", "1500", "--", "dev"} + if !reflect.DeepEqual(args, wantArgs) { + t.Errorf("groupadd args: got %v, want %v", args, wantArgs) + } + devContent := initialContent + "dev:x:1500:\n" + os.WriteFile(tempGroup, []byte(devContent), 0644) + return oscmd.MockCommand{ExitCode: 0} + }) + defer oscmd.ClearMocks() + + gidVal := 1500 + resp = api.Post("/api/groups", struct { + Name string `json:"name"` + GID *int `json:"gid"` + }{ + Name: "dev", + GID: &gidVal, + }) + if resp.Code != http.StatusOK { + t.Errorf("create group: got %d, want %d", resp.Code, http.StatusOK) + } + + // 4. Test DELETE /api/groups/{group} + oscmd.SetMock("groupdel", func(args []string) oscmd.MockCommand { + wantArgs := []string{"--", "dev"} + if !reflect.DeepEqual(args, wantArgs) { + t.Errorf("groupdel args: got %v, want %v", args, wantArgs) + } + os.WriteFile(tempGroup, []byte(initialContent), 0644) + return oscmd.MockCommand{ExitCode: 0} + }) + + resp = api.Delete("/api/groups/dev") + if resp.Code != http.StatusOK { + t.Errorf("delete group: got %d, want %d", resp.Code, http.StatusOK) + } +} diff --git a/internal/modules/groups/groups_test.go b/internal/modules/groups/groups_test.go new file mode 100644 index 0000000..564ee1d --- /dev/null +++ b/internal/modules/groups/groups_test.go @@ -0,0 +1,51 @@ +package groups + +import ( + "reflect" + "testing" +) + +func TestParseGroup(t *testing.T) { + data := []byte(`root:x:0: +# comment + +wheel:x:10:alice,bob +developers:x:1500:alice +empty:x:1600: +broken:x:notanumber:x +short:x:5 +`) + got := parseGroup(data) + if len(got) != 4 { + t.Fatalf("expected 4 valid groups, got %d: %+v", len(got), got) + } + + wheel := got[1] + if wheel.Name != "wheel" || wheel.GID != 10 || !wheel.System || + !reflect.DeepEqual(wheel.Members, []string{"alice", "bob"}) { + t.Errorf("wheel parsed wrong: %+v", wheel) + } + + dev := got[2] + if dev.GID != 1500 || dev.System { + t.Errorf("developers should be a non-system group: %+v", dev) + } + + empty := got[3] + if len(empty.Members) != 0 { + t.Errorf("empty group should have no members, got %v", empty.Members) + } +} + +func TestValidateGroupName(t *testing.T) { + for _, n := range []string{"wheel", "_svc", "dev-team", "g1"} { + if err := validateGroupName(n); err != nil { + t.Errorf("validateGroupName(%q) = %v, want nil", n, err) + } + } + for _, n := range []string{"", "-x", "Wheel", "a,b", "foo;rm", "1grp"} { + if err := validateGroupName(n); err == nil { + t.Errorf("validateGroupName(%q) = nil, want error", n) + } + } +} diff --git a/internal/modules/groups/module.go b/internal/modules/groups/module.go new file mode 100644 index 0000000..e144dab --- /dev/null +++ b/internal/modules/groups/module.go @@ -0,0 +1,30 @@ +package groups + +import ( + "nadir/internal/rbac" + + "github.com/danielgtaylor/huma/v2" +) + +const ModuleID = "groups" + +type Module struct{} + +func New() *Module { return &Module{} } + +func (m *Module) ID() string { return ModuleID } +func (m *Module) Name() string { return "Groups" } + +// Permissions: read to list/inspect groups; write to create; root to delete +// (irreversible). Group membership lives in the users module. +func (m *Module) Permissions() []rbac.Permission { + return []rbac.Permission{rbac.Read, rbac.Write, rbac.Root} +} + +func (m *Module) Register(api huma.API) { + registerGroups(api) +} + +func op(permission string) map[string]any { + return map[string]any{"module": ModuleID, "permission": permission} +} diff --git a/internal/modules/networking/backend.go b/internal/modules/networking/backend.go new file mode 100644 index 0000000..368af30 --- /dev/null +++ b/internal/modules/networking/backend.go @@ -0,0 +1,90 @@ +package networking + +import ( + "context" + "fmt" + "os/exec" + "time" + + "nadir/internal/oscmd" +) + +// backend is the write-side abstraction. Each host network manager (nmcli, +// networkd, ifupdown) implements this. Reads go through `ip -j` and +// /etc/resolv.conf regardless - they are backend-agnostic (see read.go). +// +// When detect() finds no backend, Module.be is nil and all write endpoints +// return 501 Not Implemented. Reads still work. +// Methods that shell out take a context so a request that is cancelled (client +// disconnect, timeout) kills the slow command (e.g. `nmcli con up` waiting on +// DHCP). The timer-driven auto-revert, which must finish even with no client, +// passes context.Background(). +type backend interface { + // Name returns the backend identifier ("nmcli", "networkd", "ifupdown"). + Name() string + + // Snapshot captures the current IPv4 configuration of iface so it can be + // restored on rollback. The returned IfaceConfig is backend-specific: + // nmcli reads from NM's connection, networkd/ifupdown read their config + // files, falling back to live `ip` output when no managed file exists + // (in which case Method is "dhcp" - safest revert assumption). + Snapshot(ctx context.Context, iface string) (IfaceConfig, error) + + // Apply replaces the interface's IPv4 configuration with cfg. It is the + // caller's responsibility to have taken a Snapshot first (the rollback + // mechanism does this). Apply must be idempotent: calling it with the + // same cfg twice should leave the system in the same state. + Apply(ctx context.Context, iface string, cfg IfaceConfig) error + + // SetLinkUp brings the interface up. + SetLinkUp(ctx context.Context, iface string) error + + // SetLinkDown takes the interface down. + SetLinkDown(ctx context.Context, iface string) error +} + +// detect probes the host for a supported network manager, in priority order: +// +// 1. nmcli (NetworkManager) - the majority of desktop and modern server installs +// 2. networkctl (systemd-networkd) - common on minimal/container hosts +// 3. ifup/ifdown (ifupdown) - classic Debian/Ubuntu servers +// +// Returns nil when none is found. The order matters: some distros ship both NM +// and networkd; NM wins because it's the active manager in that case. +func detect() backend { + if _, err := exec.LookPath("nmcli"); err == nil { + if _, err := oscmd.Run("nmcli", "general", "status"); err == nil { + return &nmcliBackend{} + } + } + if _, err := exec.LookPath("networkctl"); err == nil { + if _, err := oscmd.Run("systemctl", "is-active", "--quiet", "systemd-networkd"); err == nil { + return &networkdBackend{} + } + } + if _, err := exec.LookPath("ifup"); err == nil { + if _, err := exec.LookPath("ifdown"); err == nil { + return &ifupdownBackend{} + } + } + return nil +} + +// pendingChange tracks a single in-flight change that has been applied but not +// yet confirmed. revert undoes it (re-apply the prior config, or bring a +// downed link back up). If the timer fires before confirmation, revert runs - +// protecting against lock-yourself-out mistakes. +// +// ponytail: one slot for the whole module, not per-interface. An admin makes one +// change at a time; a concurrent change to another iface is rejected with a 409 +// that says so. Key it by iface (a map) if multi-interface concurrency is ever +// needed. +type pendingChange struct { + Iface string // interface that was changed + revert func() error // undoes the change, for rollback + Timer *time.Timer // fires the auto-revert + Deadline time.Time // when the timer will fire (for the status endpoint) +} + +// errNoBackend is the 501 returned when no write backend was detected. +var errNoBackend = fmt.Errorf("no supported network backend detected (tried nmcli, networkctl, ifup/ifdown)") diff --git a/internal/modules/networking/command.go b/internal/modules/networking/command.go new file mode 100644 index 0000000..f5f774c --- /dev/null +++ b/internal/modules/networking/command.go @@ -0,0 +1,133 @@ +package networking + +import ( + "net/netip" + "regexp" + + "github.com/danielgtaylor/huma/v2" +) + +// tagNetworking groups every networking-module operation under one OpenAPI tag, +// keeping tags 1:1 with modules. +const tagNetworking = "Networking" + +var ( + readErrors = []int{401, 403, 500} + writeErrors = []int{400, 401, 403, 409, 500, 501} +) + +// ifaceNameRe matches a Linux interface name. Linux caps names at 15 bytes and +// forbids '/' and whitespace; we additionally reject a leading dash so a name +// can never be read as a flag by `ip`/`nmcli`/`ifup`. Every command line also +// passes user-supplied names after a "--" separator. +var ifaceNameRe = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9._@-]{0,14}$`) + +func validateIface(name string) error { + if !ifaceNameRe.MatchString(name) { + return huma.Error400BadRequest("invalid interface name: " + name) + } + return nil +} + +// IfaceConfig is the declarative desired state for one interface. A PUT replaces +// the interface's IPv4 configuration with this (and IPv6 too, when the IPv6 block +// is included), across whichever backend is in use. Routes are part of the config +// (declarative add/remove): the client sends the full set it wants. +type IfaceConfig struct { + Method string `json:"method" enum:"static,dhcp" example:"static" doc:"\"static\" for a fixed address, \"dhcp\" for automatic"` + Address string `json:"address,omitempty" example:"192.168.1.10" doc:"IPv4 address (static only)"` + Prefix int `json:"prefix,omitempty" minimum:"0" maximum:"32" example:"24" doc:"Network prefix length (static only)"` + Gateway string `json:"gateway,omitempty" example:"192.168.1.1" doc:"Default gateway (static only, optional)"` + IPv6 *IPv6Config `json:"ipv6,omitempty" doc:"Optional IPv6 settings. Omit to leave IPv6 untouched; include to manage it."` + DNS []string `json:"dns,omitempty" example:"[\"1.1.1.1\",\"8.8.8.8\"]" doc:"DNS servers for this interface (IPv4 or IPv6)"` + Routes []Route `json:"routes,omitempty" doc:"Static routes to install for this interface"` + RollbackSeconds int `json:"rollback_seconds,omitempty" minimum:"0" maximum:"3600" example:"60" doc:"Auto-revert after this many seconds unless confirmed. 0 uses the default (60s)."` +} + +// IPv6Config is the optional IPv6 settings for an interface. Method "auto" uses +// SLAAC/router advertisements (the usual default), "static" pins an address, and +// "ignore" disables IPv6 on the interface. DHCPv6 is not modeled. +type IPv6Config struct { + Method string `json:"method" enum:"auto,static,ignore" example:"static" doc:"\"auto\" (SLAAC), \"static\", or \"ignore\" (disable IPv6)"` + Address string `json:"address,omitempty" example:"2001:db8::10" doc:"IPv6 address (static only)"` + Prefix int `json:"prefix,omitempty" minimum:"0" maximum:"128" example:"64" doc:"Prefix length (static only)"` + Gateway string `json:"gateway,omitempty" example:"2001:db8::1" doc:"IPv6 default gateway (static only, optional)"` +} + +// Route is a single static route. Destination is a CIDR (or \"default\"). +type Route struct { + Destination string `json:"destination" example:"10.0.0.0/24" doc:"Destination network in CIDR notation, or \"default\""` + Gateway string `json:"gateway" example:"192.168.1.1" doc:"Next-hop gateway"` +} + +// validate checks the desired config independently of the backend, so a bad +// request is a 400 before we touch the system. IP/CIDR parsing uses net/netip. +func (c IfaceConfig) validate() error { + switch c.Method { + case "static": + if c.Address == "" { + return huma.Error400BadRequest("static method requires an address") + } + if _, err := netip.ParseAddr(c.Address); err != nil { + return huma.Error400BadRequest("invalid address: " + c.Address) + } + if c.Prefix < 1 || c.Prefix > 32 { + return huma.Error400BadRequest("prefix must be 1-32") + } + if c.Gateway != "" { + if _, err := netip.ParseAddr(c.Gateway); err != nil { + return huma.Error400BadRequest("invalid gateway: " + c.Gateway) + } + } + case "dhcp": + // address/gateway/prefix are ignored; nothing to validate. + default: + return huma.Error400BadRequest("method must be \"static\" or \"dhcp\"") + } + for _, s := range c.DNS { + if _, err := netip.ParseAddr(s); err != nil { + return huma.Error400BadRequest("invalid DNS server: " + s) + } + } + for _, r := range c.Routes { + if r.Destination != "default" { + if _, err := netip.ParsePrefix(r.Destination); err != nil { + return huma.Error400BadRequest("invalid route destination: " + r.Destination) + } + } + if _, err := netip.ParseAddr(r.Gateway); err != nil { + return huma.Error400BadRequest("invalid route gateway: " + r.Gateway) + } + } + if c.IPv6 != nil { + if err := c.IPv6.validate(); err != nil { + return err + } + } + return nil +} + +// validate checks an IPv6 block. Static addresses/gateways must parse as IPv6 +// (an IPv4 literal here is a client mistake). +func (c IPv6Config) validate() error { + switch c.Method { + case "auto", "ignore": + // no address fields to validate + case "static": + addr, err := netip.ParseAddr(c.Address) + if err != nil || addr.Is4() { + return huma.Error400BadRequest("ipv6 static requires a valid IPv6 address, got: " + c.Address) + } + if c.Prefix < 1 || c.Prefix > 128 { + return huma.Error400BadRequest("ipv6 prefix must be 1-128") + } + if c.Gateway != "" { + if gw, err := netip.ParseAddr(c.Gateway); err != nil || gw.Is4() { + return huma.Error400BadRequest("invalid ipv6 gateway: " + c.Gateway) + } + } + default: + return huma.Error400BadRequest("ipv6 method must be \"auto\", \"static\", or \"ignore\"") + } + return nil +} diff --git a/internal/modules/networking/command_test.go b/internal/modules/networking/command_test.go new file mode 100644 index 0000000..094e7f0 --- /dev/null +++ b/internal/modules/networking/command_test.go @@ -0,0 +1,65 @@ +package networking + +import "testing" + +func TestValidateIfaceConfig(t *testing.T) { + tests := []struct { + name string + cfg IfaceConfig + wantErr bool + }{ + {"valid static", IfaceConfig{ + Method: "static", Address: "192.168.1.10", Prefix: 24, Gateway: "192.168.1.1", + DNS: []string{"1.1.1.1", "8.8.8.8"}, + }, false}, + {"valid static no gateway", IfaceConfig{ + Method: "static", Address: "10.0.0.5", Prefix: 8, + }, false}, + {"valid dhcp", IfaceConfig{Method: "dhcp"}, false}, + {"valid with routes", IfaceConfig{ + Method: "static", Address: "192.168.1.10", Prefix: 24, + Routes: []Route{ + {Destination: "100.64.0.0/24", Gateway: "192.168.1.1"}, + {Destination: "default", Gateway: "192.168.1.1"}, + }, + }, false}, + {"static missing address", IfaceConfig{Method: "static", Prefix: 24}, true}, + {"static bad address", IfaceConfig{Method: "static", Address: "not-an-ip", Prefix: 24}, true}, + {"static bad gateway", IfaceConfig{Method: "static", Address: "10.0.0.1", Prefix: 24, Gateway: "bad"}, true}, + {"bad method", IfaceConfig{Method: "pppoe"}, true}, + {"bad dns", IfaceConfig{Method: "dhcp", DNS: []string{"not-an-ip"}}, true}, + {"bad route destination", IfaceConfig{ + Method: "static", Address: "10.0.0.1", Prefix: 24, + Routes: []Route{{Destination: "nope", Gateway: "10.0.0.1"}}, + }, true}, + {"bad route gateway", IfaceConfig{ + Method: "static", Address: "10.0.0.1", Prefix: 24, + Routes: []Route{{Destination: "10.0.0.0/24", Gateway: "nope"}}, + }, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.cfg.validate() + if (err != nil) != tt.wantErr { + t.Errorf("validate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestValidateIface(t *testing.T) { + valid := []string{"eth0", "enp3s0", "wlan0", "br-lan", "veth1234567", "docker0"} + for _, name := range valid { + if err := validateIface(name); err != nil { + t.Errorf("validateIface(%q) unexpected error: %v", name, err) + } + } + + invalid := []string{"", "-eth0", "/dev/net", "a b", "name_that_is_way_too_long_for_linux"} + for _, name := range invalid { + if err := validateIface(name); err == nil { + t.Errorf("validateIface(%q) expected error", name) + } + } +} diff --git a/internal/modules/networking/hosts.go b/internal/modules/networking/hosts.go new file mode 100644 index 0000000..15c5819 --- /dev/null +++ b/internal/modules/networking/hosts.go @@ -0,0 +1,208 @@ +package networking + +import ( + "context" + "net/netip" + "os" + "regexp" + "strings" + + "nadir/internal/oscmd" + + "github.com/danielgtaylor/huma/v2" +) + +// /etc/hosts management ("Host Addresses"). Edits are surgical — upsert/delete a +// single IP's line — so existing comments, ordering and the localhost entries +// are preserved rather than rewritten away. + +var hostsFile = "/etc/hosts" + +// hostnameRe matches a single hostname/alias. It forbids whitespace and '#' (so +// an entry can't inject extra fields or a comment) and a leading dash. +var hostnameRe = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9._-]*$`) + +// HostEntry is one /etc/hosts mapping: an IP and the names that resolve to it. +type HostEntry struct { + IP string `json:"ip" example:"192.168.1.10" doc:"IPv4 or IPv6 address"` + Hostnames []string `json:"hostnames" example:"[\"server\",\"server.local\"]" doc:"Names mapped to the address"` +} + +type ListHostsOutput struct { + Body struct { + Entries []HostEntry `json:"entries"` + } +} + +type HostUpsertInput struct { + IP string `path:"ip" example:"192.168.1.10" doc:"IP address to add or update"` + Body struct { + Hostnames []string `json:"hostnames" example:"[\"server\",\"server.local\"]" doc:"Names to map to the IP"` + } +} + +type HostDeleteInput struct { + IP string `path:"ip" example:"192.168.1.10" doc:"IP address whose entry to remove"` +} + +func registerHosts(api huma.API) { + huma.Register(api, huma.Operation{ + OperationID: "networking-list-hosts", + Method: "GET", + Path: "/api/networking/hosts", + Summary: "List /etc/hosts entries", + Description: "Returns the static host-to-address mappings from /etc/hosts.", + Tags: []string{tagNetworking}, + Metadata: op("read"), + Errors: readErrors, + }, func(ctx context.Context, _ *struct{}) (*ListHostsOutput, error) { + data, err := os.ReadFile(hostsFile) + if err != nil { + return nil, huma.Error500InternalServerError("read hosts failed", err) + } + res := &ListHostsOutput{} + res.Body.Entries = parseHosts(string(data)) + return res, nil + }) + + huma.Register(api, huma.Operation{ + OperationID: "networking-upsert-host", + Method: "PUT", + Path: "/api/networking/hosts/{ip}", + Summary: "Add or update a /etc/hosts entry", + Description: "Sets the hostnames for an IP. Replaces the existing line for that IP, " + + "or appends a new one; all other lines (comments included) are left untouched.", + Tags: []string{tagNetworking}, + Metadata: op("write"), + Errors: writeErrors, + }, func(ctx context.Context, in *HostUpsertInput) (*oscmd.StatusOutput, error) { + if err := validateHost(in.IP, in.Body.Hostnames); err != nil { + return nil, err + } + if err := upsertHost(in.IP, in.Body.Hostnames); err != nil { + return nil, huma.Error500InternalServerError("write hosts failed", err) + } + return oscmd.OK(), nil + }) + + huma.Register(api, huma.Operation{ + OperationID: "networking-delete-host", + Method: "DELETE", + Path: "/api/networking/hosts/{ip}", + Summary: "Remove a /etc/hosts entry", + Description: "Removes the line(s) mapping the given IP. 404 if no entry exists for it.", + Tags: []string{tagNetworking}, + Metadata: op("write"), + Errors: []int{400, 401, 403, 404, 500}, + }, func(ctx context.Context, in *HostDeleteInput) (*oscmd.StatusOutput, error) { + if _, err := netip.ParseAddr(in.IP); err != nil { + return nil, huma.Error400BadRequest("invalid IP: " + in.IP) + } + removed, err := deleteHost(in.IP) + if err != nil { + return nil, huma.Error500InternalServerError("write hosts failed", err) + } + if !removed { + return nil, huma.Error404NotFound("no hosts entry for " + in.IP) + } + return oscmd.OK(), nil + }) +} + +func validateHost(ip string, hostnames []string) error { + if _, err := netip.ParseAddr(ip); err != nil { + return huma.Error400BadRequest("invalid IP: " + ip) + } + if len(hostnames) == 0 { + return huma.Error400BadRequest("at least one hostname is required") + } + for _, h := range hostnames { + if !hostnameRe.MatchString(h) { + return huma.Error400BadRequest("invalid hostname: " + h) + } + } + return nil +} + +// --- parse / render (pure, tested) ------------------------------------------- + +// hostFields returns the IP and hostnames on a hosts line, or ok=false for a +// blank or comment-only line. An inline "# comment" tail is stripped. +func hostFields(line string) (ip string, names []string, ok bool) { + if i := strings.IndexByte(line, '#'); i >= 0 { + line = line[:i] + } + f := strings.Fields(line) + if len(f) < 2 { + return "", nil, false + } + return f[0], f[1:], true +} + +func parseHosts(text string) []HostEntry { + entries := []HostEntry{} + for line := range strings.SplitSeq(text, "\n") { + if ip, names, ok := hostFields(line); ok { + entries = append(entries, HostEntry{IP: ip, Hostnames: names}) + } + } + return entries +} + +// renderHostLine formats one entry as a hosts line. +func renderHostLine(ip string, hostnames []string) string { + return ip + "\t" + strings.Join(hostnames, " ") +} + +// --- writes ------------------------------------------------------------------ + +// upsertHost replaces the line for ip (matched on the address field) or appends +// a new one, preserving every other line. +func upsertHost(ip string, hostnames []string) error { + data, err := os.ReadFile(hostsFile) + if err != nil { + return err + } + lines := strings.Split(string(data), "\n") + newLine := renderHostLine(ip, hostnames) + replaced := false + for i, line := range lines { + if got, _, ok := hostFields(line); ok && got == ip { + lines[i] = newLine + replaced = true + break + } + } + if !replaced { + // Append, avoiding a blank line if the file already ended with one. + if n := len(lines); n > 0 && strings.TrimSpace(lines[n-1]) == "" { + lines[n-1] = newLine + } else { + lines = append(lines, newLine) + } + lines = append(lines, "") + } + return os.WriteFile(hostsFile, []byte(strings.Join(lines, "\n")), 0644) +} + +// deleteHost removes every line mapping ip and reports whether any were removed. +func deleteHost(ip string) (bool, error) { + data, err := os.ReadFile(hostsFile) + if err != nil { + return false, err + } + lines := strings.Split(string(data), "\n") + kept := make([]string, 0, len(lines)) + removed := false + for _, line := range lines { + if got, _, ok := hostFields(line); ok && got == ip { + removed = true + continue + } + kept = append(kept, line) + } + if !removed { + return false, nil + } + return true, os.WriteFile(hostsFile, []byte(strings.Join(kept, "\n")), 0644) +} diff --git a/internal/modules/networking/hosts_test.go b/internal/modules/networking/hosts_test.go new file mode 100644 index 0000000..8ed520a --- /dev/null +++ b/internal/modules/networking/hosts_test.go @@ -0,0 +1,92 @@ +package networking + +import ( + "os" + "path/filepath" + "reflect" + "strings" + "testing" +) + +const sampleHosts = `# static table +127.0.0.1 localhost +::1 localhost ip6-localhost +192.168.1.10 server server.local # the box + +` + +func TestParseHosts(t *testing.T) { + got := parseHosts(sampleHosts) + want := []HostEntry{ + {IP: "127.0.0.1", Hostnames: []string{"localhost"}}, + {IP: "::1", Hostnames: []string{"localhost", "ip6-localhost"}}, + {IP: "192.168.1.10", Hostnames: []string{"server", "server.local"}}, + } + if !reflect.DeepEqual(got, want) { + t.Errorf("parseHosts:\n got %+v\nwant %+v", got, want) + } +} + +func TestUpsertAndDeleteHost(t *testing.T) { + path := filepath.Join(t.TempDir(), "hosts") + if err := os.WriteFile(path, []byte(sampleHosts), 0644); err != nil { + t.Fatal(err) + } + old := hostsFile + hostsFile = path + defer func() { hostsFile = old }() + + // Update an existing IP — comments and other lines must survive. + if err := upsertHost("192.168.1.10", []string{"web", "web.local"}); err != nil { + t.Fatal(err) + } + data, _ := os.ReadFile(path) + if !strings.Contains(string(data), "# static table") || !strings.Contains(string(data), "ip6-localhost") { + t.Errorf("upsert clobbered other lines:\n%s", data) + } + entries := parseHosts(string(data)) + if e := findHost(entries, "192.168.1.10"); e == nil || !reflect.DeepEqual(e.Hostnames, []string{"web", "web.local"}) { + t.Errorf("upsert did not update entry: %+v", entries) + } + + // Add a new IP. + if err := upsertHost("10.0.0.5", []string{"db"}); err != nil { + t.Fatal(err) + } + entries = parseHosts(mustRead(t, path)) + if findHost(entries, "10.0.0.5") == nil { + t.Errorf("upsert did not append new entry: %+v", entries) + } + + // Delete it again. + removed, err := deleteHost("10.0.0.5") + if err != nil || !removed { + t.Fatalf("delete failed: removed=%v err=%v", removed, err) + } + if findHost(parseHosts(mustRead(t, path)), "10.0.0.5") != nil { + t.Error("entry still present after delete") + } + + // Deleting a missing IP reports not-removed. + if removed, _ := deleteHost("8.8.8.8"); removed { + t.Error("expected removed=false for missing IP") + } +} + +func findHost(entries []HostEntry, ip string) *HostEntry { + for i := range entries { + if entries[i].IP == ip { + return &entries[i] + } + } + return nil +} + +func mustRead(t *testing.T, path string) string { + t.Helper() + data, err := os.ReadFile(path) + if err != nil { + t.Fatal(err) + } + return string(data) +} diff --git a/internal/modules/networking/ifupdown.go b/internal/modules/networking/ifupdown.go new file mode 100644 index 0000000..12ed821 --- /dev/null +++ b/internal/modules/networking/ifupdown.go @@ -0,0 +1,239 @@ +package networking + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strings" + + "nadir/internal/oscmd" +) + +// ifupdownBackend implements backend via classic Debian ifupdown. Configuration +// is done by writing stanza files under interfacesDDir. Each managed interface +// gets "nadir-" so we never touch /etc/network/interfaces itself. +// +// Prerequisite: /etc/network/interfaces must contain +// +// source /etc/network/interfaces.d/* +// +// for these stanzas to take effect. If missing, Apply checks for this and adds +// the source directive (same auto-provision pattern as the PAM service). +type ifupdownBackend struct{} + +var ( + interfacesFile = "/etc/network/interfaces" + interfacesDDir = "/etc/network/interfaces.d" +) + +func (b *ifupdownBackend) Name() string { return "ifupdown" } + +// ifupdownFile returns the path for the nadir-managed stanza file for iface. +// iface is already validated by validateIface to prevent path traversal. +func ifupdownFile(iface string) string { + return filepath.Join(interfacesDDir, "nadir-"+iface) +} + +func (b *ifupdownBackend) Snapshot(ctx context.Context, iface string) (IfaceConfig, error) { + path := ifupdownFile(iface) + data, err := os.ReadFile(path) + if err != nil { + // No nadir-managed stanza → fall back to live ip output, assume DHCP. + return snapshotFromIP(ctx, iface) + } + return parseIfupdownStanza(string(data)), nil +} + +// parseIfupdownStanza extracts IfaceConfig from an ifupdown stanza file. A file +// may hold both an "inet" (IPv4) and an "inet6" (IPv6) stanza for the interface; +// family tracks which one the indented keys below belong to. +func parseIfupdownStanza(content string) IfaceConfig { + cfg := IfaceConfig{Method: "dhcp"} + family := "inet" + + for line := range strings.SplitSeq(content, "\n") { + line = strings.TrimSpace(line) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + fields := strings.Fields(line) + if len(fields) < 2 { + continue + } + + switch fields[0] { + case "iface": + // "iface eth0 inet static" / "iface eth0 inet6 auto" + if len(fields) >= 4 { + family = fields[2] + switch family { + case "inet": + if fields[3] == "static" { + cfg.Method = "static" + } + case "inet6": + switch fields[3] { + case "static": + v6(&cfg).Method = "static" + default: // auto / dhcp / manual → treat as autoconf + v6(&cfg).Method = "auto" + } + } + } + case "address": + addr, prefix := splitCIDR(fields[1]) + if family == "inet6" { + g := v6(&cfg) + g.Address = addr + if prefix > 0 { + g.Prefix = prefix + } + } else if addr != "" { + cfg.Address = addr + if prefix > 0 { + cfg.Prefix = prefix + } + } + case "gateway": + if family == "inet6" { + v6(&cfg).Gateway = fields[1] + } else { + cfg.Gateway = fields[1] + } + case "dns-nameservers": + cfg.DNS = append(cfg.DNS, fields[1:]...) + case "up": + // "up ip route add 10.0.0.0/24 via 192.168.1.1" + r := parseUpRoute(fields[1:]) + if r.Destination != "" { + cfg.Routes = append(cfg.Routes, r) + } + } + } + return cfg +} + +// parseUpRoute extracts a Route from "ip route add via " post-up commands. +func parseUpRoute(args []string) Route { + // Expected: ["ip", "route", "add", "10.0.0.0/24", "via", "192.168.1.1"] + if len(args) < 6 || args[0] != "ip" || args[1] != "route" || args[2] != "add" { + return Route{} + } + r := Route{Destination: args[3]} + for i, a := range args { + if a == "via" && i+1 < len(args) { + r.Gateway = args[i+1] + break + } + } + return r +} + +func (b *ifupdownBackend) Apply(ctx context.Context, iface string, cfg IfaceConfig) error { + // Ensure interfaces.d exists and is sourced. + if err := ensureSourceDirective(); err != nil { + return err + } + + content := renderIfupdownStanza(iface, cfg) + path := ifupdownFile(iface) + + if err := os.MkdirAll(interfacesDDir, 0755); err != nil { + return fmt.Errorf("mkdir %s: %w", interfacesDDir, err) + } + if err := os.WriteFile(path, []byte(content), 0644); err != nil { + return fmt.Errorf("write %s: %w", path, err) + } + + // Bring the interface down and back up with the new config. + // ifdown may fail if the interface wasn't previously managed — ignore that. + _, _ = oscmd.RunContext(ctx, "ifdown", "--force", "--", iface) + if _, err := oscmd.RunContext(ctx, "ifup", "--", iface); err != nil { + return fmt.Errorf("ifup %s: %w", iface, err) + } + return nil +} + +// ensureSourceDirective checks that /etc/network/interfaces contains a source +// line for interfaces.d. If not, appends one (same auto-provision pattern as +// the PAM service file). +func ensureSourceDirective() error { + data, err := os.ReadFile(interfacesFile) + if err != nil { + // If the file doesn't exist at all, that's a broken system; don't create it. + return fmt.Errorf("read %s: %w", interfacesFile, err) + } + + content := string(data) + // Check for any source line pointing at interfaces.d. + for line := range strings.SplitSeq(content, "\n") { + line = strings.TrimSpace(line) + if strings.HasPrefix(line, "source") && strings.Contains(line, "interfaces.d") { + return nil // already present + } + } + + // Append the source directive. + addition := "\n# Added by nadir to pick up per-interface stanzas.\nsource " + interfacesDDir + "/*\n" + if err := os.WriteFile(interfacesFile, []byte(content+addition), 0644); err != nil { + return fmt.Errorf("append source directive to %s: %w", interfacesFile, err) + } + return nil +} + +// renderIfupdownStanza builds the content for an ifupdown stanza file. +func renderIfupdownStanza(iface string, cfg IfaceConfig) string { + var b strings.Builder + b.WriteString("# Managed by nadir — do not edit manually.\n") + + switch cfg.Method { + case "dhcp": + fmt.Fprintf(&b, "auto %s\n", iface) + fmt.Fprintf(&b, "iface %s inet dhcp\n", iface) + case "static": + fmt.Fprintf(&b, "auto %s\n", iface) + fmt.Fprintf(&b, "iface %s inet static\n", iface) + fmt.Fprintf(&b, " address %s/%d\n", cfg.Address, cfg.Prefix) + if cfg.Gateway != "" { + fmt.Fprintf(&b, " gateway %s\n", cfg.Gateway) + } + } + + if len(cfg.DNS) > 0 { + fmt.Fprintf(&b, " dns-nameservers %s\n", strings.Join(cfg.DNS, " ")) + } + + for _, r := range cfg.Routes { + fmt.Fprintf(&b, " up ip route add %s via %s\n", r.Destination, r.Gateway) + fmt.Fprintf(&b, " down ip route del %s via %s\n", r.Destination, r.Gateway) + } + + // IPv6 goes in its own inet6 stanza (the "auto eth0" above covers both). + if cfg.IPv6 != nil { + switch cfg.IPv6.Method { + case "auto": + fmt.Fprintf(&b, "iface %s inet6 auto\n", iface) + case "static": + fmt.Fprintf(&b, "iface %s inet6 static\n", iface) + fmt.Fprintf(&b, " address %s/%d\n", cfg.IPv6.Address, cfg.IPv6.Prefix) + if cfg.IPv6.Gateway != "" { + fmt.Fprintf(&b, " gateway %s\n", cfg.IPv6.Gateway) + } + case "ignore": + // no inet6 stanza + } + } + + return b.String() +} + +func (b *ifupdownBackend) SetLinkUp(ctx context.Context, iface string) error { + _, err := oscmd.RunContext(ctx, "ifup", "--", iface) + return err +} + +func (b *ifupdownBackend) SetLinkDown(ctx context.Context, iface string) error { + _, err := oscmd.RunContext(ctx, "ifdown", "--", iface) + return err +} diff --git a/internal/modules/networking/module.go b/internal/modules/networking/module.go new file mode 100644 index 0000000..90819ff --- /dev/null +++ b/internal/modules/networking/module.go @@ -0,0 +1,43 @@ +package networking + +import ( + "sync" + + "nadir/internal/rbac" + + "github.com/danielgtaylor/huma/v2" +) + +const ModuleID = "networking" + +type Module struct { + // be is the detected network backend (nmcli / networkd / ifupdown). nil when + // none was found: reads still work (they go through `ip`), writes return 501. + be backend + // pending holds the single in-flight change awaiting confirmation, for the + // timed auto-rollback. See rollback.go. + pending *pendingChange + mu sync.Mutex +} + +// New detects the host's network backend once at startup. +func New() *Module { return &Module{be: detect()} } + +func (m *Module) ID() string { return ModuleID } +func (m *Module) Name() string { return "Networking" } + +// Permissions: read to inspect interfaces/routes/DNS; write to reconfigure them +// (apply config, bring links up/down, confirm a pending change). +func (m *Module) Permissions() []rbac.Permission { + return []rbac.Permission{rbac.Read, rbac.Write} +} + +func (m *Module) Register(api huma.API) { + registerReads(api) + registerWrites(api, m) + registerHosts(api) +} + +func op(permission string) map[string]any { + return map[string]any{"module": ModuleID, "permission": permission} +} diff --git a/internal/modules/networking/networkd.go b/internal/modules/networking/networkd.go new file mode 100644 index 0000000..57b4bb2 --- /dev/null +++ b/internal/modules/networking/networkd.go @@ -0,0 +1,285 @@ +package networking + +import ( + "context" + "fmt" + "net/netip" + "os" + "path/filepath" + "strings" + + "nadir/internal/oscmd" +) + +// networkdBackend implements backend via systemd-networkd. Configuration is done +// by writing .network files under networkdDir. Each managed interface gets its +// own file named "90-nadir-.network" — the 90 prefix puts nadir's config +// after most distro defaults, and the "nadir-" infix ensures we never clobber +// distro-provided files. +type networkdBackend struct{} + +var networkdDir = "/etc/systemd/network" + +func (b *networkdBackend) Name() string { return "networkd" } + +// networkdFile returns the path for the nadir-managed .network file for iface. +// iface is already validated by validateIface to prevent path traversal. +func networkdFile(iface string) string { + return filepath.Join(networkdDir, "90-nadir-"+iface+".network") +} + +func (b *networkdBackend) Snapshot(ctx context.Context, iface string) (IfaceConfig, error) { + path := networkdFile(iface) + data, err := os.ReadFile(path) + if err != nil { + // No nadir-managed file exists. Fall back to live ip output and assume + // DHCP — the safest rollback assumption (reverts to "whatever the + // system was doing before nadir touched it"). + return snapshotFromIP(ctx, iface) + } + return parseNetworkdFile(string(data)), nil +} + +// snapshotFromIP captures the current state from ip -j, assuming DHCP as the +// method since there's no way to tell from live output. +func snapshotFromIP(ctx context.Context, iface string) (IfaceConfig, error) { + cfg := IfaceConfig{Method: "dhcp"} + + // Grab addresses. + out, err := oscmd.RunContext(ctx, "ip", "-j", "addr", "show", "--", iface) + if err != nil { + return cfg, nil // interface may not exist yet; DHCP fallback is fine + } + ifaces, err := parseInterfaces(out) + if err != nil || len(ifaces) == 0 { + return cfg, nil + } + + // If there are IPv4 addresses, capture the first one. + if len(ifaces[0].IPv4) > 0 { + addr, prefix := splitCIDR(ifaces[0].IPv4[0]) + cfg.Method = "static" + cfg.Address = addr + cfg.Prefix = prefix + } + + // Capture a global IPv6 address if present, skipping link-local (fe80::/10), + // so a rollback restores the interface's real IPv6 state. + cfg.IPv6 = &IPv6Config{Method: "auto"} + for _, c := range ifaces[0].IPv6 { + addr, prefix := splitCIDR(c) + if ip, err := netip.ParseAddr(addr); err == nil && !ip.IsLinkLocalUnicast() { + cfg.IPv6 = &IPv6Config{Method: "static", Address: addr, Prefix: prefix} + break + } + } + + // Grab the default gateway for this interface. + routeOut, err := oscmd.RunContext(ctx, "ip", "-j", "route", "show", "dev", "--", iface) + if err == nil { + routes, _ := parseRoutes(routeOut) + for _, r := range routes { + if r.Destination == "default" && r.Gateway != "" { + cfg.Gateway = r.Gateway + break + } + } + } + + // Grab DNS from /etc/resolv.conf + if data, err := os.ReadFile(resolvConf); err == nil { + cfg.DNS = parseResolv(string(data)) + } + + return cfg, nil +} + +// parseNetworkdFile extracts IfaceConfig from a systemd .network file. +func parseNetworkdFile(content string) IfaceConfig { + cfg := IfaceConfig{Method: "dhcp"} + section := "" + + for line := range strings.SplitSeq(content, "\n") { + line = strings.TrimSpace(line) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + if strings.HasPrefix(line, "[") && strings.HasSuffix(line, "]") { + section = line + continue + } + + key, val, ok := strings.Cut(line, "=") + if !ok { + continue + } + key = strings.TrimSpace(key) + val = strings.TrimSpace(val) + + if section == "[Network]" { + switch key { + case "DHCP": + if val == "yes" || val == "ipv4" { + cfg.Method = "dhcp" + } + case "Address": + addr, prefix := splitCIDR(val) + if ip, err := netip.ParseAddr(addr); err == nil && !ip.Is4() { + g := v6(&cfg) + g.Method = "static" + g.Address = addr + g.Prefix = prefix + } else if addr != "" { + cfg.Method = "static" + cfg.Address = addr + cfg.Prefix = prefix + } + case "Gateway": + if ip, err := netip.ParseAddr(val); err == nil && !ip.Is4() { + v6(&cfg).Gateway = val + } else { + cfg.Gateway = val + } + case "DNS": + for s := range strings.FieldsSeq(val) { + cfg.DNS = append(cfg.DNS, s) + } + case "IPv6AcceptRA": + if val == "yes" { + if g := v6(&cfg); g.Method != "static" { + g.Method = "auto" + } + } + case "LinkLocalAddressing": + if val == "no" { + v6(&cfg).Method = "ignore" + } + } + } + } + + // Parse [Route] sections separately. + cfg.Routes = parseNetworkdRoutes(content) + return cfg +} + +// parseNetworkdRoutes extracts Route entries from [Route] sections. +func parseNetworkdRoutes(content string) []Route { + var routes []Route + var inRoute bool + var current Route + + for line := range strings.SplitSeq(content, "\n") { + line = strings.TrimSpace(line) + + if strings.HasPrefix(line, "[") { + // Flush any pending route when entering a new section. + if inRoute && (current.Destination != "" || current.Gateway != "") { + routes = append(routes, current) + } + inRoute = line == "[Route]" + current = Route{} + continue + } + + if !inRoute { + continue + } + + key, val, ok := strings.Cut(line, "=") + if !ok { + continue + } + switch strings.TrimSpace(key) { + case "Destination": + current.Destination = strings.TrimSpace(val) + case "Gateway": + current.Gateway = strings.TrimSpace(val) + } + } + // Flush last route if we ended inside a [Route] section. + if inRoute && (current.Destination != "" || current.Gateway != "") { + routes = append(routes, current) + } + return routes +} + +func (b *networkdBackend) Apply(ctx context.Context, iface string, cfg IfaceConfig) error { + content := renderNetworkdFile(iface, cfg) + path := networkdFile(iface) + + if err := os.MkdirAll(networkdDir, 0755); err != nil { + return fmt.Errorf("mkdir %s: %w", networkdDir, err) + } + if err := os.WriteFile(path, []byte(content), 0644); err != nil { + return fmt.Errorf("write %s: %w", path, err) + } + + // Reload networkd and reconfigure the specific interface. + if _, err := oscmd.RunContext(ctx, "networkctl", "reload"); err != nil { + return fmt.Errorf("networkctl reload: %w", err) + } + if _, err := oscmd.RunContext(ctx, "networkctl", "reconfigure", "--", iface); err != nil { + return fmt.Errorf("networkctl reconfigure %s: %w", iface, err) + } + return nil +} + +// renderNetworkdFile builds the INI content for a systemd .network file. +func renderNetworkdFile(iface string, cfg IfaceConfig) string { + var b strings.Builder + + b.WriteString("# Managed by nadir — do not edit manually.\n") + b.WriteString("[Match]\n") + fmt.Fprintf(&b, "Name=%s\n\n", iface) + + b.WriteString("[Network]\n") + switch cfg.Method { + case "dhcp": + b.WriteString("DHCP=yes\n") + case "static": + b.WriteString("DHCP=no\n") + fmt.Fprintf(&b, "Address=%s/%d\n", cfg.Address, cfg.Prefix) + if cfg.Gateway != "" { + fmt.Fprintf(&b, "Gateway=%s\n", cfg.Gateway) + } + } + if cfg.IPv6 != nil { + switch cfg.IPv6.Method { + case "static": + fmt.Fprintf(&b, "Address=%s/%d\n", cfg.IPv6.Address, cfg.IPv6.Prefix) + if cfg.IPv6.Gateway != "" { + fmt.Fprintf(&b, "Gateway=%s\n", cfg.IPv6.Gateway) + } + b.WriteString("IPv6AcceptRA=no\n") + case "auto": + b.WriteString("IPv6AcceptRA=yes\n") + case "ignore": + b.WriteString("LinkLocalAddressing=no\nIPv6AcceptRA=no\n") + } + } + + for _, dns := range cfg.DNS { + fmt.Fprintf(&b, "DNS=%s\n", dns) + } + + for _, r := range cfg.Routes { + b.WriteString("\n[Route]\n") + fmt.Fprintf(&b, "Destination=%s\n", r.Destination) + fmt.Fprintf(&b, "Gateway=%s\n", r.Gateway) + } + + return b.String() +} + +func (b *networkdBackend) SetLinkUp(ctx context.Context, iface string) error { + // Note: ip link set parses DEVICE positionally, so -- is technically ignored + // by ip but included here for consistency with other oscmd calls. + _, err := oscmd.RunContext(ctx, "ip", "link", "set", "--", iface, "up") + return err +} + +func (b *networkdBackend) SetLinkDown(ctx context.Context, iface string) error { + _, err := oscmd.RunContext(ctx, "ip", "link", "set", "--", iface, "down") + return err +} diff --git a/internal/modules/networking/networking_handler_test.go b/internal/modules/networking/networking_handler_test.go new file mode 100644 index 0000000..61b69da --- /dev/null +++ b/internal/modules/networking/networking_handler_test.go @@ -0,0 +1,397 @@ +package networking + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "os" + "path/filepath" + "reflect" + "strings" + "testing" + "time" + + "nadir/internal/oscmd" + + "github.com/danielgtaylor/huma/v2" + "github.com/danielgtaylor/huma/v2/adapters/humago" + "github.com/danielgtaylor/huma/v2/humatest" +) + +func TestMain(m *testing.M) { + if oscmd.RunHelperProcess() { + return + } + os.Exit(m.Run()) +} + +type mockBackend struct { + name string + snapshotResult IfaceConfig + applyCalledWith IfaceConfig + applyErr error + snapshotErr error + setUpCalled bool + setDownCalled bool +} + +func (m *mockBackend) Name() string { return m.name } +func (m *mockBackend) Snapshot(_ context.Context, iface string) (IfaceConfig, error) { + return m.snapshotResult, m.snapshotErr +} +func (m *mockBackend) Apply(_ context.Context, iface string, cfg IfaceConfig) error { + m.applyCalledWith = cfg + return m.applyErr +} +func (m *mockBackend) SetLinkUp(_ context.Context, iface string) error { + m.setUpCalled = true + return nil +} +func (m *mockBackend) SetLinkDown(_ context.Context, iface string) error { + m.setDownCalled = true + return nil +} + +func TestNetworkingHandlers(t *testing.T) { + mux := http.NewServeMux() + api := humatest.Wrap(t, humago.New(mux, huma.DefaultConfig("Test", "1.0.0"))) + + be := &mockBackend{ + name: "mockbe", + snapshotResult: IfaceConfig{ + Method: "dhcp", + Address: "192.168.1.10/24", + }, + } + m := &Module{be: be} + m.Register(api) + + tempResolv := filepath.Join(t.TempDir(), "resolv.conf") + if err := os.WriteFile(tempResolv, []byte("nameserver 1.1.1.1\nnameserver 8.8.8.8\n"), 0644); err != nil { + t.Fatal(err) + } + oldResolv := resolvConf + resolvConf = tempResolv + defer func() { resolvConf = oldResolv }() + + oscmd.SetMock("ip", func(args []string) oscmd.MockCommand { + if reflect.DeepEqual(args, []string{"-j", "addr"}) { + out := `[{"ifname": "eth0", "operstate": "UP", "address": "aa:bb:cc:dd:ee:ff", "mtu": 1500, "addr_info": [{"family": "inet", "local": "192.168.1.10", "prefixlen": 24}]}]` + return oscmd.MockCommand{Stdout: out, ExitCode: 0} + } + if reflect.DeepEqual(args, []string{"-j", "route"}) { + out := `[{"dst": "default", "gateway": "192.168.1.1", "dev": "eth0"}]` + return oscmd.MockCommand{Stdout: out, ExitCode: 0} + } + return oscmd.MockCommand{ExitCode: 1} + }) + defer oscmd.ClearMocks() + + // 1. Test GET /api/networking/interfaces + resp := api.Get("/api/networking/interfaces") + if resp.Code != http.StatusOK { + t.Errorf("list interfaces: got %d, want %d", resp.Code, http.StatusOK) + } + + // 2. Test GET /api/networking/routes + resp = api.Get("/api/networking/routes") + if resp.Code != http.StatusOK { + t.Errorf("list routes: got %d, want %d", resp.Code, http.StatusOK) + } + + // 3. Test GET /api/networking/dns + resp = api.Get("/api/networking/dns") + if resp.Code != http.StatusOK { + t.Errorf("get dns: got %d, want %d", resp.Code, http.StatusOK) + } + var dnsRes DNSOutput + if err := json.Unmarshal(resp.Body.Bytes(), &dnsRes.Body); err != nil { + t.Fatal(err) + } + if len(dnsRes.Body.Servers) != 2 || dnsRes.Body.Servers[0] != "1.1.1.1" { + t.Errorf("get dns output: %+v", dnsRes.Body) + } + + // 4. Test PUT /api/networking/interfaces/{name} + applyPayload := struct { + Method string `json:"method"` + Address string `json:"address,omitempty"` + Prefix int `json:"prefix,omitempty"` + Gateway string `json:"gateway,omitempty"` + DNS []string `json:"dns,omitempty"` + RollbackSeconds int `json:"rollback_seconds,omitempty"` + }{ + Method: "static", + Address: "192.168.1.20", + Prefix: 24, + Gateway: "192.168.1.1", + DNS: []string{"1.1.1.1"}, + RollbackSeconds: 2, + } + + resp = api.Put("/api/networking/interfaces/eth0", applyPayload) + if resp.Code != http.StatusOK { + t.Errorf("apply interface config: got %d, want %d, body=%s", resp.Code, http.StatusOK, resp.Body.String()) + } + + resp = api.Get("/api/networking/pending") + if resp.Code != http.StatusOK { + t.Errorf("get pending: got %d, want %d", resp.Code, http.StatusOK) + } + + // 5. Test POST /api/networking/interfaces/{name}/confirm + resp = api.Post("/api/networking/interfaces/eth0/confirm", struct{}{}) + if resp.Code != http.StatusOK { + t.Errorf("confirm change: got %d, want %d", resp.Code, http.StatusOK) + } + + resp = api.Get("/api/networking/pending") + if resp.Code != http.StatusNotFound { + t.Errorf("pending change should be cleared: got %d, want %d", resp.Code, http.StatusNotFound) + } + + // 6. Test automatic rollback + applyPayload.RollbackSeconds = 1 + resp = api.Put("/api/networking/interfaces/eth0", applyPayload) + if resp.Code != http.StatusOK { + t.Errorf("apply config again: got %d, want %d", resp.Code, http.StatusOK) + } + + time.Sleep(1200 * time.Millisecond) + + resp = api.Get("/api/networking/pending") + if resp.Code != http.StatusNotFound { + t.Errorf("pending change should be rolled back: got %d, want %d", resp.Code, http.StatusNotFound) + } + + if be.applyCalledWith.Method != "dhcp" || be.applyCalledWith.Address != "192.168.1.10/24" { + t.Errorf("rollback failed to restore prior config: %+v", be.applyCalledWith) + } + + // 7. Test POST link up & down + resp = api.Post("/api/networking/interfaces/eth0/up", struct{}{}) + if resp.Code != http.StatusOK { + t.Errorf("set link up: got %d, want %d", resp.Code, http.StatusOK) + } + if !be.setUpCalled { + t.Errorf("expected SetLinkUp call on backend") + } + + resp = api.Post("/api/networking/interfaces/eth0/down", struct{}{}) + if resp.Code != http.StatusOK { + t.Errorf("set link down: got %d, want %d", resp.Code, http.StatusOK) + } + if !be.setDownCalled { + t.Errorf("expected SetLinkDown call on backend") + } +} + +// #3: a failed Apply must restore the prior snapshot (no half-applied config +// left with no auto-revert) and arm no pending change. +func TestApplyFailureRestoresPrior(t *testing.T) { + be := &mockBackend{ + name: "mock", + snapshotResult: IfaceConfig{Method: "dhcp"}, + applyErr: errors.New("con up failed"), + } + m := &Module{be: be} + + _, err := m.startRollback(t.Context(), "eth0", IfaceConfig{Method: "static", Address: "10.0.0.5", Prefix: 24}) + if err == nil { + t.Fatal("expected apply to fail") + } + // The last Apply call should be the restore-to-prior (dhcp), not the new config. + if be.applyCalledWith.Method != "dhcp" { + t.Errorf("expected restore to prior config, last Apply got %+v", be.applyCalledWith) + } + if m.pending != nil { + t.Error("no pending change should be armed after a failed apply") + } +} + +// #1: link-down must go through the rollback safety net (arms a pending change), +// and rolling it back brings the interface back up. #2: a change to another +// interface while one is pending is rejected with errAlreadyPending. +func TestLinkDownArmsRollback(t *testing.T) { + be := &mockBackend{name: "mock"} + m := &Module{be: be} + + secs, err := m.startLinkDown(t.Context(), "eth0") + if err != nil { + t.Fatal(err) + } + if secs != defaultRollbackSeconds { + t.Errorf("seconds = %d, want default %d", secs, defaultRollbackSeconds) + } + if !be.setDownCalled { + t.Error("expected SetLinkDown to be called") + } + if m.pending == nil { + t.Fatal("expected a pending change to be armed") + } + + // A concurrent change to a different interface is rejected (global lock). + if _, err := m.startRollback(t.Context(), "eth1", IfaceConfig{Method: "dhcp"}); !errors.Is(err, errAlreadyPending) { + t.Errorf("expected errAlreadyPending for eth1, got %v", err) + } + + // Rolling back brings the link back up and clears the pending change. + if err := m.rollbackNow("eth0"); err != nil { + t.Fatal(err) + } + if !be.setUpCalled { + t.Error("expected SetLinkUp on rollback") + } + if m.pending != nil { + t.Error("pending change should be cleared after rollback") + } +} + +func TestBackendImplementations(t *testing.T) { + tempDir := t.TempDir() + + // Mock ip command for fallback snapshots and link control + oscmd.SetMock("ip", func(args []string) oscmd.MockCommand { + argStr := strings.Join(args, " ") + if strings.Contains(argStr, "addr show") { + out := `[{"ifname": "eth0", "operstate": "UP", "address": "aa:bb:cc:dd:ee:ff", "mtu": 1500, "addr_info": [{"family": "inet", "local": "192.168.1.30", "prefixlen": 24}]}]` + return oscmd.MockCommand{Stdout: out, ExitCode: 0} + } + if strings.Contains(argStr, "route show") { + out := `[{"dst": "default", "gateway": "192.168.1.1", "dev": "eth0"}]` + return oscmd.MockCommand{Stdout: out, ExitCode: 0} + } + if strings.Contains(argStr, "link set") { + return oscmd.MockCommand{ExitCode: 0} + } + return oscmd.MockCommand{ExitCode: 1} + }) + defer oscmd.ClearMocks() + + // 1. Test nmcliBackend + nm := &nmcliBackend{} + oscmd.SetMock("nmcli", func(args []string) oscmd.MockCommand { + argStr := strings.Join(args, " ") + t.Logf("nmcli mock called with: %s", argStr) + if strings.Contains(argStr, "con show --active") { + return oscmd.MockCommand{Stdout: "myconn:eth0\n", ExitCode: 0} + } + if strings.Contains(argStr, "con show") && strings.Contains(argStr, "myconn") { + showOut := "ipv4.method:manual\nipv4.addresses:192.168.1.10/24\nipv4.gateway:192.168.1.1\nipv4.dns:1.1.1.1\nipv4.routes:10.0.0.0/24 192.168.1.1\n" + return oscmd.MockCommand{Stdout: showOut, ExitCode: 0} + } + if (strings.Contains(argStr, "con modify") || strings.Contains(argStr, "con up") || strings.Contains(argStr, "con down")) && strings.Contains(argStr, "myconn") { + return oscmd.MockCommand{ExitCode: 0} + } + return oscmd.MockCommand{ExitCode: 1} + }) + + cfg, err := nm.Snapshot(t.Context(), "eth0") + if err != nil { + t.Fatal(err) + } + if cfg.Address != "192.168.1.10" || cfg.Gateway != "192.168.1.1" { + t.Errorf("nmcli snapshot failed: %+v", cfg) + } + + err = nm.Apply(t.Context(), "eth0", IfaceConfig{ + Method: "static", + Address: "192.168.1.20", + Prefix: 24, + Gateway: "192.168.1.1", + }) + if err != nil { + t.Fatal(err) + } + + if err := nm.SetLinkUp(t.Context(), "eth0"); err != nil { + t.Fatal(err) + } + if err := nm.SetLinkDown(t.Context(), "eth0"); err != nil { + t.Fatal(err) + } + + // 2. Test networkdBackend + nd := &networkdBackend{} + oldNdDir := networkdDir + networkdDir = tempDir + defer func() { networkdDir = oldNdDir }() + + oscmd.SetMock("networkctl", func(args []string) oscmd.MockCommand { + return oscmd.MockCommand{ExitCode: 0} + }) + + err = nd.Apply(t.Context(), "eth0", IfaceConfig{ + Method: "static", + Address: "192.168.1.30", + Prefix: 24, + Gateway: "192.168.1.1", + DNS: []string{"8.8.8.8"}, + }) + if err != nil { + t.Fatal(err) + } + + cfg, err = nd.Snapshot(t.Context(), "eth0") + if err != nil { + t.Fatal(err) + } + if cfg.Address != "192.168.1.30" || cfg.Method != "static" { + t.Errorf("networkd snapshot failed: %+v", cfg) + } + + if err := nd.SetLinkUp(t.Context(), "eth0"); err != nil { + t.Fatal(err) + } + if err := nd.SetLinkDown(t.Context(), "eth0"); err != nil { + t.Fatal(err) + } + + // 3. Test ifupdownBackend + iu := &ifupdownBackend{} + oldIfaceFile := interfacesFile + oldInterfacesDDir := interfacesDDir + interfacesFile = filepath.Join(tempDir, "interfaces") + interfacesDDir = filepath.Join(tempDir, "interfaces.d") + defer func() { + interfacesFile = oldIfaceFile + interfacesDDir = oldInterfacesDDir + }() + + if err := os.MkdirAll(interfacesDDir, 0755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(interfacesFile, []byte("auto lo\niface lo inet loopback\n"), 0644); err != nil { + t.Fatal(err) + } + + oscmd.SetMock("ifup", func(args []string) oscmd.MockCommand { return oscmd.MockCommand{ExitCode: 0} }) + oscmd.SetMock("ifdown", func(args []string) oscmd.MockCommand { return oscmd.MockCommand{ExitCode: 0} }) + + err = iu.Apply(t.Context(), "eth0", IfaceConfig{ + Method: "static", + Address: "192.168.1.40", + Prefix: 24, + Gateway: "192.168.1.1", + DNS: []string{"1.1.1.1"}, + }) + if err != nil { + t.Fatal(err) + } + + cfg, err = iu.Snapshot(t.Context(), "eth0") + if err != nil { + t.Fatal(err) + } + if cfg.Address != "192.168.1.40" || cfg.Method != "static" { + t.Errorf("ifupdown snapshot failed: %+v", cfg) + } + + if err := iu.SetLinkUp(t.Context(), "eth0"); err != nil { + t.Fatal(err) + } + if err := iu.SetLinkDown(t.Context(), "eth0"); err != nil { + t.Fatal(err) + } +} diff --git a/internal/modules/networking/networking_test.go b/internal/modules/networking/networking_test.go new file mode 100644 index 0000000..1050094 --- /dev/null +++ b/internal/modules/networking/networking_test.go @@ -0,0 +1,396 @@ +package networking + +import ( + "reflect" + "strings" + "testing" +) + +func TestParseInterfaces(t *testing.T) { + // Trimmed real output from `ip -j addr` on a typical host. + input := `[ + { + "ifname": "lo", + "address": "00:00:00:00:00:00", + "mtu": 65536, + "operstate": "UNKNOWN", + "addr_info": [ + {"family": "inet", "local": "127.0.0.1", "prefixlen": 8}, + {"family": "inet6", "local": "::1", "prefixlen": 128} + ] + }, + { + "ifname": "eth0", + "address": "52:54:00:12:34:56", + "mtu": 1500, + "operstate": "UP", + "addr_info": [ + {"family": "inet", "local": "192.168.1.10", "prefixlen": 24}, + {"family": "inet6", "local": "fe80::1", "prefixlen": 64} + ] + } +]` + + ifaces, err := parseInterfaces(input) + if err != nil { + t.Fatal(err) + } + if len(ifaces) != 2 { + t.Fatalf("expected 2 interfaces, got %d", len(ifaces)) + } + + lo := ifaces[0] + if lo.Name != "lo" || lo.State != "unknown" || lo.MTU != 65536 { + t.Errorf("lo: got %+v", lo) + } + if len(lo.IPv4) != 1 || lo.IPv4[0] != "127.0.0.1/8" { + t.Errorf("lo ipv4: got %v", lo.IPv4) + } + + eth0 := ifaces[1] + if eth0.Name != "eth0" || eth0.State != "up" || eth0.MAC != "52:54:00:12:34:56" { + t.Errorf("eth0: got %+v", eth0) + } + if len(eth0.IPv4) != 1 || eth0.IPv4[0] != "192.168.1.10/24" { + t.Errorf("eth0 ipv4: got %v", eth0.IPv4) + } + if len(eth0.IPv6) != 1 || eth0.IPv6[0] != "fe80::1/64" { + t.Errorf("eth0 ipv6: got %v", eth0.IPv6) + } +} + +func TestParseRoutes(t *testing.T) { + input := `[ + {"dst": "default", "gateway": "192.168.1.1", "dev": "eth0", "prefsrc": "192.168.1.10", "metric": 100}, + {"dst": "192.168.1.0/24", "dev": "eth0", "prefsrc": "192.168.1.10"} +]` + + routes, err := parseRoutes(input) + if err != nil { + t.Fatal(err) + } + if len(routes) != 2 { + t.Fatalf("expected 2 routes, got %d", len(routes)) + } + if routes[0].Destination != "default" || routes[0].Gateway != "192.168.1.1" || routes[0].Metric != 100 { + t.Errorf("route 0: got %+v", routes[0]) + } + if routes[1].Destination != "192.168.1.0/24" || routes[1].Interface != "eth0" { + t.Errorf("route 1: got %+v", routes[1]) + } +} + +func TestParseResolv(t *testing.T) { + input := `# Generated by NetworkManager +nameserver 1.1.1.1 +nameserver 8.8.8.8 +; legacy comment +search example.com +nameserver 9.9.9.9 +` + servers := parseResolv(input) + want := []string{"1.1.1.1", "8.8.8.8", "9.9.9.9"} + if !reflect.DeepEqual(servers, want) { + t.Errorf("parseResolv() = %v, want %v", servers, want) + } +} + +func TestParseResolvEmpty(t *testing.T) { + servers := parseResolv("") + if len(servers) != 0 { + t.Errorf("expected empty, got %v", servers) + } +} + +func TestParseNmcliSnapshot(t *testing.T) { + input := `ipv4.method:manual +ipv4.addresses:192.168.1.10/24 +ipv4.gateway:192.168.1.1 +ipv4.dns:1.1.1.1,8.8.8.8 +ipv4.routes:--` + + cfg := parseNmcliSnapshot(input) + if cfg.Method != "static" { + t.Errorf("method: got %q, want static", cfg.Method) + } + if cfg.Address != "192.168.1.10" || cfg.Prefix != 24 { + t.Errorf("address: got %s/%d", cfg.Address, cfg.Prefix) + } + if cfg.Gateway != "192.168.1.1" { + t.Errorf("gateway: got %q", cfg.Gateway) + } + if len(cfg.DNS) != 2 || cfg.DNS[0] != "1.1.1.1" || cfg.DNS[1] != "8.8.8.8" { + t.Errorf("dns: got %v", cfg.DNS) + } +} + +func TestParseNmcliRoutes(t *testing.T) { + input := `dst=10.0.0.0/24, nh=192.168.1.1; dst=172.16.0.0/12, nh=10.0.0.1` + routes := parseNmcliRoutes(input) + if len(routes) != 2 { + t.Fatalf("expected 2 routes, got %d", len(routes)) + } + if routes[0].Destination != "10.0.0.0/24" || routes[0].Gateway != "192.168.1.1" { + t.Errorf("route 0: got %v", routes[0]) + } + if routes[1].Destination != "172.16.0.0/12" || routes[1].Gateway != "10.0.0.1" { + t.Errorf("route 1: got %v", routes[1]) + } +} + +func TestParseNmcliSnapshotDHCP(t *testing.T) { + input := `ipv4.method:auto +ipv4.addresses:-- +ipv4.gateway:-- +ipv4.dns:-- +ipv4.routes:--` + + cfg := parseNmcliSnapshot(input) + if cfg.Method != "dhcp" { + t.Errorf("method: got %q, want dhcp", cfg.Method) + } +} + +func TestParseNetworkdFile(t *testing.T) { + input := `# Managed by nadir +[Match] +Name=eth0 + +[Network] +DHCP=no +Address=10.0.0.5/24 +Gateway=10.0.0.1 +DNS=1.1.1.1 +DNS=8.8.8.8 + +[Route] +Destination=192.168.0.0/16 +Gateway=10.0.0.254 +` + + cfg := parseNetworkdFile(input) + if cfg.Method != "static" { + t.Errorf("method: got %q, want static", cfg.Method) + } + if cfg.Address != "10.0.0.5" || cfg.Prefix != 24 { + t.Errorf("address: got %s/%d", cfg.Address, cfg.Prefix) + } + if cfg.Gateway != "10.0.0.1" { + t.Errorf("gateway: got %q", cfg.Gateway) + } + if len(cfg.DNS) != 2 { + t.Errorf("dns: got %v", cfg.DNS) + } + if len(cfg.Routes) != 1 || cfg.Routes[0].Destination != "192.168.0.0/16" || cfg.Routes[0].Gateway != "10.0.0.254" { + t.Errorf("routes: got %v", cfg.Routes) + } +} + +func TestParseNetworkdFileDHCP(t *testing.T) { + input := `[Match] +Name=eth0 + +[Network] +DHCP=yes +` + cfg := parseNetworkdFile(input) + if cfg.Method != "dhcp" { + t.Errorf("method: got %q, want dhcp", cfg.Method) + } +} + +func TestParseIfupdownStanza(t *testing.T) { + input := `# Managed by nadir +auto eth0 +iface eth0 inet static + address 192.168.1.10/24 + gateway 192.168.1.1 + dns-nameservers 1.1.1.1 8.8.8.8 + up ip route add 10.0.0.0/24 via 192.168.1.254 + down ip route del 10.0.0.0/24 via 192.168.1.254 +` + + cfg := parseIfupdownStanza(input) + if cfg.Method != "static" { + t.Errorf("method: got %q, want static", cfg.Method) + } + if cfg.Address != "192.168.1.10" || cfg.Prefix != 24 { + t.Errorf("address: got %s/%d", cfg.Address, cfg.Prefix) + } + if cfg.Gateway != "192.168.1.1" { + t.Errorf("gateway: got %q", cfg.Gateway) + } + if len(cfg.DNS) != 2 || cfg.DNS[0] != "1.1.1.1" || cfg.DNS[1] != "8.8.8.8" { + t.Errorf("dns: got %v", cfg.DNS) + } + if len(cfg.Routes) != 1 || cfg.Routes[0].Destination != "10.0.0.0/24" || cfg.Routes[0].Gateway != "192.168.1.254" { + t.Errorf("routes: got %v", cfg.Routes) + } +} + +func TestParseIfupdownStanzaDHCP(t *testing.T) { + input := `auto eth0 +iface eth0 inet dhcp +` + cfg := parseIfupdownStanza(input) + if cfg.Method != "dhcp" { + t.Errorf("method: got %q, want dhcp", cfg.Method) + } +} + +func TestRenderNetworkdFile(t *testing.T) { + cfg := IfaceConfig{ + Method: "static", + Address: "10.0.0.5", + Prefix: 24, + Gateway: "10.0.0.1", + DNS: []string{"1.1.1.1"}, + Routes: []Route{{Destination: "192.168.0.0/16", Gateway: "10.0.0.254"}}, + } + out := renderNetworkdFile("eth0", cfg) + + mustContain := []string{ + "Name=eth0", + "DHCP=no", + "Address=10.0.0.5/24", + "Gateway=10.0.0.1", + "DNS=1.1.1.1", + "Destination=192.168.0.0/16", + } + for _, s := range mustContain { + if !strings.Contains(out, s) { + t.Errorf("renderNetworkdFile missing %q in:\n%s", s, out) + } + } + + // Roundtrip test + parsed := parseNetworkdFile(out) + if !reflect.DeepEqual(parsed, cfg) { + t.Errorf("roundtrip failed: got %+v, want %+v", parsed, cfg) + } +} + +func TestRenderIfupdownStanza(t *testing.T) { + cfg := IfaceConfig{ + Method: "static", + Address: "192.168.1.10", + Prefix: 24, + Gateway: "192.168.1.1", + DNS: []string{"1.1.1.1", "8.8.8.8"}, + Routes: []Route{{Destination: "10.0.0.0/24", Gateway: "192.168.1.254"}}, + } + out := renderIfupdownStanza("eth0", cfg) + + mustContain := []string{ + "iface eth0 inet static", + "address 192.168.1.10/24", + "gateway 192.168.1.1", + "dns-nameservers 1.1.1.1 8.8.8.8", + "up ip route add 10.0.0.0/24 via 192.168.1.254", + "down ip route del 10.0.0.0/24 via 192.168.1.254", + } + for _, s := range mustContain { + if !strings.Contains(out, s) { + t.Errorf("renderIfupdownStanza missing %q in:\n%s", s, out) + } + } + + // Roundtrip test + parsed := parseIfupdownStanza(out) + if !reflect.DeepEqual(parsed, cfg) { + t.Errorf("roundtrip failed: got %+v, want %+v", parsed, cfg) + } +} + +func TestValidateIPv6(t *testing.T) { + ok := IfaceConfig{Method: "dhcp", IPv6: &IPv6Config{Method: "static", Address: "2001:db8::1", Prefix: 64, Gateway: "2001:db8::ff"}} + if err := ok.validate(); err != nil { + t.Errorf("valid ipv6 rejected: %v", err) + } + bad := []IfaceConfig{ + {Method: "dhcp", IPv6: &IPv6Config{Method: "static", Address: "1.2.3.4", Prefix: 64}}, // v4 in v6 block + {Method: "dhcp", IPv6: &IPv6Config{Method: "static", Address: "2001:db8::1", Prefix: 0}}, + {Method: "dhcp", IPv6: &IPv6Config{Method: "bogus"}}, + } + for i, c := range bad { + if err := c.validate(); err == nil { + t.Errorf("bad ipv6 case %d accepted", i) + } + } +} + +func TestParseNmcliSnapshotIPv6(t *testing.T) { + input := `ipv4.method:manual +ipv4.addresses:192.168.1.10/24 +ipv6.method:manual +ipv6.addresses:2001:db8::10/64 +ipv6.gateway:2001:db8::1` + cfg := parseNmcliSnapshot(input) + if cfg.IPv6 == nil { + t.Fatal("ipv6 not captured") + } + if cfg.IPv6.Method != "static" || cfg.IPv6.Address != "2001:db8::10" || cfg.IPv6.Prefix != 64 || cfg.IPv6.Gateway != "2001:db8::1" { + t.Errorf("ipv6 snapshot: %+v", cfg.IPv6) + } +} + +func TestNetworkdIPv6RoundTrip(t *testing.T) { + cfg := IfaceConfig{ + Method: "static", + Address: "10.0.0.5", + Prefix: 24, + IPv6: &IPv6Config{Method: "static", Address: "2001:db8::5", Prefix: 64, Gateway: "2001:db8::1"}, + } + got := parseNetworkdFile(renderNetworkdFile("eth0", cfg)) + if got.Address != "10.0.0.5" || got.Prefix != 24 { + t.Errorf("ipv4 lost in roundtrip: %+v", got) + } + if !reflect.DeepEqual(got.IPv6, cfg.IPv6) { + t.Errorf("ipv6 roundtrip: got %+v want %+v", got.IPv6, cfg.IPv6) + } +} + +func TestIfupdownIPv6RoundTrip(t *testing.T) { + cfg := IfaceConfig{ + Method: "static", + Address: "192.168.1.10", + Prefix: 24, + IPv6: &IPv6Config{Method: "static", Address: "2001:db8::10", Prefix: 64, Gateway: "2001:db8::1"}, + } + got := parseIfupdownStanza(renderIfupdownStanza("eth0", cfg)) + if got.Address != "192.168.1.10" || got.Prefix != 24 || got.Method != "static" { + t.Errorf("ipv4 lost in roundtrip: %+v", got) + } + if !reflect.DeepEqual(got.IPv6, cfg.IPv6) { + t.Errorf("ipv6 roundtrip: got %+v want %+v", got.IPv6, cfg.IPv6) + } +} + +func TestNetworkdIPv6AutoIgnore(t *testing.T) { + auto := parseNetworkdFile(renderNetworkdFile("eth0", IfaceConfig{Method: "dhcp", IPv6: &IPv6Config{Method: "auto"}})) + if auto.IPv6 == nil || auto.IPv6.Method != "auto" { + t.Errorf("auto roundtrip: %+v", auto.IPv6) + } + ign := parseNetworkdFile(renderNetworkdFile("eth0", IfaceConfig{Method: "dhcp", IPv6: &IPv6Config{Method: "ignore"}})) + if ign.IPv6 == nil || ign.IPv6.Method != "ignore" { + t.Errorf("ignore roundtrip: %+v", ign.IPv6) + } +} + +func TestSplitCIDR(t *testing.T) { + tests := []struct { + input string + wantAddr string + wantPrefix int + }{ + {"192.168.1.10/24", "192.168.1.10", 24}, + {"10.0.0.1/8", "10.0.0.1", 8}, + {"10.0.0.1", "10.0.0.1", 0}, + } + for _, tt := range tests { + addr, prefix := splitCIDR(tt.input) + if addr != tt.wantAddr || prefix != tt.wantPrefix { + t.Errorf("splitCIDR(%q) = (%q, %d), want (%q, %d)", tt.input, addr, prefix, tt.wantAddr, tt.wantPrefix) + } + } +} diff --git a/internal/modules/networking/nmcli.go b/internal/modules/networking/nmcli.go new file mode 100644 index 0000000..0b5c903 --- /dev/null +++ b/internal/modules/networking/nmcli.go @@ -0,0 +1,274 @@ +package networking + +import ( + "context" + "fmt" + "strconv" + "strings" + + "nadir/internal/oscmd" +) + +// nmcliBackend implements backend via NetworkManager's nmcli CLI. +type nmcliBackend struct{} + +func (b *nmcliBackend) Name() string { return "nmcli" } + +// connForIface resolves a network interface name to the NM connection name +// that owns it. Returns an error if the interface has no active connection. +// +// nmcli -t uses ':' as the field separator. Connection names can contain colons +// (e.g. "VLAN:100"), but Linux device names cannot, so we split on the last +// colon to get the device and treat everything before it as the connection name. +func connForIface(ctx context.Context, iface string) (string, error) { + out, err := oscmd.RunContext(ctx, "nmcli", "-t", "-f", "NAME,DEVICE", "con", "show", "--active") + if err != nil { + return "", fmt.Errorf("nmcli con show: %w", err) + } + for line := range strings.SplitSeq(out, "\n") { + // Split on the last colon: "conn:name:eth0" → ("conn:name", "eth0") + idx := strings.LastIndex(line, ":") + if idx < 0 { + continue + } + name, dev := line[:idx], line[idx+1:] + if dev == iface { + return name, nil + } + } + return "", fmt.Errorf("no active NM connection found for interface %s", iface) +} + +func (b *nmcliBackend) Snapshot(ctx context.Context, iface string) (IfaceConfig, error) { + conn, err := connForIface(ctx, iface) + if err != nil { + // No managed connection → assume DHCP (safest rollback assumption). + return IfaceConfig{Method: "dhcp"}, nil + } + + out, err := oscmd.RunContext(ctx, "nmcli", "-t", "-f", + "ipv4.method,ipv4.addresses,ipv4.gateway,ipv4.dns,ipv4.routes,ipv6.method,ipv6.addresses,ipv6.gateway", + "con", "show", "--", conn) + if err != nil { + return IfaceConfig{}, fmt.Errorf("nmcli con show %s: %w", conn, err) + } + + return parseNmcliSnapshot(out), nil +} + +// parseNmcliSnapshot parses the terse output of `nmcli -t -f ... con show`. +// Fields are colon-separated key:value lines. Multi-valued fields (addresses, +// dns, routes) use comma separation within the value. +func parseNmcliSnapshot(out string) IfaceConfig { + cfg := IfaceConfig{Method: "dhcp"} + for line := range strings.SplitSeq(out, "\n") { + key, val, ok := strings.Cut(line, ":") + if !ok || val == "" || val == "--" { + continue + } + switch key { + case "ipv4.method": + if val == "manual" { + cfg.Method = "static" + } else { + cfg.Method = "dhcp" + } + case "ipv4.addresses": + // "192.168.1.10/24" or "192.168.1.10/24, 10.0.0.1/8" + for _, part := range splitTrim(val, ",") { + addr, prefix := splitCIDR(part) + if addr != "" { + cfg.Address = addr + cfg.Prefix = prefix + } + } + case "ipv4.gateway": + cfg.Gateway = val + case "ipv4.dns": + cfg.DNS = append(cfg.DNS, splitTrim(val, ",")...) + case "ipv4.routes": + // "dst=10.0.0.0/24, nh=192.168.1.1; dst=default, nh=192.168.1.1" + // or simpler "{ dst = 10.0.0.0/24, nh = 192.168.1.1 }" + cfg.Routes = parseNmcliRoutes(val) + case "ipv6.method": + v6(&cfg).Method = nmcliV6Method(val) + case "ipv6.addresses": + for _, part := range splitTrim(val, ",") { + if addr, prefix := splitCIDR(part); addr != "" { + v6(&cfg).Address = addr + v6(&cfg).Prefix = prefix + } + } + case "ipv6.gateway": + v6(&cfg).Gateway = val + } + } + return cfg +} + +// v6 lazily allocates the IPv6 block so a snapshot captures the live IPv6 state +// (defaulting to "auto") even when only some ipv6.* fields are present. +func v6(cfg *IfaceConfig) *IPv6Config { + if cfg.IPv6 == nil { + cfg.IPv6 = &IPv6Config{Method: "auto"} + } + return cfg.IPv6 +} + +// nmcliV6Method maps nmcli's ipv6.method to our vocabulary. +func nmcliV6Method(val string) string { + switch val { + case "manual": + return "static" + case "ignore", "disabled", "link-local": + return "ignore" + default: + return "auto" + } +} + +// parseNmcliRoutes parses route entries from nmcli terse output. +func parseNmcliRoutes(val string) []Route { + var routes []Route + // Routes are semicolon-separated, each containing "dst=..., nh=..." + for _, entry := range splitTrim(val, ";") { + entry = strings.Trim(entry, "{}") + var r Route + for _, part := range splitTrim(entry, ",") { + k, v, ok := strings.Cut(part, "=") + if !ok { + continue + } + switch strings.TrimSpace(k) { + case "dst": + r.Destination = strings.TrimSpace(v) + case "nh": + r.Gateway = strings.TrimSpace(v) + } + } + if r.Destination != "" && r.Gateway != "" { + routes = append(routes, r) + } + } + return routes +} + +func (b *nmcliBackend) Apply(ctx context.Context, iface string, cfg IfaceConfig) error { + conn, err := connForIface(ctx, iface) + if err != nil { + return fmt.Errorf("cannot apply: %w", err) + } + + // Build the nmcli con modify arguments. Note: conn is safe to place after + // -- since it comes from nmcli output, not directly from the user. + args := []string{"con", "modify", "--", conn} + + switch cfg.Method { + case "static": + cidr := fmt.Sprintf("%s/%d", cfg.Address, cfg.Prefix) + args = append(args, + "ipv4.method", "manual", + "ipv4.addresses", cidr, + ) + if cfg.Gateway != "" { + args = append(args, "ipv4.gateway", cfg.Gateway) + } else { + args = append(args, "ipv4.gateway", "") + } + case "dhcp": + args = append(args, + "ipv4.method", "auto", + "ipv4.addresses", "", + "ipv4.gateway", "", + ) + } + + // DNS: set or clear. + if len(cfg.DNS) > 0 { + args = append(args, "ipv4.dns", strings.Join(cfg.DNS, ",")) + } else { + args = append(args, "ipv4.dns", "") + } + + // Routes: set or clear. + if len(cfg.Routes) > 0 { + var routeStrs []string + for _, r := range cfg.Routes { + routeStrs = append(routeStrs, r.Destination+" "+r.Gateway) + } + args = append(args, "ipv4.routes", strings.Join(routeStrs, ",")) + } else { + args = append(args, "ipv4.routes", "") + } + + // IPv6: only touched when the request includes an ipv6 block. + if cfg.IPv6 != nil { + switch cfg.IPv6.Method { + case "static": + args = append(args, + "ipv6.method", "manual", + "ipv6.addresses", fmt.Sprintf("%s/%d", cfg.IPv6.Address, cfg.IPv6.Prefix), + "ipv6.gateway", cfg.IPv6.Gateway, // "" clears it + ) + case "auto": + args = append(args, "ipv6.method", "auto", "ipv6.addresses", "", "ipv6.gateway", "") + case "ignore": + args = append(args, "ipv6.method", "ignore", "ipv6.addresses", "", "ipv6.gateway", "") + } + } + + if _, err := oscmd.RunContext(ctx, "nmcli", args...); err != nil { + return fmt.Errorf("nmcli con modify: %w", err) + } + + // Bring the connection up to apply changes. + if _, err := oscmd.RunContext(ctx, "nmcli", "con", "up", "--", conn); err != nil { + return fmt.Errorf("nmcli con up: %w", err) + } + return nil +} + +func (b *nmcliBackend) SetLinkUp(ctx context.Context, iface string) error { + conn, err := connForIface(ctx, iface) + if err != nil { + // No NM connection - fall back to `ip link`. + _, err = oscmd.RunContext(ctx, "ip", "link", "set", "--", iface, "up") + return err + } + _, err = oscmd.RunContext(ctx, "nmcli", "con", "up", "--", conn) + return err +} + +func (b *nmcliBackend) SetLinkDown(ctx context.Context, iface string) error { + conn, err := connForIface(ctx, iface) + if err != nil { + _, err = oscmd.RunContext(ctx, "ip", "link", "set", "--", iface, "down") + return err + } + _, err = oscmd.RunContext(ctx, "nmcli", "con", "down", "--", conn) + return err +} + +// --- helpers ----------------------------------------------------------------- + +// splitTrim splits s on sep and returns the trimmed, non-empty segments. +func splitTrim(s, sep string) []string { + var out []string + for part := range strings.SplitSeq(s, sep) { + if p := strings.TrimSpace(part); p != "" { + out = append(out, p) + } + } + return out +} + +// splitCIDR splits "192.168.1.10/24" into ("192.168.1.10", 24). A missing or +// malformed prefix yields 0. +func splitCIDR(cidr string) (string, int) { + addr, prefixStr, ok := strings.Cut(cidr, "/") + if !ok { + return addr, 0 + } + prefix, _ := strconv.Atoi(prefixStr) + return addr, prefix +} diff --git a/internal/modules/networking/read.go b/internal/modules/networking/read.go new file mode 100644 index 0000000..b54a404 --- /dev/null +++ b/internal/modules/networking/read.go @@ -0,0 +1,211 @@ +package networking + +import ( + "context" + "encoding/json" + "os" + "strconv" + "strings" + + "nadir/internal/oscmd" + + "github.com/danielgtaylor/huma/v2" +) + +// Reads go through `ip -j` (JSON) and /etc/resolv.conf, which behave the same +// regardless of which backend manages the interfaces - so reads need no backend +// detection. + +var resolvConf = "/etc/resolv.conf" + +type Interface struct { + Name string `json:"name" example:"eth0"` + State string `json:"state" example:"up" doc:"operstate: up / down / unknown"` + MAC string `json:"mac" example:"52:54:00:12:34:56"` + MTU int `json:"mtu" example:"1500"` + IPv4 []string `json:"ipv4" example:"[\"192.168.1.10/24\"]" doc:"IPv4 addresses in CIDR form"` + IPv6 []string `json:"ipv6" example:"[\"fe80::1/64\"]" doc:"IPv6 addresses in CIDR form"` +} + +type ListInterfacesOutput struct { + Body struct { + Interfaces []Interface `json:"interfaces"` + } +} + +type RouteEntry struct { + Destination string `json:"destination" example:"default" doc:"Destination network, or \"default\""` + Gateway string `json:"gateway,omitempty" example:"192.168.1.1"` + Interface string `json:"interface" example:"eth0"` + Source string `json:"source,omitempty" example:"192.168.1.10" doc:"Preferred source address"` + Metric int `json:"metric,omitempty" example:"100"` +} + +type ListRoutesOutput struct { + Body struct { + Routes []RouteEntry `json:"routes"` + } +} + +type DNSOutput struct { + Body struct { + Servers []string `json:"servers" example:"[\"1.1.1.1\"]" doc:"Nameservers from /etc/resolv.conf"` + } +} + +func registerReads(api huma.API) { + huma.Register(api, huma.Operation{ + OperationID: "networking-list-interfaces", + Method: "GET", + Path: "/api/networking/interfaces", + Summary: "List network interfaces", + Description: "Returns every interface with its state, MAC, MTU and addresses, via `ip -j addr`.", + Tags: []string{tagNetworking}, + Metadata: op("read"), + Errors: readErrors, + }, func(ctx context.Context, _ *struct{}) (*ListInterfacesOutput, error) { + out, err := oscmd.RunContext(ctx, "ip", "-j", "addr") + if err != nil { + return nil, huma.Error500InternalServerError("ip addr failed", err) + } + ifaces, err := parseInterfaces(out) + if err != nil { + return nil, huma.Error500InternalServerError("parse ip addr failed", err) + } + res := &ListInterfacesOutput{} + res.Body.Interfaces = ifaces + return res, nil + }) + + huma.Register(api, huma.Operation{ + OperationID: "networking-list-routes", + Method: "GET", + Path: "/api/networking/routes", + Summary: "List the IPv4 route table", + Description: "Returns the kernel IPv4 route table via `ip -j route`.", + Tags: []string{tagNetworking}, + Metadata: op("read"), + Errors: readErrors, + }, func(ctx context.Context, _ *struct{}) (*ListRoutesOutput, error) { + out, err := oscmd.RunContext(ctx, "ip", "-j", "route") + if err != nil { + return nil, huma.Error500InternalServerError("ip route failed", err) + } + routes, err := parseRoutes(out) + if err != nil { + return nil, huma.Error500InternalServerError("parse ip route failed", err) + } + res := &ListRoutesOutput{} + res.Body.Routes = routes + return res, nil + }) + + huma.Register(api, huma.Operation{ + OperationID: "networking-get-dns", + Method: "GET", + Path: "/api/networking/dns", + Summary: "Get configured DNS servers", + Description: "Returns the nameservers listed in /etc/resolv.conf. DNS is set " + + "per-interface as part of the interface config (PUT /api/networking/interfaces/{name}), " + + "so there is no standalone DNS write endpoint.", + Tags: []string{tagNetworking}, + Metadata: op("read"), + Errors: readErrors, + }, func(ctx context.Context, _ *struct{}) (*DNSOutput, error) { + data, err := os.ReadFile(resolvConf) + if err != nil { + return nil, huma.Error500InternalServerError("read resolv.conf failed", err) + } + res := &DNSOutput{} + res.Body.Servers = parseResolv(string(data)) + return res, nil + }) +} + +// --- parsers (pure, tested) -------------------------------------------------- + +// ipAddr mirrors the fields we use from one `ip -j addr` element. +type ipAddr struct { + Name string `json:"ifname"` + MAC string `json:"address"` + MTU int `json:"mtu"` + OperState string `json:"operstate"` + AddrInfo []struct { + Family string `json:"family"` // "inet" / "inet6" + Local string `json:"local"` + Prefix int `json:"prefixlen"` + } `json:"addr_info"` +} + +func parseInterfaces(jsonOut string) ([]Interface, error) { + var raw []ipAddr + if err := json.Unmarshal([]byte(jsonOut), &raw); err != nil { + return nil, err + } + ifaces := make([]Interface, 0, len(raw)) + for _, r := range raw { + iface := Interface{ + Name: r.Name, + State: strings.ToLower(r.OperState), + MAC: r.MAC, + MTU: r.MTU, + IPv4: []string{}, + IPv6: []string{}, + } + for _, a := range r.AddrInfo { + cidr := a.Local + "/" + strconv.Itoa(a.Prefix) + if a.Family == "inet6" { + iface.IPv6 = append(iface.IPv6, cidr) + } else { + iface.IPv4 = append(iface.IPv4, cidr) + } + } + ifaces = append(ifaces, iface) + } + return ifaces, nil +} + +// ipRoute mirrors the fields we use from one `ip -j route` element. +type ipRoute struct { + Dst string `json:"dst"` + Gateway string `json:"gateway"` + Dev string `json:"dev"` + PrefSrc string `json:"prefsrc"` + Metric int `json:"metric"` +} + +func parseRoutes(jsonOut string) ([]RouteEntry, error) { + var raw []ipRoute + if err := json.Unmarshal([]byte(jsonOut), &raw); err != nil { + return nil, err + } + routes := make([]RouteEntry, 0, len(raw)) + for _, r := range raw { + routes = append(routes, RouteEntry{ + Destination: r.Dst, + Gateway: r.Gateway, + Interface: r.Dev, + Source: r.PrefSrc, + Metric: r.Metric, + }) + } + return routes, nil +} + +// parseResolv extracts "nameserver X" entries, ignoring comments and other +// directives. +func parseResolv(text string) []string { + servers := []string{} + for line := range strings.SplitSeq(text, "\n") { + line = strings.TrimSpace(line) + if strings.HasPrefix(line, "#") || strings.HasPrefix(line, ";") { + continue + } + if rest, ok := strings.CutPrefix(line, "nameserver"); ok { + if s := strings.TrimSpace(rest); s != "" { + servers = append(servers, s) + } + } + } + return servers +} diff --git a/internal/modules/networking/rollback.go b/internal/modules/networking/rollback.go new file mode 100644 index 0000000..3329664 --- /dev/null +++ b/internal/modules/networking/rollback.go @@ -0,0 +1,193 @@ +package networking + +import ( + "context" + "errors" + "fmt" + "log" + "time" +) + +const defaultRollbackSeconds = 60 + +// errAlreadyPending is returned when another change is awaiting confirmation. +// The write handlers map this to 409 Conflict. +var errAlreadyPending = errors.New("already pending") + +// errPending builds the 409 message. The lock is global across all interfaces +// (see pendingChange), so the message says so to avoid confusing a user who is +// touching a different interface than the one that holds the lock. +func errPending(iface string) error { + return fmt.Errorf("%w: a change to %s is awaiting confirmation. This is a global lock across all interfaces — confirm or roll that change back first", errAlreadyPending, iface) +} + +// startRollback snapshots the current state, applies the new config, and arms a +// timer that auto-reverts if not confirmed. Returns errAlreadyPending (409) if +// another change is in flight, or a wrapped error (500) if apply fails. +// +// Snapshot and Apply run WITHOUT the mutex held, so they don't block reads or +// the pending-status endpoint while shelling out to nmcli/networkctl/ifup. +func (m *Module) startRollback(ctx context.Context, iface string, cfg IfaceConfig) (int, error) { + // Fast pre-check so we don't snapshot/apply when something is already + // pending. armPending re-checks under the lock to close the race. + if err := m.checkNoPending(); err != nil { + return 0, err + } + + prior, err := m.be.Snapshot(ctx, iface) + if err != nil { + return 0, fmt.Errorf("snapshot %s: %w", iface, err) + } + if err := m.be.Apply(ctx, iface, cfg); err != nil { + // Apply is not atomic: nmcli `con modify` may succeed before `con up` + // fails, and networkd writes the .network file before `reconfigure` + // runs. A failed Apply can therefore leave a half-applied config that + // would otherwise have NO auto-revert (we bail before arming the timer). + // Best-effort restore the snapshot so we never leave that unprotected. + if rerr := m.be.Apply(ctx, iface, prior); rerr != nil { + log.Printf("networking: apply %s failed and restore also failed: %v", iface, rerr) + } + return 0, fmt.Errorf("apply %s: %w", iface, err) + } + + // The revert runs from the timer or an explicit rollback, possibly with no + // client attached, so it uses context.Background() rather than ctx. + return m.armPending(iface, func() error { return m.be.Apply(context.Background(), iface, prior) }, cfg.RollbackSeconds) +} + +// startLinkDown takes the interface down behind the same rollback safety net: if +// the change is not confirmed, the interface is brought back up. Taking a remote +// interface down is just as much a lock-yourself-out risk as a bad static config. +// +// Bringing a link UP needs no protection (it cannot lock you out), so link-up +// stays a direct, un-wrapped call in the handler. +func (m *Module) startLinkDown(ctx context.Context, iface string) (int, error) { + if err := m.checkNoPending(); err != nil { + return 0, err + } + if err := m.be.SetLinkDown(ctx, iface); err != nil { + return 0, fmt.Errorf("link down %s: %w", iface, err) + } + // Revert (bring the link back up) may run from the timer with no client. + return m.armPending(iface, func() error { return m.be.SetLinkUp(context.Background(), iface) }, 0) +} + +// checkNoPending reports a 409 error if a change is already pending. +func (m *Module) checkNoPending() error { + m.mu.Lock() + defer m.mu.Unlock() + if m.pending != nil { + return errPending(m.pending.Iface) + } + return nil +} + +// armPending installs the pending change and starts its auto-revert timer. The +// caller has already applied the change; revert is the closure that undoes it. +// It is invoked on timer expiry, explicit rollback, or if a concurrent change +// raced us between the pre-check and here (in which case we revert immediately +// and report the conflict). seconds <= 0 uses the default timeout. +func (m *Module) armPending(iface string, revert func() error, seconds int) (int, error) { + if seconds <= 0 { + seconds = defaultRollbackSeconds + } + dur := time.Duration(seconds) * time.Second + + m.mu.Lock() + defer m.mu.Unlock() + + if m.pending != nil { + // Lost the race — undo what we just applied and report conflict. + if err := revert(); err != nil { + log.Printf("networking: failed to undo raced change on %s: %v", iface, err) + } + return 0, errPending(m.pending.Iface) + } + + pc := &pendingChange{ + Iface: iface, + revert: revert, + Deadline: time.Now().Add(dur), + } + // The timer fires the auto-revert. It captures m and pc by closure so it can + // revert even if the server is otherwise idle — the whole point is protecting + // against being locked out of a remote box. + pc.Timer = time.AfterFunc(dur, func() { + m.mu.Lock() + defer m.mu.Unlock() + // Only revert if this exact change is still pending (it may have been + // confirmed or manually rolled back in the meantime). + if m.pending != pc { + return + } + log.Printf("networking: rollback timer expired for %s — reverting", iface) + if err := pc.revert(); err != nil { + log.Printf("networking: auto-rollback of %s failed: %v", iface, err) + } + m.pending = nil + }) + + m.pending = pc + return seconds, nil +} + +// confirm cancels the rollback timer and clears the pending change, making it +// permanent. Errors if there is no pending change or it's for another interface. +func (m *Module) confirm(iface string) error { + m.mu.Lock() + defer m.mu.Unlock() + + if m.pending == nil { + return fmt.Errorf("no pending change to confirm") + } + if m.pending.Iface != iface { + return fmt.Errorf("pending change is for %s, not %s", m.pending.Iface, iface) + } + + m.pending.Timer.Stop() + m.pending = nil + return nil +} + +// rollbackNow immediately reverts the pending change and clears it. Errors if +// there is no pending change or it's for another interface. +func (m *Module) rollbackNow(iface string) error { + m.mu.Lock() + defer m.mu.Unlock() + + if m.pending == nil { + return fmt.Errorf("no pending change to rollback") + } + if m.pending.Iface != iface { + return fmt.Errorf("pending change is for %s, not %s", m.pending.Iface, iface) + } + + m.pending.Timer.Stop() + err := m.pending.revert() + m.pending = nil + if err != nil { + return fmt.Errorf("rollback %s: %w", iface, err) + } + return nil +} + +// PendingInfo is the JSON body returned by the pending-change status endpoint. +type PendingInfo struct { + Iface string `json:"interface" example:"eth0" doc:"Interface with a pending change"` + SecondsRemaining int `json:"seconds_remaining" example:"45" doc:"Seconds until auto-rollback"` +} + +// pendingInfo returns the current pending change status, or nil if none. +func (m *Module) pendingInfo() *PendingInfo { + m.mu.Lock() + defer m.mu.Unlock() + + if m.pending == nil { + return nil + } + remaining := max(int(time.Until(m.pending.Deadline).Seconds()), 0) + return &PendingInfo{ + Iface: m.pending.Iface, + SecondsRemaining: remaining, + } +} diff --git a/internal/modules/networking/write.go b/internal/modules/networking/write.go new file mode 100644 index 0000000..41b643d --- /dev/null +++ b/internal/modules/networking/write.go @@ -0,0 +1,194 @@ +package networking + +import ( + "context" + "errors" + + "nadir/internal/oscmd" + + "github.com/danielgtaylor/huma/v2" +) + +// registerWrites adds all write endpoints for the networking module. Every +// handler checks m.be != nil first and returns 501 when no backend was detected. + +type ApplyInput struct { + Name string `path:"name" example:"eth0" doc:"Interface name"` + Body IfaceConfig `doc:"Desired interface configuration"` +} + +type ApplyOutput struct { + Body struct { + Status string `json:"status" example:"pending" doc:"Always \"pending\" — confirm to make permanent"` + Interface string `json:"interface" example:"eth0"` + Backend string `json:"backend" example:"nmcli" doc:"Network backend that applied the change"` + RollbackSeconds int `json:"rollback_seconds" example:"60" doc:"Seconds until auto-rollback unless confirmed"` + } +} + +type IfacePathInput struct { + Name string `path:"name" example:"eth0" doc:"Interface name"` +} + +type PendingOutput struct { + Body PendingInfo +} + +func registerWrites(api huma.API, m *Module) { + huma.Register(api, huma.Operation{ + OperationID: "networking-apply-config", + Method: "PUT", + Path: "/api/networking/interfaces/{name}", + Summary: "Apply interface configuration", + Description: "Replaces the interface's IPv4 configuration. The change is applied " + + "immediately but starts a rollback timer — if not confirmed within the timeout " + + "(default 60s), the prior configuration is automatically restored. This prevents " + + "lock-yourself-out mistakes on remote hosts.", + Tags: []string{tagNetworking}, + Metadata: op("write"), + Errors: writeErrors, + }, func(ctx context.Context, in *ApplyInput) (*ApplyOutput, error) { + if m.be == nil { + return nil, huma.Error501NotImplemented("", errNoBackend) + } + if err := validateIface(in.Name); err != nil { + return nil, err + } + if err := in.Body.validate(); err != nil { + return nil, err + } + + seconds, err := m.startRollback(ctx, in.Name, in.Body) + if err != nil { + if errors.Is(err, errAlreadyPending) { + return nil, huma.Error409Conflict(err.Error()) + } + return nil, huma.Error500InternalServerError("apply failed", err) + } + + out := &ApplyOutput{} + out.Body.Status = "pending" + out.Body.Interface = in.Name + out.Body.Backend = m.be.Name() + out.Body.RollbackSeconds = seconds + return out, nil + }) + + huma.Register(api, huma.Operation{ + OperationID: "networking-confirm-change", + Method: "POST", + Path: "/api/networking/interfaces/{name}/confirm", + Summary: "Confirm a pending change", + Description: "Cancels the rollback timer, making the applied configuration permanent.", + Tags: []string{tagNetworking}, + Metadata: op("write"), + Errors: writeErrors, + }, func(ctx context.Context, in *IfacePathInput) (*oscmd.StatusOutput, error) { + if m.be == nil { + return nil, huma.Error501NotImplemented("", errNoBackend) + } + if err := validateIface(in.Name); err != nil { + return nil, err + } + if err := m.confirm(in.Name); err != nil { + return nil, huma.Error409Conflict(err.Error()) + } + return oscmd.OK(), nil + }) + + huma.Register(api, huma.Operation{ + OperationID: "networking-rollback-change", + Method: "POST", + Path: "/api/networking/interfaces/{name}/rollback", + Summary: "Immediately revert a pending change", + Description: "Reverts the interface to its prior configuration and clears the pending change.", + Tags: []string{tagNetworking}, + Metadata: op("write"), + Errors: writeErrors, + }, func(ctx context.Context, in *IfacePathInput) (*oscmd.StatusOutput, error) { + if m.be == nil { + return nil, huma.Error501NotImplemented("", errNoBackend) + } + if err := validateIface(in.Name); err != nil { + return nil, err + } + if err := m.rollbackNow(in.Name); err != nil { + return nil, huma.Error500InternalServerError("rollback failed", err) + } + return oscmd.OK(), nil + }) + + huma.Register(api, huma.Operation{ + OperationID: "networking-link-up", + Method: "POST", + Path: "/api/networking/interfaces/{name}/up", + Summary: "Bring an interface up", + Tags: []string{tagNetworking}, + Metadata: op("write"), + Errors: writeErrors, + }, func(ctx context.Context, in *IfacePathInput) (*oscmd.StatusOutput, error) { + if m.be == nil { + return nil, huma.Error501NotImplemented("", errNoBackend) + } + if err := validateIface(in.Name); err != nil { + return nil, err + } + if err := m.be.SetLinkUp(ctx, in.Name); err != nil { + return nil, huma.Error500InternalServerError("link up failed", err) + } + return oscmd.OK(), nil + }) + + huma.Register(api, huma.Operation{ + OperationID: "networking-link-down", + Method: "POST", + Path: "/api/networking/interfaces/{name}/down", + Summary: "Take an interface down", + Description: "Brings the interface down behind the rollback safety net: it is brought " + + "back up automatically if not confirmed within the timeout (default 60s). This " + + "prevents taking down the link you're managing the host over and losing access.", + Tags: []string{tagNetworking}, + Metadata: op("write"), + Errors: writeErrors, + }, func(ctx context.Context, in *IfacePathInput) (*ApplyOutput, error) { + if m.be == nil { + return nil, huma.Error501NotImplemented("", errNoBackend) + } + if err := validateIface(in.Name); err != nil { + return nil, err + } + seconds, err := m.startLinkDown(ctx, in.Name) + if err != nil { + if errors.Is(err, errAlreadyPending) { + return nil, huma.Error409Conflict(err.Error()) + } + return nil, huma.Error500InternalServerError("link down failed", err) + } + out := &ApplyOutput{} + out.Body.Status = "pending" + out.Body.Interface = in.Name + out.Body.Backend = m.be.Name() + out.Body.RollbackSeconds = seconds + return out, nil + }) + + huma.Register(api, huma.Operation{ + OperationID: "networking-get-pending", + Method: "GET", + Path: "/api/networking/pending", + Summary: "Get pending change status", + Description: "Returns the currently pending change (interface name and seconds until " + + "auto-rollback), or 404 if there is no pending change.", + Tags: []string{tagNetworking}, + Metadata: op("read"), + Errors: []int{401, 403, 404, 500}, + }, func(ctx context.Context, _ *struct{}) (*PendingOutput, error) { + info := m.pendingInfo() + if info == nil { + return nil, huma.Error404NotFound("no pending change") + } + out := &PendingOutput{} + out.Body = *info + return out, nil + }) +} diff --git a/internal/modules/packages/module.go b/internal/modules/packages/module.go new file mode 100644 index 0000000..7456ef3 --- /dev/null +++ b/internal/modules/packages/module.go @@ -0,0 +1,33 @@ +package packages + +import ( + "nadir/internal/rbac" + + "github.com/danielgtaylor/huma/v2" +) + +const ModuleID = "packages" + +type Module struct { + pm manager // detected package manager; zero value means none found +} + +// New detects the host's package manager once at startup. +func New() *Module { return &Module{pm: detect()} } + +func (m *Module) ID() string { return ModuleID } +func (m *Module) Name() string { return "Packages" } + +// Permissions: read to list installed/available; write to install, remove, and +// upgrade. +func (m *Module) Permissions() []rbac.Permission { + return []rbac.Permission{rbac.Read, rbac.Write} +} + +func (m *Module) Register(api huma.API) { + registerPackages(api, m.pm) +} + +func op(permission string) map[string]any { + return map[string]any{"module": ModuleID, "permission": permission} +} diff --git a/internal/modules/packages/packages.go b/internal/modules/packages/packages.go new file mode 100644 index 0000000..632ad2b --- /dev/null +++ b/internal/modules/packages/packages.go @@ -0,0 +1,370 @@ +package packages + +import ( + "context" + "os/exec" + "regexp" + "strings" + + "nadir/internal/oscmd" + + "github.com/danielgtaylor/huma/v2" + "github.com/danielgtaylor/huma/v2/sse" +) + +const tagPackages = "Packages" + +var ( + readErrors = []int{401, 403, 500} + writeErrors = []int{400, 401, 403, 500} +) + +// pkgNameRe matches a plain package name (no version specifiers, no shell +// metacharacters). Combined with a leading-dash reject and "--" before the name +// on every command line, it keeps user input from being read as a flag. +var pkgNameRe = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9@._+-]*$`) + +// manager identifies the host package manager. Its name drives a switch in each +// operation rather than an interface - three concrete tools, no polymorphism. +type manager struct{ name string } + +// detect picks the package manager. dnf/pacman are checked before apt because a +// few rpm-based boxes also ship an apt-get shim; the native tool should win. +func detect() manager { + for _, m := range []struct{ name, bin string }{ + {"dnf", "dnf"}, + {"pacman", "pacman"}, + {"apt", "apt-get"}, + } { + if _, err := exec.LookPath(m.bin); err == nil { + return manager{name: m.name} + } + } + return manager{} +} + +type Package struct { + Name string `json:"name" example:"openssh-server" doc:"Package name"` + Version string `json:"version" example:"9.6p1" doc:"Installed version, or the available version for updates"` +} + +type ListOutput struct { + Body struct { + Manager string `json:"manager" example:"dnf" doc:"Detected package manager"` + Packages []Package `json:"packages"` + } +} + +type InstallInput struct { + Body struct { + Name string `json:"name" example:"htop" doc:"Package to install"` + } +} + +type RemoveInput struct { + Name string `path:"name" example:"htop" doc:"Package to remove"` +} + +// SSE event types for streaming package operations. +type PkgOutputEvent struct { + Line string `json:"line" doc:"One line of the package manager's terminal output"` +} + +type PkgErrorEvent struct { + Message string `json:"message"` +} + +type PkgDoneEvent struct { + Success bool `json:"success" doc:"True if the package manager exited 0"` + Error string `json:"error,omitempty" doc:"Exit error when it failed"` +} + +// pkgEvents maps SSE event names to their payload types for the streaming +// install/remove/upgrade operations. +var pkgEvents = map[string]any{ + "output": PkgOutputEvent{}, + "done": PkgDoneEvent{}, + "error": PkgErrorEvent{}, +} + +func registerPackages(api huma.API, pm manager) { + huma.Register(api, huma.Operation{ + OperationID: "packages-list-installed", + Method: "GET", + Path: "/api/packages", + Summary: "List installed packages", + Tags: []string{tagPackages}, + Metadata: op("read"), + Errors: readErrors, + }, func(ctx context.Context, _ *struct{}) (*ListOutput, error) { + return listInstalled(pm) + }) + + huma.Register(api, huma.Operation{ + OperationID: "packages-list-updates", + Method: "GET", + Path: "/api/packages/updates", + Summary: "List available updates", + Tags: []string{tagPackages}, + Metadata: op("read"), + Errors: readErrors, + }, func(ctx context.Context, _ *struct{}) (*ListOutput, error) { + return listUpdates(pm) + }) + + // Install/remove/upgrade stream the package manager's terminal output live + // via SSE (one `output` event per line, a final `done` with exit status), + // rather than blocking the request until a long operation finishes. + sse.Register(api, huma.Operation{ + OperationID: "packages-install", + Method: "POST", + Path: "/api/packages", + Summary: "Install a package (streamed)", + Description: "Installs a package, streaming the package manager's output as " + + "`output` events and ending with a `done` event carrying the exit status.", + Tags: []string{tagPackages}, + Metadata: op("write"), + }, pkgEvents, func(ctx context.Context, in *InstallInput, send sse.Sender) { + if validateName(in.Body.Name) != nil { + send.Data(PkgErrorEvent{Message: "invalid package name: " + in.Body.Name}) + return + } + bin, args := pm.installArgs(in.Body.Name) + streamOp(ctx, send, bin, args) + }) + + sse.Register(api, huma.Operation{ + OperationID: "packages-remove", + Method: "DELETE", + Path: "/api/packages/{name}", + Summary: "Remove a package (streamed)", + Tags: []string{tagPackages}, + Metadata: op("write"), + }, pkgEvents, func(ctx context.Context, in *RemoveInput, send sse.Sender) { + if validateName(in.Name) != nil { + send.Data(PkgErrorEvent{Message: "invalid package name: " + in.Name}) + return + } + bin, args := pm.removeArgs(in.Name) + streamOp(ctx, send, bin, args) + }) + + sse.Register(api, huma.Operation{ + OperationID: "packages-upgrade", + Method: "POST", + Path: "/api/packages/upgrade", + Summary: "Upgrade all packages (streamed)", + Description: "Upgrades every package to its latest version, streaming the " + + "package manager's output live.", + Tags: []string{tagPackages}, + Metadata: op("write"), + }, pkgEvents, func(ctx context.Context, _ *struct{}, send sse.Sender) { + bin, args := pm.upgradeArgs() + streamOp(ctx, send, bin, args) + }) +} + +// streamOp runs a package write and streams its combined output to the client. +func streamOp(ctx context.Context, send sse.Sender, bin string, args []string) { + if bin == "" { + send.Data(PkgErrorEvent{Message: "no supported package manager found"}) + return + } + // DEBIAN_FRONTEND keeps apt from blocking on an interactive prompt. + lines, errc, err := oscmd.RunStreamCombined(ctx, []string{"DEBIAN_FRONTEND=noninteractive"}, bin, args...) + if err != nil { + send.Data(PkgErrorEvent{Message: err.Error()}) + return + } + for line := range lines { + if send.Data(PkgOutputEvent{Line: line}) != nil { + return // client gone; ctx cancel kills the process + } + } + done := PkgDoneEvent{Success: true} + if werr := <-errc; werr != nil { + done.Success = false + done.Error = werr.Error() + } + send.Data(done) +} + +func validateName(name string) error { + if !pkgNameRe.MatchString(name) { + return huma.Error400BadRequest("invalid package name: " + name) + } + return nil +} + +// --- reads ------------------------------------------------------------------- + +func listInstalled(pm manager) (*ListOutput, error) { + var out string + var err error + switch pm.name { + case "dnf": + out, err = oscmd.Run("rpm", "-qa", "--qf", "%{NAME}\t%{VERSION}-%{RELEASE}\n") + case "apt": + out, err = oscmd.Run("dpkg-query", "-W", "-f=${Package}\t${Version}\n") + case "pacman": + out, err = oscmd.Run("pacman", "-Q") + default: + return nil, huma.Error500InternalServerError("no supported package manager found") + } + if err != nil { + return nil, huma.Error500InternalServerError("listing installed packages failed", err) + } + pkgs := parseTabbed(out) + if pm.name == "pacman" { + pkgs = parseSpaced(out) + } + return result(pm, pkgs), nil +} + +func listUpdates(pm manager) (*ListOutput, error) { + switch pm.name { + case "dnf": + // check-update exits 100 when updates exist, 0 when none - both fine. + out, code, err := oscmd.RunStatus("dnf", "-q", "check-update") + if err != nil || (code != 0 && code != 100) { + return nil, huma.Error500InternalServerError("dnf check-update failed", err) + } + return result(pm, parseDnf(out)), nil + case "apt": + out, err := oscmd.Run("apt", "list", "--upgradable", "-qq") + if err != nil { + return nil, huma.Error500InternalServerError("apt list failed", err) + } + return result(pm, parseApt(out)), nil + case "pacman": + // -Qu exits 1 when there is nothing to upgrade - not an error. + out, code, err := oscmd.RunStatus("pacman", "-Qu") + if err != nil || (code != 0 && code != 1) { + return nil, huma.Error500InternalServerError("pacman -Qu failed", err) + } + return result(pm, parsePacmanUpdates(out)), nil + default: + return nil, huma.Error500InternalServerError("no supported package manager found") + } +} + +func result(pm manager, pkgs []Package) *ListOutput { + out := &ListOutput{} + out.Body.Manager = pm.name + out.Body.Packages = pkgs + return out +} + +// --- writes ------------------------------------------------------------------ + +func (m manager) installArgs(name string) (string, []string) { + switch m.name { + case "dnf": + return "dnf", []string{"install", "-y", "--", name} + case "apt": + return "apt-get", []string{"install", "-y", "--", name} + case "pacman": + return "pacman", []string{"-S", "--noconfirm", "--", name} + } + return "", nil +} + +func (m manager) removeArgs(name string) (string, []string) { + switch m.name { + case "dnf": + return "dnf", []string{"remove", "-y", "--", name} + case "apt": + return "apt-get", []string{"remove", "-y", "--", name} + case "pacman": + return "pacman", []string{"-R", "--noconfirm", "--", name} + } + return "", nil +} + +func (m manager) upgradeArgs() (string, []string) { + switch m.name { + case "dnf": + return "dnf", []string{"upgrade", "-y"} + case "apt": + return "apt-get", []string{"upgrade", "-y"} + case "pacman": + return "pacman", []string{"-Su", "--noconfirm"} + } + return "", nil +} + +// --- parsers (pure, tested) -------------------------------------------------- + +// parseTabbed reads "name\tversion" lines (dpkg-query / rpm output). +func parseTabbed(out string) []Package { + pkgs := []Package{} + for line := range strings.SplitSeq(out, "\n") { + if name, ver, ok := strings.Cut(line, "\t"); ok && name != "" { + pkgs = append(pkgs, Package{Name: name, Version: ver}) + } + } + return pkgs +} + +// parseSpaced reads "name version" lines (pacman -Q). +func parseSpaced(out string) []Package { + pkgs := []Package{} + for line := range strings.SplitSeq(out, "\n") { + if f := strings.Fields(line); len(f) >= 2 { + pkgs = append(pkgs, Package{Name: f[0], Version: f[1]}) + } + } + return pkgs +} + +// parseDnf reads `dnf check-update` lines ("name.arch version repo"), skipping +// the section headers and blank lines that dnf5 emits. +func parseDnf(out string) []Package { + pkgs := []Package{} + for line := range strings.SplitSeq(out, "\n") { + f := strings.Fields(line) + // Real rows have 3 columns and an arch suffix on the name; headers like + // "Upgrades" or "Obsoleting Packages" don't. + if len(f) != 3 || !strings.Contains(f[0], ".") { + continue + } + pkgs = append(pkgs, Package{Name: stripArch(f[0]), Version: f[1]}) + } + return pkgs +} + +// parseApt reads `apt list --upgradable` lines ("name/repo version arch [...]"). +func parseApt(out string) []Package { + pkgs := []Package{} + for line := range strings.SplitSeq(out, "\n") { + f := strings.Fields(line) + if len(f) < 2 || !strings.Contains(f[0], "/") { + continue + } + name, _, _ := strings.Cut(f[0], "/") + pkgs = append(pkgs, Package{Name: name, Version: f[1]}) + } + return pkgs +} + +// parsePacmanUpdates reads `pacman -Qu` lines ("name oldver -> newver"). +func parsePacmanUpdates(out string) []Package { + pkgs := []Package{} + for line := range strings.SplitSeq(out, "\n") { + f := strings.Fields(line) + if len(f) < 4 || f[2] != "->" { + continue + } + pkgs = append(pkgs, Package{Name: f[0], Version: f[3]}) + } + return pkgs +} + +// stripArch removes the trailing ".arch" from an rpm "name.arch" token. The +// arch is always the final dotted segment, so cut at the last dot. +func stripArch(s string) string { + if i := strings.LastIndex(s, "."); i > 0 { + return s[:i] + } + return s +} diff --git a/internal/modules/packages/packages_handler_test.go b/internal/modules/packages/packages_handler_test.go new file mode 100644 index 0000000..1741eef --- /dev/null +++ b/internal/modules/packages/packages_handler_test.go @@ -0,0 +1,125 @@ +package packages + +import ( + "encoding/json" + "net/http" + "os" + "reflect" + "strings" + "testing" + + "nadir/internal/oscmd" + + "github.com/danielgtaylor/huma/v2" + "github.com/danielgtaylor/huma/v2/adapters/humago" + "github.com/danielgtaylor/huma/v2/humatest" +) + +func TestMain(m *testing.M) { + if oscmd.RunHelperProcess() { + return + } + os.Exit(m.Run()) +} + +func TestPackagesHandlers(t *testing.T) { + managers := []string{"dnf", "apt", "pacman"} + + for _, mgrName := range managers { + t.Run(mgrName, func(t *testing.T) { + mux := http.NewServeMux() + api := humatest.Wrap(t, humago.New(mux, huma.DefaultConfig("Test", "1.0.0"))) + + m := &Module{pm: manager{name: mgrName}} + m.Register(api) + + oscmd.SetMock("rpm", func(args []string) oscmd.MockCommand { + return oscmd.MockCommand{Stdout: "htop\t3.3.0-1\n", ExitCode: 0} + }) + oscmd.SetMock("dpkg-query", func(args []string) oscmd.MockCommand { + return oscmd.MockCommand{Stdout: "htop\t3.3.0-1\n", ExitCode: 0} + }) + oscmd.SetMock("pacman", func(args []string) oscmd.MockCommand { + if reflect.DeepEqual(args, []string{"-Q"}) { + return oscmd.MockCommand{Stdout: "htop 3.3.0-1\n", ExitCode: 0} + } + if reflect.DeepEqual(args, []string{"-Qu"}) { + return oscmd.MockCommand{Stdout: "linux 6.9.1-1 -> 6.9.2-1\n", ExitCode: 0} + } + return oscmd.MockCommand{Lines: []string{"pacman output"}, ExitCode: 0} + }) + oscmd.SetMock("dnf", func(args []string) oscmd.MockCommand { + if reflect.DeepEqual(args, []string{"-q", "check-update"}) { + return oscmd.MockCommand{Stdout: "Upgrades\ncode.x86_64 1.125.1-1 code\n", ExitCode: 100} + } + return oscmd.MockCommand{Lines: []string{"dnf output"}, ExitCode: 0} + }) + oscmd.SetMock("apt", func(args []string) oscmd.MockCommand { + return oscmd.MockCommand{Stdout: "Listing...\nvim/jammy 2:8.2.3995 amd64 [upgradable]\n", ExitCode: 0} + }) + oscmd.SetMock("apt-get", func(args []string) oscmd.MockCommand { + return oscmd.MockCommand{Lines: []string{"apt-get output"}, ExitCode: 0} + }) + defer oscmd.ClearMocks() + + // Test list installed + resp := api.Get("/api/packages") + if resp.Code != http.StatusOK { + t.Errorf("list installed: got %d, want %d", resp.Code, http.StatusOK) + } + var listRes ListOutput + if err := json.Unmarshal(resp.Body.Bytes(), &listRes.Body); err != nil { + t.Fatal(err) + } + if len(listRes.Body.Packages) != 1 || listRes.Body.Packages[0].Name != "htop" { + t.Errorf("unexpected installed packages: %+v", listRes.Body) + } + + // Test list updates + resp = api.Get("/api/packages/updates") + if resp.Code != http.StatusOK { + t.Errorf("list updates: got %d, want %d", resp.Code, http.StatusOK) + } + if err := json.Unmarshal(resp.Body.Bytes(), &listRes.Body); err != nil { + t.Fatal(err) + } + if len(listRes.Body.Packages) != 1 { + t.Errorf("unexpected updates: %+v", listRes.Body) + } + + // Test install (SSE) + resp = api.Post("/api/packages", struct { + Name string `json:"name"` + }{ + Name: "htop", + }) + if resp.Code != http.StatusOK { + t.Errorf("install: got %d, want %d", resp.Code, http.StatusOK) + } + bodyStr := resp.Body.String() + if !strings.Contains(bodyStr, "done") { + t.Errorf("install stream output missing done: %q", bodyStr) + } + + // Test remove (SSE) + resp = api.Delete("/api/packages/htop") + if resp.Code != http.StatusOK { + t.Errorf("remove: got %d, want %d", resp.Code, http.StatusOK) + } + bodyStr = resp.Body.String() + if !strings.Contains(bodyStr, "done") { + t.Errorf("remove stream output missing done: %q", bodyStr) + } + + // Test upgrade (SSE) + resp = api.Post("/api/packages/upgrade", struct{}{}) + if resp.Code != http.StatusOK { + t.Errorf("upgrade: got %d, want %d", resp.Code, http.StatusOK) + } + bodyStr = resp.Body.String() + if !strings.Contains(bodyStr, "done") { + t.Errorf("upgrade stream output missing done: %q", bodyStr) + } + }) + } +} diff --git a/internal/modules/packages/packages_test.go b/internal/modules/packages/packages_test.go new file mode 100644 index 0000000..d1eee9c --- /dev/null +++ b/internal/modules/packages/packages_test.go @@ -0,0 +1,80 @@ +package packages + +import ( + "reflect" + "testing" +) + +func TestParseTabbed(t *testing.T) { + out := "zeromq\t4.3.5-22.fc43\nspice-server\t0.16.0-2.fc43\n\nbad-no-tab\n" + got := parseTabbed(out) + want := []Package{{"zeromq", "4.3.5-22.fc43"}, {"spice-server", "0.16.0-2.fc43"}} + if !reflect.DeepEqual(got, want) { + t.Errorf("got %+v, want %+v", got, want) + } +} + +func TestParseSpaced(t *testing.T) { + got := parseSpaced("linux 6.9.1\nhtop 3.3.0\n") + want := []Package{{"linux", "6.9.1"}, {"htop", "3.3.0"}} + if !reflect.DeepEqual(got, want) { + t.Errorf("got %+v, want %+v", got, want) + } +} + +func TestParseDnf(t *testing.T) { + // dnf5 emits a section header ("Upgrades") that must be skipped. + out := "Upgrades\n" + + "code.x86_64 1.125.1-1781859648.el8 code\n" + + "containerd.io.x86_64 2.2.5-1.fc44 docker-ce-stable\n" + + "\nObsoleting Packages\n" + got := parseDnf(out) + want := []Package{{"code", "1.125.1-1781859648.el8"}, {"containerd.io", "2.2.5-1.fc44"}} + if !reflect.DeepEqual(got, want) { + t.Errorf("got %+v, want %+v", got, want) + } +} + +func TestParseApt(t *testing.T) { + out := "Listing...\n" + + "vim/jammy-updates 2:8.2.3995-1ubuntu2.15 amd64 [upgradable from: 2:8.2.3995-1ubuntu2.1]\n" + got := parseApt(out) + want := []Package{{"vim", "2:8.2.3995-1ubuntu2.15"}} + if !reflect.DeepEqual(got, want) { + t.Errorf("got %+v, want %+v", got, want) + } +} + +func TestParsePacmanUpdates(t *testing.T) { + got := parsePacmanUpdates("linux 6.9.1-1 -> 6.9.2-1\nfoo 1.0 1.0\n") + want := []Package{{"linux", "6.9.2-1"}} + if !reflect.DeepEqual(got, want) { + t.Errorf("got %+v, want %+v", got, want) + } +} + +func TestStripArch(t *testing.T) { + cases := map[string]string{ + "code.x86_64": "code", + "python3.11.noarch": "python3.11", // arch is only the final segment + "noarchhere": "noarchhere", + } + for in, want := range cases { + if got := stripArch(in); got != want { + t.Errorf("stripArch(%q) = %q, want %q", in, got, want) + } + } +} + +func TestValidateName(t *testing.T) { + for _, n := range []string{"htop", "openssh-server", "lib32-glibc", "g++", "python3.11"} { + if err := validateName(n); err != nil { + t.Errorf("validateName(%q) = %v, want nil", n, err) + } + } + for _, n := range []string{"", "-rf", "foo;rm", "foo bar", "pkg=1.0", "a/b"} { + if err := validateName(n); err == nil { + t.Errorf("validateName(%q) = nil, want error", n) + } + } +} diff --git a/internal/modules/services/logs.go b/internal/modules/services/logs.go new file mode 100644 index 0000000..a71b15f --- /dev/null +++ b/internal/modules/services/logs.go @@ -0,0 +1,237 @@ +package services + +import ( + "context" + "encoding/json" + "slices" + "strconv" + "strings" + "time" + + "nadir/internal/oscmd" + + "github.com/danielgtaylor/huma/v2" + "github.com/danielgtaylor/huma/v2/sse" +) + +const ( + defaultLogLines = 100 + maxLogLines = 10000 +) + +// LogEntry is one log record. For the journal source it is distilled from +// journalctl's JSON; for the file source only Message is set (the raw line, +// which usually carries its own embedded timestamp). +type LogEntry struct { + Time string `json:"time" example:"2026-06-20T08:15:04Z" doc:"Record timestamp (RFC3339, UTC); empty for file lines"` + Priority int `json:"priority" example:"6" doc:"syslog priority 0 (emerg) – 7 (debug); 6 for file lines"` + Message string `json:"message" example:"Started OpenSSH server daemon."` +} + +// ErrorEvent is an SSE event carrying a stream-level error (e.g. bad unit). +type ErrorEvent struct { + Message string `json:"message"` +} + +type LogsInput struct { + Unit string `path:"unit" example:"docker.service" doc:"Unit name as listed by GET /api/services; the trailing .service is optional"` + Source string `query:"source" enum:"journal,file" default:"journal" doc:"Where to read logs from"` + Path string `query:"path" example:"/var/log/nginx/error.log" doc:"Log file (file source only); must be allowlisted for this unit in config"` + Lines int `query:"lines" default:"100" doc:"How many recent records to return (max 10000)"` + Since string `query:"since" example:"-1h" doc:"journalctl time filter (journal source only)"` + Priority int `query:"priority" default:"7" minimum:"0" maximum:"7" doc:"Max syslog priority to include: 0 emerg .. 7 debug (journal source only). 7 = all."` +} + +type LogsOutput struct { + Body struct { + Entries []LogEntry `json:"entries" doc:"Log records, oldest first"` + } +} + +type LogStreamInput struct { + Unit string `path:"unit" example:"docker.service" doc:"Unit name as listed by GET /api/services; the trailing .service is optional"` + Source string `query:"source" enum:"journal,file" default:"journal" doc:"Where to stream logs from"` + Path string `query:"path" example:"/var/log/nginx/error.log" doc:"Log file (file source only); must be allowlisted for this unit in config"` + Since string `query:"since" example:"-1h" doc:"Backfill window (journal source only)"` + Priority int `query:"priority" default:"7" minimum:"0" maximum:"7" doc:"Max syslog priority to include: 0 emerg .. 7 debug (journal source only). 7 = all."` +} + +func registerLogs(api huma.API, logFiles map[string][]string) { + huma.Register(api, huma.Operation{ + OperationID: "services-logs", + Method: "GET", + Path: "/api/services/{unit}/logs", + Summary: "Get recent log records for a service", + Description: "Returns a snapshot of the unit's logs from the journal " + + "(default) or an allowlisted file (source=file&path=). Use /logs/stream " + + "to follow new records live.", + Tags: []string{tagServices}, + Metadata: op("read"), + // No 404: the journal is historical, so logs are returned even for units + // that aren't currently loaded; an unknown unit just yields an empty list. + Errors: []int{400, 401, 403, 500}, + }, func(ctx context.Context, in *LogsInput) (*LogsOutput, error) { + if err := validateUnit(in.Unit); err != nil { + return nil, err + } + + var lines []string + var err error + if in.Source == "file" { + path, perr := resolveLogPath(logFiles, in.Unit, in.Path) + if perr != nil { + return nil, perr + } + lines, err = oscmd.RunLines("tail", "-n", strconv.Itoa(clampLines(in.Lines)), "--", path) + if err != nil { + return nil, huma.Error500InternalServerError("tail failed", err) + } + return fileOutput(lines), nil + } + + args := []string{"-u", journalUnit(in.Unit), "--no-pager", "-o", "json", "-p", strconv.Itoa(in.Priority), "-n", strconv.Itoa(clampLines(in.Lines))} + if in.Since != "" { + args = append(args, "--since", in.Since) + } + lines, err = oscmd.RunLines("journalctl", args...) + if err != nil { + return nil, huma.Error500InternalServerError("journalctl failed", err) + } + out := &LogsOutput{} + out.Body.Entries = []LogEntry{} + for _, l := range lines { + if e, ok := parseJournalLine([]byte(l)); ok { + out.Body.Entries = append(out.Body.Entries, e) + } + } + return out, nil + }) + + // Streaming via huma's sse package keeps the route inside huma, so the RBAC + // middleware still enforces op("read") - a raw mux handler would bypass it. + sse.Register(api, huma.Operation{ + OperationID: "services-logs-stream", + Method: "GET", + Path: "/api/services/{unit}/logs/stream", + Summary: "Stream a service's logs (Server-Sent Events)", + Description: "Follows the unit's journal (journalctl -f) or an allowlisted " + + "file (source=file&path=, via tail -F) and emits a `log` event per " + + "record. Stops when the client disconnects.", + Tags: []string{tagServices}, + Metadata: op("read"), + }, map[string]any{ + "log": LogEntry{}, + "error": ErrorEvent{}, + }, func(ctx context.Context, in *LogStreamInput, send sse.Sender) { + if err := validateUnit(in.Unit); err != nil { + send.Data(ErrorEvent{Message: "invalid unit name"}) + return + } + + var cmd string + var args []string + if in.Source == "file" { + path, perr := resolveLogPath(logFiles, in.Unit, in.Path) + if perr != nil { + send.Data(ErrorEvent{Message: perr.Error()}) + return + } + cmd, args = "tail", []string{"-n", strconv.Itoa(defaultLogLines), "-F", "--", path} + } else { + cmd = "journalctl" + args = []string{"-u", journalUnit(in.Unit), "--no-pager", "-o", "json", "-p", strconv.Itoa(in.Priority), "-f"} + if in.Since != "" { + args = append(args, "--since", in.Since) + } + } + + lines, err := oscmd.RunStream(ctx, cmd, args...) + if err != nil { + send.Data(ErrorEvent{Message: cmd + " failed: " + err.Error()}) + return + } + for l := range lines { + e, ok := LogEntry{Priority: 6, Message: l}, true + if in.Source != "file" { + e, ok = parseJournalLine([]byte(l)) + } + if ok { + if send.Data(e) != nil { + return // client gone; ctx cancel will kill the command + } + } + } + }) +} + +// resolveLogPath validates that path is allowlisted for unit. The caller never +// gets to point exec at an arbitrary file - only paths an admin listed under +// log_files for this unit are accepted. +func resolveLogPath(logFiles map[string][]string, unit, path string) (string, error) { + if path == "" { + return "", huma.Error400BadRequest("source=file requires a path") + } + // Match the allowlist key suffix-insensitively (nginx == nginx.service), so + // it behaves like the journal source regardless of which form the caller and + // the config author each used. + want := journalUnit(unit) + for key, paths := range logFiles { + if journalUnit(key) == want && slices.Contains(paths, path) { + return path, nil + } + } + return "", huma.Error403Forbidden("log file not allowlisted for unit " + unit + ": " + path) +} + +func fileOutput(lines []string) *LogsOutput { + out := &LogsOutput{} + out.Body.Entries = make([]LogEntry, 0, len(lines)) + for _, l := range lines { + out.Body.Entries = append(out.Body.Entries, LogEntry{Priority: 6, Message: l}) + } + return out +} + +// journalUnit normalizes a unit name for `journalctl -u`. journalctl treats a +// bare name as a .service, and on some setups only the bare form matches the +// recorded _SYSTEMD_UNIT, so we always strip the suffix. This is the services +// module, so .service is the only suffix we expect. +func journalUnit(unit string) string { + return strings.TrimSuffix(unit, ".service") +} + +func clampLines(n int) int { + switch { + case n <= 0: + return defaultLogLines + case n > maxLogLines: + return maxLogLines + default: + return n + } +} + +// parseJournalLine distills one journalctl `-o json` record. Returns false for +// unparseable lines. Binary MESSAGE fields (encoded as a byte array rather than +// a string) yield an empty message rather than an error. +func parseJournalLine(line []byte) (LogEntry, bool) { + var raw struct { + Message any `json:"MESSAGE"` + Priority string `json:"PRIORITY"` + TS string `json:"__REALTIME_TIMESTAMP"` + } + if err := json.Unmarshal(line, &raw); err != nil { + return LogEntry{}, false + } + // Records without a PRIORITY (it's often absent) default to info (6), not + // the zero value 0 which is emerg - that would fake critical alerts. + e := LogEntry{Priority: 6} + e.Message, _ = raw.Message.(string) + if p, err := strconv.Atoi(raw.Priority); err == nil { + e.Priority = p + } + if us, err := strconv.ParseInt(raw.TS, 10, 64); err == nil { + e.Time = time.UnixMicro(us).UTC().Format(time.RFC3339) + } + return e, true +} diff --git a/internal/modules/services/logs_test.go b/internal/modules/services/logs_test.go new file mode 100644 index 0000000..42d1223 --- /dev/null +++ b/internal/modules/services/logs_test.go @@ -0,0 +1,96 @@ +package services + +import "testing" + +func TestParseJournalLine(t *testing.T) { + line := []byte(`{"__REALTIME_TIMESTAMP":"1750406104000000","PRIORITY":"6","MESSAGE":"Started OpenSSH server daemon.","_PID":"123"}`) + e, ok := parseJournalLine(line) + if !ok { + t.Fatal("expected parse to succeed") + } + if e.Message != "Started OpenSSH server daemon." { + t.Errorf("message = %q", e.Message) + } + if e.Priority != 6 { + t.Errorf("priority = %d", e.Priority) + } + if e.Time != "2025-06-20T07:55:04Z" { + t.Errorf("time = %q", e.Time) + } +} + +func TestParseJournalLineBinaryMessage(t *testing.T) { + // Binary MESSAGE is encoded as a byte array; we yield an empty message, not an error. + line := []byte(`{"__REALTIME_TIMESTAMP":"1750406104000000","PRIORITY":"3","MESSAGE":[104,105]}`) + e, ok := parseJournalLine(line) + if !ok || e.Message != "" || e.Priority != 3 { + t.Errorf("got ok=%v entry=%+v", ok, e) + } +} + +func TestParseJournalLineMissingPriority(t *testing.T) { + // PRIORITY is often absent; it must default to info (6), not emerg (0). + line := []byte(`{"__REALTIME_TIMESTAMP":"1750406104000000","MESSAGE":"hi"}`) + e, ok := parseJournalLine(line) + if !ok || e.Priority != 6 { + t.Errorf("got ok=%v priority=%d, want priority 6", ok, e.Priority) + } +} + +func TestParseJournalLineGarbage(t *testing.T) { + if _, ok := parseJournalLine([]byte("not json")); ok { + t.Error("garbage line should not parse") + } +} + +func TestResolveLogPath(t *testing.T) { + allow := map[string][]string{ + "nginx.service": {"/var/log/nginx/access.log", "/var/log/nginx/error.log"}, + } + + // Allowlisted path resolves whether the caller uses the bare or .service + // form, regardless of which form the config key used. + for _, unit := range []string{"nginx.service", "nginx"} { + if p, err := resolveLogPath(allow, unit, "/var/log/nginx/error.log"); err != nil || p != "/var/log/nginx/error.log" { + t.Errorf("allowlisted path for %q: got %q, %v", unit, p, err) + } + } + + // Everything else is rejected: empty path, non-listed path (traversal), + // listed path but wrong unit, and unit with no allowlist at all. + bad := []struct{ unit, path string }{ + {"nginx.service", ""}, + {"nginx.service", "/etc/shadow"}, + {"nginx.service", "/var/log/nginx/access.log/../../../etc/shadow"}, + {"sshd.service", "/var/log/nginx/error.log"}, + {"unknown.service", "/var/log/nginx/error.log"}, + } + for _, b := range bad { + if _, err := resolveLogPath(allow, b.unit, b.path); err == nil { + t.Errorf("resolveLogPath(%q, %q) = nil error, want rejection", b.unit, b.path) + } + } +} + +func TestJournalUnit(t *testing.T) { + cases := map[string]string{ + "docker.service": "docker", + "docker": "docker", + "sshd.service": "sshd", + "foo.socket": "foo.socket", // only .service is stripped + } + for in, want := range cases { + if got := journalUnit(in); got != want { + t.Errorf("journalUnit(%q) = %q, want %q", in, got, want) + } + } +} + +func TestClampLines(t *testing.T) { + cases := map[int]int{0: defaultLogLines, -5: defaultLogLines, 50: 50, 999999: maxLogLines} + for in, want := range cases { + if got := clampLines(in); got != want { + t.Errorf("clampLines(%d) = %d, want %d", in, got, want) + } + } +} diff --git a/internal/modules/services/module.go b/internal/modules/services/module.go new file mode 100644 index 0000000..6c5fcd4 --- /dev/null +++ b/internal/modules/services/module.go @@ -0,0 +1,38 @@ +package services + +import ( + "nadir/internal/rbac" + + "github.com/danielgtaylor/huma/v2" +) + +const ModuleID = "services" + +type Module struct { + // logFiles is the per-unit allowlist of readable log files (from config), + // consulted by the file log source. + logFiles map[string][]string +} + +func New(logFiles map[string][]string) *Module { return &Module{logFiles: logFiles} } + +func (m *Module) ID() string { return ModuleID } +func (m *Module) Name() string { return "Services" } + +// Permissions: read to list and inspect units; write to control them +// (start/stop/restart/enable/disable). +func (m *Module) Permissions() []rbac.Permission { + return []rbac.Permission{rbac.Read, rbac.Write} +} + +func (m *Module) Register(api huma.API) { + registerServices(api) + registerLogs(api, m.logFiles) +} + +func op(permission string) map[string]any { + return map[string]any{ + "module": ModuleID, + "permission": permission, + } +} diff --git a/internal/modules/services/services.go b/internal/modules/services/services.go new file mode 100644 index 0000000..b151a97 --- /dev/null +++ b/internal/modules/services/services.go @@ -0,0 +1,180 @@ +package services + +import ( + "context" + "encoding/json" + "regexp" + "strings" + + "nadir/internal/oscmd" + + "github.com/danielgtaylor/huma/v2" +) + +const tagServices = "Services" + +var ( + readErrors = []int{401, 403, 500} + writeErrors = []int{400, 401, 403, 404, 500} +) + +// unitNameRe matches valid systemd unit names. Combined with a leading-dash +// reject and the "--" separator on every systemctl call, it keeps user-supplied +// unit names from being interpreted as options. +var unitNameRe = regexp.MustCompile(`^[a-zA-Z0-9@._:-]+$`) + +// ServiceUnit mirrors one entry of `systemctl list-units --type=service -o json`. +type ServiceUnit struct { + Unit string `json:"unit" example:"sshd.service" doc:"Unit name"` + Load string `json:"load" example:"loaded" doc:"Load state"` + Active string `json:"active" example:"active" doc:"High-level active state"` + Sub string `json:"sub" example:"running" doc:"Low-level sub state"` + Description string `json:"description" example:"OpenSSH server daemon" doc:"Unit description"` +} + +type ListServicesOutput struct { + Body struct { + Services []ServiceUnit `json:"services" doc:"All service units, active and inactive"` + } +} + +// ServiceStatusBody is the detailed status of a single unit from `systemctl show`. +type ServiceStatusBody struct { + Unit string `json:"unit" example:"sshd.service"` + Description string `json:"description" example:"OpenSSH server daemon"` + LoadState string `json:"load_state" example:"loaded" doc:"loaded / not-found / masked"` + ActiveState string `json:"active_state" example:"active" doc:"active / inactive / failed"` + SubState string `json:"sub_state" example:"running"` + UnitFileState string `json:"unit_file_state" example:"enabled" doc:"enabled / disabled / static"` +} + +type GetServiceOutput struct{ Body ServiceStatusBody } + +// UnitPath is the shared path parameter for per-unit operations. +type UnitPath struct { + Unit string `path:"unit" example:"sshd.service" doc:"systemd unit name"` +} + +func registerServices(api huma.API) { + huma.Register(api, huma.Operation{ + OperationID: "services-list", + Method: "GET", + Path: "/api/services", + Summary: "List service units", + Description: "Returns all service units (active and inactive) via " + + "`systemctl list-units --type=service --all`.", + Tags: []string{tagServices}, + Metadata: op("read"), + Errors: readErrors, + }, func(ctx context.Context, _ *struct{}) (*ListServicesOutput, error) { + out, err := oscmd.Run("systemctl", "list-units", "--type=service", "--all", "-o", "json", "--no-pager") + if err != nil { + return nil, huma.Error500InternalServerError("systemctl list-units failed", err) + } + var units []ServiceUnit + if err := json.Unmarshal([]byte(out), &units); err != nil { + return nil, huma.Error500InternalServerError("parse systemctl json failed", err) + } + res := &ListServicesOutput{} + res.Body.Services = units + return res, nil + }) + + huma.Register(api, huma.Operation{ + OperationID: "services-get", + Method: "GET", + Path: "/api/services/{unit}", + Summary: "Get a service's status", + Description: "Returns load/active/sub/unit-file state for one unit via " + + "`systemctl show`. Returns 404 when the unit does not exist.", + Tags: []string{tagServices}, + Metadata: op("read"), + Errors: []int{400, 401, 403, 404, 500}, + }, func(ctx context.Context, in *UnitPath) (*GetServiceOutput, error) { + if err := validateUnit(in.Unit); err != nil { + return nil, err + } + m, err := showUnit(in.Unit) + if err != nil { + return nil, huma.Error500InternalServerError("systemctl show failed", err) + } + if m["LoadState"] == "not-found" { + return nil, huma.Error404NotFound("unit not found: " + in.Unit) + } + out := &GetServiceOutput{Body: ServiceStatusBody{ + Unit: m["Id"], + Description: m["Description"], + LoadState: m["LoadState"], + ActiveState: m["ActiveState"], + SubState: m["SubState"], + UnitFileState: m["UnitFileState"], + }} + return out, nil + }) + + controls := []struct{ action, summary, desc string }{ + {"start", "Start a service", "Starts the unit (`systemctl start`)."}, + {"stop", "Stop a service", "Stops the unit (`systemctl stop`)."}, + {"restart", "Restart a service", "Restarts the unit (`systemctl restart`)."}, + {"enable", "Enable a service at boot", "Enables the unit (`systemctl enable`)."}, + {"disable", "Disable a service at boot", "Disables the unit (`systemctl disable`)."}, + } + for _, c := range controls { + huma.Register(api, huma.Operation{ + OperationID: "services-" + c.action, + Method: "POST", + Path: "/api/services/{unit}/" + c.action, + Summary: c.summary, + Description: c.desc + " Returns 404 when the unit does not exist.", + Tags: []string{tagServices}, + Metadata: op("write"), + Errors: writeErrors, + }, func(ctx context.Context, in *UnitPath) (*oscmd.StatusOutput, error) { + if err := validateUnit(in.Unit); err != nil { + return nil, err + } + if err := ensureExists(in.Unit); err != nil { + return nil, err + } + if _, err := oscmd.Run("systemctl", c.action, "--", in.Unit); err != nil { + return nil, huma.Error500InternalServerError("systemctl "+c.action+" failed", err) + } + return oscmd.OK(), nil + }) + } +} + +// validateUnit guards against empty, flag-like, or malformed unit names. +func validateUnit(unit string) error { + if unit == "" || strings.HasPrefix(unit, "-") || !unitNameRe.MatchString(unit) { + return huma.Error400BadRequest("invalid unit name: " + unit) + } + return nil +} + +// showUnit returns selected properties of a unit as a key=value map. systemctl +// show exits 0 even for unknown units (LoadState=not-found), so callers must +// check LoadState to detect non-existence. +func showUnit(unit string) (map[string]string, error) { + lines, err := oscmd.RunLines("systemctl", "show", + "-p", "Id", "-p", "Description", "-p", "LoadState", + "-p", "ActiveState", "-p", "SubState", "-p", "UnitFileState", + "--", unit) + if err != nil { + return nil, err + } + return oscmd.ParseKV(lines), nil +} + +// ensureExists returns a 404 if the unit is unknown, mapping the systemctl +// show probe to an HTTP error for the control endpoints. +func ensureExists(unit string) error { + m, err := showUnit(unit) + if err != nil { + return huma.Error500InternalServerError("systemctl show failed", err) + } + if m["LoadState"] == "not-found" { + return huma.Error404NotFound("unit not found: " + unit) + } + return nil +} diff --git a/internal/modules/services/services_handler_test.go b/internal/modules/services/services_handler_test.go new file mode 100644 index 0000000..5c11159 --- /dev/null +++ b/internal/modules/services/services_handler_test.go @@ -0,0 +1,161 @@ +package services + +import ( + "encoding/json" + "net/http" + "os" + "path/filepath" + "reflect" + "strings" + "testing" + + "nadir/internal/oscmd" + + "github.com/danielgtaylor/huma/v2" + "github.com/danielgtaylor/huma/v2/adapters/humago" + "github.com/danielgtaylor/huma/v2/humatest" +) + +func TestMain(m *testing.M) { + if oscmd.RunHelperProcess() { + return + } + os.Exit(m.Run()) +} + +func TestServicesHandlers(t *testing.T) { + mux := http.NewServeMux() + api := humatest.Wrap(t, humago.New(mux, huma.DefaultConfig("Test", "1.0.0"))) + + // Set up allowlisted log files for testing the file source + logFiles := map[string][]string{ + "nginx.service": {filepath.Join(t.TempDir(), "nginx-error.log")}, + } + // Create the dummy log file + errLogPath := logFiles["nginx.service"][0] + if err := os.WriteFile(errLogPath, []byte("file log line 1\nfile log line 2\n"), 0644); err != nil { + t.Fatal(err) + } + + m := New(logFiles) + m.Register(api) + + // 1. Test GET /api/services (list services) + oscmd.SetMock("systemctl", func(args []string) oscmd.MockCommand { + if reflect.DeepEqual(args, []string{"list-units", "--type=service", "--all", "-o", "json", "--no-pager"}) { + units := []ServiceUnit{ + {Unit: "sshd.service", Load: "loaded", Active: "active", Sub: "running", Description: "OpenSSH"}, + } + data, _ := json.Marshal(units) + return oscmd.MockCommand{Stdout: string(data) + "\n", ExitCode: 0} + } + if reflect.DeepEqual(args, []string{"show", "-p", "Id", "-p", "Description", "-p", "LoadState", "-p", "ActiveState", "-p", "SubState", "-p", "UnitFileState", "--", "sshd.service"}) { + showOut := "Id=sshd.service\nDescription=OpenSSH\nLoadState=loaded\nActiveState=active\nSubState=running\nUnitFileState=enabled\n" + return oscmd.MockCommand{Stdout: showOut, ExitCode: 0} + } + if reflect.DeepEqual(args, []string{"start", "--", "sshd.service"}) { + return oscmd.MockCommand{ExitCode: 0} + } + return oscmd.MockCommand{ExitCode: 1} + }) + defer oscmd.ClearMocks() + + resp := api.Get("/api/services") + if resp.Code != http.StatusOK { + t.Errorf("list services: got %d, want %d", resp.Code, http.StatusOK) + } + var listRes ListServicesOutput + if err := json.Unmarshal(resp.Body.Bytes(), &listRes.Body); err != nil { + t.Fatal(err) + } + if len(listRes.Body.Services) != 1 || listRes.Body.Services[0].Unit != "sshd.service" { + t.Errorf("list services output: %+v", listRes.Body) + } + + // 2. Test GET /api/services/{unit} (get service status) + resp = api.Get("/api/services/sshd.service") + if resp.Code != http.StatusOK { + t.Errorf("get service status: got %d, want %d", resp.Code, http.StatusOK) + } + var getRes GetServiceOutput + if err := json.Unmarshal(resp.Body.Bytes(), &getRes.Body); err != nil { + t.Fatal(err) + } + if getRes.Body.Unit != "sshd.service" || getRes.Body.ActiveState != "active" { + t.Errorf("get service output: %+v", getRes.Body) + } + + // 3. Test POST /api/services/{unit}/start + resp = api.Post("/api/services/sshd.service/start", struct{}{}) + if resp.Code != http.StatusOK { + t.Errorf("start service: got %d, want %d", resp.Code, http.StatusOK) + } + + // 4. Test GET /api/services/{unit}/logs (journal source) + oscmd.SetMock("journalctl", func(args []string) oscmd.MockCommand { + if strings.Contains(strings.Join(args, " "), "-f") { + // Streaming mock + lines := []string{ + `{"MESSAGE":"streaming line 1","PRIORITY":"6","__REALTIME_TIMESTAMP":"1718873704000000"}`, + } + return oscmd.MockCommand{Lines: lines, DelayMs: 1, ExitCode: 0} + } + // Regular snapshot mock + lines := []string{ + `{"MESSAGE":"journal line 1","PRIORITY":"6","__REALTIME_TIMESTAMP":"1718873704000000"}`, + } + return oscmd.MockCommand{Stdout: strings.Join(lines, "\n") + "\n", ExitCode: 0} + }) + + resp = api.Get("/api/services/sshd.service/logs") + if resp.Code != http.StatusOK { + t.Errorf("get journal logs: got %d, want %d", resp.Code, http.StatusOK) + } + var logsRes LogsOutput + if err := json.Unmarshal(resp.Body.Bytes(), &logsRes.Body); err != nil { + t.Fatal(err) + } + if len(logsRes.Body.Entries) != 1 || logsRes.Body.Entries[0].Message != "journal line 1" { + t.Errorf("journal logs output: %+v", logsRes.Body) + } + + // 5. Test GET /api/services/{unit}/logs (file source) + oscmd.SetMock("tail", func(args []string) oscmd.MockCommand { + if strings.Contains(strings.Join(args, " "), "-F") { + // Streaming mock + return oscmd.MockCommand{Lines: []string{"stream file line 1"}, DelayMs: 1, ExitCode: 0} + } + return oscmd.MockCommand{Stdout: "file log line 1\nfile log line 2\n", ExitCode: 0} + }) + + resp = api.Get("/api/services/nginx.service/logs?source=file&path=" + errLogPath) + if resp.Code != http.StatusOK { + t.Errorf("get file logs: got %d, want %d", resp.Code, http.StatusOK) + } + if err := json.Unmarshal(resp.Body.Bytes(), &logsRes.Body); err != nil { + t.Fatal(err) + } + if len(logsRes.Body.Entries) != 2 || logsRes.Body.Entries[0].Message != "file log line 1" { + t.Errorf("file logs output: %+v", logsRes.Body) + } + + // 6. Test GET /api/services/{unit}/logs/stream (journal stream) + resp = api.Get("/api/services/sshd.service/logs/stream") + if resp.Code != http.StatusOK { + t.Errorf("stream journal logs: got %d, want %d", resp.Code, http.StatusOK) + } + bodyStr := resp.Body.String() + if !strings.Contains(bodyStr, "streaming line 1") { + t.Errorf("stream journal logs missing message, got: %q", bodyStr) + } + + // 7. Test GET /api/services/{unit}/logs/stream (file stream) + resp = api.Get("/api/services/nginx.service/logs/stream?source=file&path=" + errLogPath) + if resp.Code != http.StatusOK { + t.Errorf("stream file logs: got %d, want %d", resp.Code, http.StatusOK) + } + bodyStr = resp.Body.String() + if !strings.Contains(bodyStr, "stream file line 1") { + t.Errorf("stream file logs missing message, got: %q", bodyStr) + } +} diff --git a/internal/modules/services/services_test.go b/internal/modules/services/services_test.go new file mode 100644 index 0000000..579942d --- /dev/null +++ b/internal/modules/services/services_test.go @@ -0,0 +1,21 @@ +package services + +import "testing" + +func TestValidateUnit(t *testing.T) { + valid := []string{"sshd.service", "getty@tty1.service", "foo.bar:baz-1.service", "a_b.timer"} + for _, u := range valid { + if err := validateUnit(u); err != nil { + t.Errorf("validateUnit(%q) = %v, want nil", u, err) + } + } + + // Empty, flag-injection, and anything with shell/path metacharacters must + // be rejected before reaching systemctl. + invalid := []string{"", "-rf", "--now", "a b", "foo;rm -rf /", "a/b", "naughty$()", "x|y"} + for _, u := range invalid { + if err := validateUnit(u); err == nil { + t.Errorf("validateUnit(%q) = nil, want error", u) + } + } +} diff --git a/internal/modules/storage/command.go b/internal/modules/storage/command.go new file mode 100644 index 0000000..c0fa9b0 --- /dev/null +++ b/internal/modules/storage/command.go @@ -0,0 +1,48 @@ +package storage + +import ( + "regexp" + "strings" + + "github.com/danielgtaylor/huma/v2" +) + +const tagStorage = "Storage" + +var ( + readErrors = []int{401, 403, 500} + writeErrors = []int{400, 401, 403, 409, 500} +) + +// Validation regexes. Every value is also passed after "--" on the command line. +// We forbid whitespace and shell metacharacters outright (mount/umount take no +// such values for our purposes, and fstab fields with spaces would need octal +// escaping we don't emit). +var ( + // deviceRe allows a /dev path or a UUID=/LABEL= specifier. + deviceRe = regexp.MustCompile(`^(/dev/[a-zA-Z0-9/._-]+|UUID=[a-zA-Z0-9-]+|LABEL=[a-zA-Z0-9._-]+|PARTUUID=[a-zA-Z0-9-]+)$`) + // mountpointRe is an absolute path without traversal or whitespace. + mountpointRe = regexp.MustCompile(`^/[a-zA-Z0-9/._-]*$`) + // fstypeRe matches a filesystem type token (ext4, xfs, btrfs, vfat, nfs, …). + fstypeRe = regexp.MustCompile(`^[a-z0-9]+$`) + // optionsRe matches a comma-separated mount option list (defaults, noatime, + // uid=1000, …). + optionsRe = regexp.MustCompile(`^[a-zA-Z0-9=,._:+-]+$`) +) + +func validateMountpoint(mp string) error { + if !mountpointRe.MatchString(mp) || hasDotDot(mp) { + return huma.Error400BadRequest("invalid mountpoint: " + mp) + } + return nil +} + +// hasDotDot rejects a ".." path segment so a mountpoint can't escape upward. +func hasDotDot(p string) bool { + for _, seg := range strings.Split(p, "/") { + if seg == ".." { + return true + } + } + return false +} diff --git a/internal/modules/storage/module.go b/internal/modules/storage/module.go new file mode 100644 index 0000000..f05f163 --- /dev/null +++ b/internal/modules/storage/module.go @@ -0,0 +1,30 @@ +package storage + +import ( + "nadir/internal/rbac" + + "github.com/danielgtaylor/huma/v2" +) + +const ModuleID = "storage" + +type Module struct{} + +func New() *Module { return &Module{} } + +func (m *Module) ID() string { return ModuleID } +func (m *Module) Name() string { return "Storage" } + +// Permissions: read to list mounts and fstab; write to mount/unmount and edit +// fstab entries. +func (m *Module) Permissions() []rbac.Permission { + return []rbac.Permission{rbac.Read, rbac.Write} +} + +func (m *Module) Register(api huma.API) { + registerStorage(api) +} + +func op(permission string) map[string]any { + return map[string]any{"module": ModuleID, "permission": permission} +} diff --git a/internal/modules/storage/storage.go b/internal/modules/storage/storage.go new file mode 100644 index 0000000..12d239d --- /dev/null +++ b/internal/modules/storage/storage.go @@ -0,0 +1,311 @@ +package storage + +import ( + "context" + "fmt" + "os" + "strconv" + "strings" + + "nadir/internal/mounts" + "nadir/internal/oscmd" + + "github.com/danielgtaylor/huma/v2" +) + +// fstabFile is a var so tests can point it at a fixture. +var fstabFile = "/etc/fstab" + +// FstabEntry is one /etc/fstab line. Dump and Pass are the last two numeric +// fields (fs_freq and fs_passno). +type FstabEntry struct { + Device string `json:"device" example:"UUID=1234-5678"` + Mountpoint string `json:"mountpoint" example:"/mnt/data"` + FSType string `json:"fstype" example:"ext4"` + Options string `json:"options" example:"defaults"` + Dump int `json:"dump" example:"0"` + Pass int `json:"pass" example:"2" doc:"fsck order (0 = skip)"` +} + +type ListMountsOutput struct { + Body struct { + Mounts []mounts.Mount `json:"mounts"` + } +} + +type ListFstabOutput struct { + Body struct { + Entries []FstabEntry `json:"entries"` + } +} + +type MountInput struct { + Body struct { + Device string `json:"device" example:"/dev/sdb1" doc:"Block device or UUID=/LABEL= specifier"` + Mountpoint string `json:"mountpoint" example:"/mnt/data" doc:"Absolute mount path"` + FSType string `json:"fstype" example:"ext4"` + Options string `json:"options,omitempty" example:"defaults" doc:"Mount options (default: defaults)"` + Dump int `json:"dump,omitempty"` + Pass int `json:"pass,omitempty" doc:"fsck order (default 2 for real filesystems, 0 to skip)"` + } +} + +type UnmountInput struct { + Mountpoint string `query:"mountpoint" example:"/mnt/data" doc:"Mountpoint to unmount and remove from fstab"` +} + +func registerStorage(api huma.API) { + huma.Register(api, huma.Operation{ + OperationID: "storage-list-mounts", + Method: "GET", + Path: "/api/storage/mounts", + Summary: "List active mounts", + Description: "Returns the kernel mount table (/proc/mounts).", + Tags: []string{tagStorage}, + Metadata: op("read"), + Errors: readErrors, + }, func(ctx context.Context, _ *struct{}) (*ListMountsOutput, error) { + entries, err := mounts.Proc() + if err != nil { + return nil, huma.Error500InternalServerError("read mounts failed", err) + } + res := &ListMountsOutput{} + res.Body.Mounts = entries + return res, nil + }) + + huma.Register(api, huma.Operation{ + OperationID: "storage-list-fstab", + Method: "GET", + Path: "/api/storage/fstab", + Summary: "List /etc/fstab entries", + Description: "Returns the persistent mount definitions from /etc/fstab.", + Tags: []string{tagStorage}, + Metadata: op("read"), + Errors: readErrors, + }, func(ctx context.Context, _ *struct{}) (*ListFstabOutput, error) { + data, err := os.ReadFile(fstabFile) + if err != nil { + return nil, huma.Error500InternalServerError("read fstab failed", err) + } + res := &ListFstabOutput{} + res.Body.Entries = parseFstab(string(data)) + return res, nil + }) + + huma.Register(api, huma.Operation{ + OperationID: "storage-add-mount", + Method: "POST", + Path: "/api/storage/mounts", + Summary: "Add and mount a filesystem", + Description: "Appends an /etc/fstab entry and mounts it. If the mount fails the " + + "fstab entry is rolled back, so a bad request leaves the system unchanged.", + Tags: []string{tagStorage}, + Metadata: op("write"), + Errors: writeErrors, + }, func(ctx context.Context, in *MountInput) (*oscmd.StatusOutput, error) { + e := FstabEntry{ + Device: in.Body.Device, + Mountpoint: in.Body.Mountpoint, + FSType: in.Body.FSType, + Options: in.Body.Options, + Dump: in.Body.Dump, + Pass: in.Body.Pass, + } + if e.Options == "" { + e.Options = "defaults" + } + if err := validateEntry(e); err != nil { + return nil, err + } + + existing, err := readFstab() + if err != nil { + return nil, huma.Error500InternalServerError("read fstab failed", err) + } + if findEntry(existing, e.Mountpoint) != nil { + return nil, huma.Error409Conflict("an fstab entry already exists for " + e.Mountpoint) + } + + if err := appendFstabLine(e); err != nil { + return nil, huma.Error500InternalServerError("write fstab failed", err) + } + // mount reads the fstab line we just wrote. On failure, roll the line back + // so a bad device/options doesn't linger in fstab and break the next boot. + if _, err := oscmd.RunContext(ctx, "mount", "--", e.Mountpoint); err != nil { + if _, werr := removeFstabLines(e.Mountpoint); werr != nil { + return nil, huma.Error500InternalServerError("mount failed and fstab rollback failed", werr) + } + return nil, huma.Error400BadRequest("mount failed: " + err.Error()) + } + return oscmd.OK(), nil + }) + + huma.Register(api, huma.Operation{ + OperationID: "storage-remove-mount", + Method: "DELETE", + Path: "/api/storage/mounts", + Summary: "Unmount and remove a filesystem", + Description: "Unmounts the filesystem (if mounted) and removes its /etc/fstab entry. " + + "The mountpoint is passed as a query parameter since it contains slashes.", + Tags: []string{tagStorage}, + Metadata: op("write"), + Errors: []int{400, 401, 403, 404, 409, 500}, + }, func(ctx context.Context, in *UnmountInput) (*oscmd.StatusOutput, error) { + if err := validateMountpoint(in.Mountpoint); err != nil { + return nil, err + } + + entries, err := readFstab() + if err != nil { + return nil, huma.Error500InternalServerError("read fstab failed", err) + } + inFstab := findEntry(entries, in.Mountpoint) != nil + + mounted, err := isMounted(in.Mountpoint) + if err != nil { + return nil, huma.Error500InternalServerError("read mounts failed", err) + } + if !inFstab && !mounted { + return nil, huma.Error404NotFound("no mount or fstab entry for " + in.Mountpoint) + } + + if mounted { + if _, err := oscmd.RunContext(ctx, "umount", "--", in.Mountpoint); err != nil { + return nil, huma.Error409Conflict("unmount failed (in use?): " + err.Error()) + } + } + if inFstab { + if _, err := removeFstabLines(in.Mountpoint); err != nil { + return nil, huma.Error500InternalServerError("write fstab failed", err) + } + } + return oscmd.OK(), nil + }) +} + +// --- fstab helpers ----------------------------------------------------------- + +func readFstab() ([]FstabEntry, error) { + data, err := os.ReadFile(fstabFile) + if err != nil { + return nil, err + } + return parseFstab(string(data)), nil +} + +// parseFstab parses /etc/fstab, skipping comments and blank lines. +func parseFstab(data string) []FstabEntry { + entries := []FstabEntry{} + for line := range strings.SplitSeq(data, "\n") { + trimmed := strings.TrimSpace(line) + if trimmed == "" || strings.HasPrefix(trimmed, "#") { + continue + } + f := strings.Fields(trimmed) + if len(f) < 4 { // device mountpoint fstype options [dump] [pass] + continue + } + e := FstabEntry{ + Device: mounts.Unescape(f[0]), + Mountpoint: mounts.Unescape(f[1]), + FSType: f[2], + Options: f[3], + } + if len(f) >= 5 { + e.Dump, _ = strconv.Atoi(f[4]) + } + if len(f) >= 6 { + e.Pass, _ = strconv.Atoi(f[5]) + } + entries = append(entries, e) + } + return entries +} + +// fstab edits are surgical (append one line / drop matching lines) rather than a +// whole-file rewrite, so an admin's existing entries, comments and formatting in +// this boot-critical file are preserved — same approach as /etc/hosts. + +func renderFstabLine(e FstabEntry) string { + return fmt.Sprintf("%s\t%s\t%s\t%s\t%d\t%d", e.Device, e.Mountpoint, e.FSType, e.Options, e.Dump, e.Pass) +} + +// appendFstabLine adds one entry, leaving every existing line untouched. +func appendFstabLine(e FstabEntry) error { + data, err := os.ReadFile(fstabFile) + if err != nil { + return err + } + content := string(data) + if content != "" && !strings.HasSuffix(content, "\n") { + content += "\n" + } + content += renderFstabLine(e) + "\n" + return os.WriteFile(fstabFile, []byte(content), 0644) +} + +// removeFstabLines drops every line mapping mountpoint, preserving comments and +// other entries. Reports whether anything was removed. +func removeFstabLines(mountpoint string) (bool, error) { + data, err := os.ReadFile(fstabFile) + if err != nil { + return false, err + } + lines := strings.Split(string(data), "\n") + kept := make([]string, 0, len(lines)) + removed := false + for _, line := range lines { + trimmed := strings.TrimSpace(line) + if trimmed != "" && !strings.HasPrefix(trimmed, "#") { + if f := strings.Fields(trimmed); len(f) >= 2 && mounts.Unescape(f[1]) == mountpoint { + removed = true + continue + } + } + kept = append(kept, line) + } + if !removed { + return false, nil + } + return true, os.WriteFile(fstabFile, []byte(strings.Join(kept, "\n")), 0644) +} + +func findEntry(entries []FstabEntry, mountpoint string) *FstabEntry { + for i := range entries { + if entries[i].Mountpoint == mountpoint { + return &entries[i] + } + } + return nil +} + +// isMounted reports whether mountpoint currently appears in /proc/mounts. +func isMounted(mountpoint string) (bool, error) { + active, err := mounts.Proc() + if err != nil { + return false, err + } + for _, e := range active { + if e.Mountpoint == mountpoint { + return true, nil + } + } + return false, nil +} + +func validateEntry(e FstabEntry) error { + if !deviceRe.MatchString(e.Device) { + return huma.Error400BadRequest("invalid device: " + e.Device) + } + if err := validateMountpoint(e.Mountpoint); err != nil { + return err + } + if !fstypeRe.MatchString(e.FSType) { + return huma.Error400BadRequest("invalid fstype: " + e.FSType) + } + if !optionsRe.MatchString(e.Options) { + return huma.Error400BadRequest("invalid options: " + e.Options) + } + return nil +} diff --git a/internal/modules/storage/storage_handler_test.go b/internal/modules/storage/storage_handler_test.go new file mode 100644 index 0000000..eb48a1f --- /dev/null +++ b/internal/modules/storage/storage_handler_test.go @@ -0,0 +1,83 @@ +package storage + +import ( + "net/http" + "os" + "path/filepath" + "strings" + "testing" + + "nadir/internal/oscmd" + + "github.com/danielgtaylor/huma/v2" + "github.com/danielgtaylor/huma/v2/adapters/humago" + "github.com/danielgtaylor/huma/v2/humatest" +) + +func TestStorageHandlers(t *testing.T) { + path := filepath.Join(t.TempDir(), "fstab") + if err := os.WriteFile(path, []byte("# seed\nUUID=root / ext4 defaults 0 1\n"), 0644); err != nil { + t.Fatal(err) + } + old := fstabFile + fstabFile = path + defer func() { fstabFile = old }() + + mux := http.NewServeMux() + api := humatest.Wrap(t, humago.New(mux, huma.DefaultConfig("Test", "1.0.0"))) + New().Register(api) + + mountOK := true + oscmd.SetMock("mount", func([]string) oscmd.MockCommand { + if mountOK { + return oscmd.MockCommand{ExitCode: 0} + } + return oscmd.MockCommand{Stderr: "mount: wrong fs type", ExitCode: 32} + }) + oscmd.SetMock("umount", func([]string) oscmd.MockCommand { return oscmd.MockCommand{ExitCode: 0} }) + defer oscmd.ClearMocks() + + body := map[string]any{"device": "/dev/sdb1", "mountpoint": "/mnt/data", "fstype": "ext4"} + + // Add + mount succeeds. + if resp := api.Post("/api/storage/mounts", body); resp.Code != http.StatusOK { + t.Fatalf("add mount: got %d, body=%s", resp.Code, resp.Body.String()) + } + if findEntry(mustReadFstab(t), "/mnt/data") == nil { + t.Fatal("entry not written to fstab") + } + + // Duplicate mountpoint → 409. + if resp := api.Post("/api/storage/mounts", body); resp.Code != http.StatusConflict { + t.Errorf("duplicate: got %d, want 409", resp.Code) + } + + // mount fails → 400 and the fstab line is rolled back. + mountOK = false + bad := map[string]any{"device": "/dev/sdc1", "mountpoint": "/mnt/bad", "fstype": "ext4"} + if resp := api.Post("/api/storage/mounts", bad); resp.Code != http.StatusBadRequest { + t.Errorf("failed mount: got %d, want 400", resp.Code) + } + if findEntry(mustReadFstab(t), "/mnt/bad") != nil { + t.Error("fstab entry not rolled back after mount failure") + } + + // Invalid device is rejected before touching the system. + if resp := api.Post("/api/storage/mounts", map[string]any{"device": "/dev/x; rm", "mountpoint": "/mnt/x", "fstype": "ext4"}); resp.Code != http.StatusBadRequest { + t.Errorf("invalid device: got %d, want 400", resp.Code) + } + + // Delete the good entry (not actually mounted, so no umount needed). + if resp := api.Delete("/api/storage/mounts?mountpoint=/mnt/data"); resp.Code != http.StatusOK { + t.Errorf("delete: got %d, body=%s", resp.Code, resp.Body.String()) + } + data, _ := os.ReadFile(path) + if strings.Contains(string(data), "/mnt/data") || !strings.Contains(string(data), "# seed") { + t.Errorf("delete result wrong:\n%s", data) + } + + // Deleting a nonexistent mountpoint → 404. + if resp := api.Delete("/api/storage/mounts?mountpoint=/mnt/none"); resp.Code != http.StatusNotFound { + t.Errorf("delete missing: got %d, want 404", resp.Code) + } +} diff --git a/internal/modules/storage/storage_test.go b/internal/modules/storage/storage_test.go new file mode 100644 index 0000000..2c9f02e --- /dev/null +++ b/internal/modules/storage/storage_test.go @@ -0,0 +1,110 @@ +package storage + +import ( + "os" + "path/filepath" + "reflect" + "strings" + "testing" + + "nadir/internal/oscmd" +) + +func TestMain(m *testing.M) { + if oscmd.RunHelperProcess() { + return + } + os.Exit(m.Run()) +} + +func TestParseFstab(t *testing.T) { + data := `# /etc/fstab +UUID=1111-2222 / ext4 defaults 0 1 + +/dev/sdb1 /mnt/data ext4 rw,noatime 0 2 +LABEL=swap none swap sw +` + got := parseFstab(data) + want := []FstabEntry{ + {Device: "UUID=1111-2222", Mountpoint: "/", FSType: "ext4", Options: "defaults", Dump: 0, Pass: 1}, + {Device: "/dev/sdb1", Mountpoint: "/mnt/data", FSType: "ext4", Options: "rw,noatime", Dump: 0, Pass: 2}, + {Device: "LABEL=swap", Mountpoint: "none", FSType: "swap", Options: "sw", Dump: 0, Pass: 0}, + } + if !reflect.DeepEqual(got, want) { + t.Errorf("parseFstab:\n got %+v\nwant %+v", got, want) + } +} + +func TestFstabSurgicalEdits(t *testing.T) { + path := filepath.Join(t.TempDir(), "fstab") + old := fstabFile + fstabFile = path + defer func() { fstabFile = old }() + + // Seed with a comment + an existing entry that must survive edits. + seed := "# keep me\nUUID=root / ext4 defaults 0 1\n" + if err := os.WriteFile(path, []byte(seed), 0644); err != nil { + t.Fatal(err) + } + + // Append a new entry. + if err := appendFstabLine(FstabEntry{Device: "/dev/sdb1", Mountpoint: "/mnt/data", FSType: "xfs", Options: "noatime", Pass: 2}); err != nil { + t.Fatal(err) + } + entries, err := readFstab() + if err != nil { + t.Fatal(err) + } + if len(entries) != 2 || findEntry(entries, "/mnt/data") == nil { + t.Fatalf("append failed: %+v", entries) + } + + // Remove it again; the comment and original entry must remain. + removed, err := removeFstabLines("/mnt/data") + if err != nil || !removed { + t.Fatalf("remove failed: removed=%v err=%v", removed, err) + } + data, _ := os.ReadFile(path) + if !strings.Contains(string(data), "# keep me") || !strings.Contains(string(data), "UUID=root") { + t.Errorf("surgical edit clobbered other lines:\n%s", data) + } + if findEntry(mustReadFstab(t), "/mnt/data") != nil { + t.Error("entry still present after remove") + } + + // Removing a missing mountpoint reports not-removed. + if removed, _ := removeFstabLines("/nope"); removed { + t.Error("expected removed=false for missing mountpoint") + } +} + +func mustReadFstab(t *testing.T) []FstabEntry { + t.Helper() + e, err := readFstab() + if err != nil { + t.Fatal(err) + } + return e +} + +func TestValidateEntry(t *testing.T) { + ok := FstabEntry{Device: "/dev/sdb1", Mountpoint: "/mnt/data", FSType: "ext4", Options: "defaults"} + if err := validateEntry(ok); err != nil { + t.Errorf("valid entry rejected: %v", err) + } + if err := validateEntry(FstabEntry{Device: "UUID=ab-cd", Mountpoint: "/mnt/x", FSType: "xfs", Options: "rw,noatime"}); err != nil { + t.Errorf("UUID entry rejected: %v", err) + } + bad := []FstabEntry{ + {Device: "/dev/sdb1; rm -rf /", Mountpoint: "/mnt/x", FSType: "ext4", Options: "defaults"}, // shell metachars + {Device: "/dev/sdb1", Mountpoint: "../etc", FSType: "ext4", Options: "defaults"}, // not absolute + {Device: "/dev/sdb1", Mountpoint: "/mnt/../../etc", FSType: "ext4", Options: "defaults"}, // traversal + {Device: "/dev/sdb1", Mountpoint: "/mnt/x", FSType: "ext4!", Options: "defaults"}, // bad fstype + {Device: "/dev/sdb1", Mountpoint: "/mnt/x", FSType: "ext4", Options: "defaults; reboot"}, // bad options + } + for i, e := range bad { + if err := validateEntry(e); err == nil { + t.Errorf("bad entry %d accepted: %+v", i, e) + } + } +} diff --git a/internal/modules/system/command.go b/internal/modules/system/command.go new file mode 100644 index 0000000..5431435 --- /dev/null +++ b/internal/modules/system/command.go @@ -0,0 +1,10 @@ +package system + +// tagSystem groups every system-module operation (info, hostname, time/date, +// locale, power) under one OpenAPI tag, so tags map 1:1 to modules. +const tagSystem = "System" + +var ( + readErrors = []int{401, 403, 500} + writeErrors = []int{400, 401, 403, 500} +) diff --git a/internal/modules/system/hostname.go b/internal/modules/system/hostname.go new file mode 100644 index 0000000..0177bfd --- /dev/null +++ b/internal/modules/system/hostname.go @@ -0,0 +1,57 @@ +package system + +import ( + "context" + "strings" + + "nadir/internal/oscmd" + + "github.com/danielgtaylor/huma/v2" +) + +type HostnameBody struct { + Hostname string `json:"hostname" example:"server01" doc:"System hostname"` +} + +type GetHostnameOutput struct{ Body HostnameBody } +type SetHostnameInput struct{ Body HostnameBody } + +func registerHostname(api huma.API) { + huma.Register(api, huma.Operation{ + OperationID: "system-get-hostname", + Method: "GET", + Path: "/api/system/hostname", + Summary: "Get system hostname", + Description: "Returns the current hostname as reported by hostnamectl.", + Tags: []string{tagSystem}, + Metadata: op("read"), + Errors: readErrors, + }, func(ctx context.Context, _ *struct{}) (*GetHostnameOutput, error) { + name, err := oscmd.Run("hostnamectl", "hostname") + if err != nil { + return nil, huma.Error500InternalServerError("hostnamectl failed", err) + } + return &GetHostnameOutput{Body: HostnameBody{Hostname: name}}, nil + }) + + huma.Register(api, huma.Operation{ + OperationID: "system-set-hostname", + Method: "POST", + Path: "/api/system/hostname", + Summary: "Set system hostname", + Description: "Sets the static hostname via hostnamectl, which owns " + + "/etc/hostname and manages the static/pretty/transient names.", + Tags: []string{tagSystem}, + Metadata: op("write"), + Errors: writeErrors, + }, func(ctx context.Context, in *SetHostnameInput) (*oscmd.StatusOutput, error) { + name := strings.TrimSpace(in.Body.Hostname) + if name == "" { + return nil, huma.Error400BadRequest("empty hostname") + } + if _, err := oscmd.Run("hostnamectl", "set-hostname", name); err != nil { + return nil, huma.Error500InternalServerError("hostnamectl failed", err) + } + return oscmd.OK(), nil + }) +} diff --git a/internal/modules/system/info.go b/internal/modules/system/info.go new file mode 100644 index 0000000..0e4a5b2 --- /dev/null +++ b/internal/modules/system/info.go @@ -0,0 +1,544 @@ +package system + +import ( + "context" + "math" + "net" + "os" + "path/filepath" + "runtime" + "strconv" + "strings" + "sync" + "syscall" + "time" + + "nadir/internal/mounts" + "nadir/internal/oscmd" + + "github.com/danielgtaylor/huma/v2" +) + +// SystemInfoBody is the dashboard overview: OS identity plus live CPU, memory, +// disk, load, network, and temperature readings. Every section is best-effort — +// a source that's unavailable (e.g. no thermal zones in a VM) yields a zero +// value or empty list rather than failing the whole call. +type SystemInfoBody struct { + OS OSInfo `json:"os" doc:"OS and kernel identity"` + CPU CPUInfo `json:"cpu" doc:"Processor model and core count"` + Memory MemoryInfo `json:"memory" doc:"RAM and swap usage in bytes"` + Load LoadInfo `json:"load" doc:"Load averages (1/5/15 min)"` + UptimeSec int64 `json:"uptime_seconds" example:"12490" doc:"Seconds since boot"` + BootTime string `json:"boot_time" example:"2026-06-19T12:08:00Z" doc:"Boot time (RFC3339, UTC)"` + Disks []DiskInfo `json:"disks" doc:"Mounted block-device filesystems"` + NetworkInterfaces []NetInterface `json:"network_interfaces" doc:"Network interfaces and their addresses"` + Temperatures []Temperature `json:"temperatures" doc:"Thermal sensor readings in Celsius"` +} + +type OSInfo struct { + PrettyName string `json:"pretty_name" example:"Fedora Linux 44 (Workstation Edition)" doc:"Distro name from /etc/os-release PRETTY_NAME"` + Kernel string `json:"kernel" example:"7.0.12-201.fc44.x86_64" doc:"Running kernel release (uname -r)"` + Architecture string `json:"architecture" example:"x86_64" doc:"Machine hardware architecture (uname -m)"` + Hostname string `json:"hostname" example:"server01" doc:"System hostname"` +} + +type CPUInfo struct { + Model string `json:"model" example:"AMD Ryzen 7 7840U" doc:"CPU model name"` + LogicalCPUs int `json:"logical_cpus" example:"16" doc:"Number of logical CPUs (cores × threads)"` + MinMHz int `json:"min_mhz" example:"400" doc:"Lowest frequency the scaling governor can select"` + MaxMHz int `json:"max_mhz" example:"5137" doc:"Highest frequency (boost ceiling)"` + CurrentMHz int `json:"current_mhz" example:"3157" doc:"Peak current clock across all cores (instantaneous snapshot; 0 if cpufreq unavailable)"` +} + +type MemoryInfo struct { + TotalBytes uint64 `json:"total_bytes" example:"16384000000"` + AvailableBytes uint64 `json:"available_bytes" example:"8192000000" doc:"Memory available for new allocations without swapping"` + UsedBytes uint64 `json:"used_bytes" example:"8192000000" doc:"total - available"` + SwapTotalBytes uint64 `json:"swap_total_bytes" example:"8589934592"` + SwapFreeBytes uint64 `json:"swap_free_bytes" example:"8589934592"` +} + +type LoadInfo struct { + Load1 float64 `json:"load1" example:"0.42"` + Load5 float64 `json:"load5" example:"0.55"` + Load15 float64 `json:"load15" example:"0.61"` + CPUUsage []CoreUsage `json:"cpu_usage" doc:"Per-core CPU usage percentage (sampled over ~1 s); empty until the first sample completes"` +} + +// CoreUsage holds the usage percentage for a single logical core, computed as +// the delta of non-idle ticks over total ticks between two /proc/stat reads. +type CoreUsage struct { + Core int `json:"core" example:"0" doc:"Logical core index"` + UsagePct float64 `json:"usage_pct" example:"23.4" doc:"Usage percentage (0–100)"` +} + +type DiskInfo struct { + Mountpoint string `json:"mountpoint" example:"/"` + Filesystem string `json:"filesystem" example:"/dev/nvme0n1p2" doc:"Backing device"` + FSType string `json:"fstype" example:"btrfs"` + TotalBytes uint64 `json:"total_bytes" example:"512000000000"` + FreeBytes uint64 `json:"free_bytes" example:"256000000000" doc:"Space available to unprivileged users"` + UsedBytes uint64 `json:"used_bytes" example:"256000000000"` +} + +type NetInterface struct { + Name string `json:"name" example:"eth0"` + MAC string `json:"mac" example:"aa:bb:cc:dd:ee:ff"` + Up bool `json:"up" doc:"Interface is administratively up"` + Addresses []string `json:"addresses" doc:"Assigned addresses in CIDR notation"` +} + +type Temperature struct { + Chip string `json:"chip" example:"k10temp" doc:"hwmon chip name; identifies the source (k10temp/coretemp=CPU, amdgpu/nvidia=GPU, nvme=disk)"` + Label string `json:"label" example:"Tctl" doc:"Per-sensor label, or the chip name when the sensor is unlabelled"` + Celsius float64 `json:"celsius" example:"47.5"` +} + +type GetInfoOutput struct{ Body SystemInfoBody } + +func registerInfo(api huma.API) { + startCPUSampler() + huma.Register(api, huma.Operation{ + OperationID: "system-get-info", + Method: "GET", + Path: "/api/system/info", + Summary: "Get system information", + Description: "Returns an overview for a dashboard: OS/kernel identity, CPU, " + + "memory and swap, mounted disks, load averages, uptime, network " + + "interfaces, and temperatures. All values come from cheap local reads " + + "(/proc, /sys, syscalls) with no D-Bus dependency; each section is " + + "best-effort.", + Tags: []string{tagSystem}, + Metadata: op("read"), + Errors: readErrors, + }, func(ctx context.Context, _ *struct{}) (*GetInfoOutput, error) { + uptime, boot := uptimeAndBoot() + return &GetInfoOutput{Body: SystemInfoBody{ + OS: osInfo(), + CPU: cpuInfo(), + Memory: memInfo(), + Load: loadInfo(), + UptimeSec: uptime, + BootTime: boot.Format(time.RFC3339), + Disks: diskInfo(), + NetworkInterfaces: netInfo(), + Temperatures: tempInfo(), + }}, nil + }) +} + +func osInfo() OSInfo { + host, _ := os.Hostname() + return OSInfo{ + PrettyName: osReleasePretty(), + Kernel: firstLine(oscmd.Run("uname", "-r")), + Architecture: firstLine(oscmd.Run("uname", "-m")), + Hostname: host, + } +} + +// firstLine discards a command error and returns its (already trimmed) output, +// used where a missing value is acceptable. +func firstLine(out string, _ error) string { return out } + +func osReleasePretty() string { + data, err := os.ReadFile("/etc/os-release") + if err != nil { + return "" + } + for line := range strings.SplitSeq(string(data), "\n") { + if v, ok := strings.CutPrefix(line, "PRETTY_NAME="); ok { + return strings.Trim(v, `"`) + } + } + return "" +} + +func cpuInfo() CPUInfo { + data, _ := os.ReadFile("/proc/cpuinfo") + c := CPUInfo{Model: cpuModel(string(data)), LogicalCPUs: runtime.NumCPU()} + c.MinMHz, c.MaxMHz, c.CurrentMHz = cpuFreqMHz("/sys/devices/system/cpu") + // ponytail: cpufreq sysfs is absent on many VMs and stock Ubuntu server + // kernels; fall back to /proc/cpuinfo "cpu MHz" so CurrentMHz isn't 0. + if c.CurrentMHz == 0 { + c.CurrentMHz = cpuinfoMaxMHz(string(data)) + } + return c +} + +// cpuinfoMaxMHz returns the highest "cpu MHz" value across all cores in +// /proc/cpuinfo, rounded to an int. Returns 0 when no such line exists. +func cpuinfoMaxMHz(cpuinfo string) int { + var max float64 + for line := range strings.SplitSeq(cpuinfo, "\n") { + k, v, ok := strings.Cut(line, ":") + if !ok || strings.TrimSpace(k) != "cpu MHz" { + continue + } + if f, err := strconv.ParseFloat(strings.TrimSpace(v), 64); err == nil && f > max { + max = f + } + } + return int(math.Round(max)) +} + +// cpuFreqMHz reads cpufreq sysfs: min/max are stable hardware limits (from +// cpu0); current is the highest scaling_cur_freq across all cores — the "is it +// boosting" figure. Values are kHz in sysfs. Returns zeros when cpufreq is +// absent (e.g. some VMs). +func cpuFreqMHz(root string) (min, max, cur int) { + min = readKHzAsMHz(filepath.Join(root, "cpu0/cpufreq/cpuinfo_min_freq")) + max = readKHzAsMHz(filepath.Join(root, "cpu0/cpufreq/cpuinfo_max_freq")) + cores, _ := filepath.Glob(filepath.Join(root, "cpu[0-9]*/cpufreq/scaling_cur_freq")) + for _, f := range cores { + if v := readKHzAsMHz(f); v > cur { + cur = v + } + } + return min, max, cur +} + +func readKHzAsMHz(path string) int { + khz, err := strconv.Atoi(readTrim(path)) + if err != nil { + return 0 + } + return khz / 1000 +} + +// cpuModel extracts the processor model from /proc/cpuinfo. x86 uses "model +// name"; many ARM boards use "Model" instead, so fall back to it. +func cpuModel(cpuinfo string) string { + var fallback string + for line := range strings.SplitSeq(cpuinfo, "\n") { + k, v, ok := strings.Cut(line, ":") + if !ok { + continue + } + switch strings.TrimSpace(k) { + case "model name": + return strings.TrimSpace(v) + case "Model": + fallback = strings.TrimSpace(v) + } + } + return fallback +} + +func memInfo() MemoryInfo { + data, _ := os.ReadFile("/proc/meminfo") + return parseMeminfo(data) +} + +// parseMeminfo reads the kB values in /proc/meminfo and converts them to bytes. +func parseMeminfo(data []byte) MemoryInfo { + kv := map[string]uint64{} + for line := range strings.SplitSeq(string(data), "\n") { + k, v, ok := strings.Cut(line, ":") + if !ok { + continue + } + fields := strings.Fields(v) // e.g. "16384000 kB" + if len(fields) == 0 { + continue + } + if n, err := strconv.ParseUint(fields[0], 10, 64); err == nil { + kv[k] = n * 1024 // values are in kB + } + } + return MemoryInfo{ + TotalBytes: kv["MemTotal"], + AvailableBytes: kv["MemAvailable"], + UsedBytes: kv["MemTotal"] - kv["MemAvailable"], + SwapTotalBytes: kv["SwapTotal"], + SwapFreeBytes: kv["SwapFree"], + } +} + +func loadInfo() LoadInfo { + data, _ := os.ReadFile("/proc/loadavg") + l := parseLoadavg(string(data)) + l.CPUUsage = cachedCPUUsage() + return l +} + +// parseLoadavg reads the three load averages from /proc/loadavg. +func parseLoadavg(loadavg string) LoadInfo { + f := strings.Fields(loadavg) + if len(f) < 3 { + return LoadInfo{} + } + at := func(i int) float64 { v, _ := strconv.ParseFloat(f[i], 64); return v } + return LoadInfo{Load1: at(0), Load5: at(1), Load15: at(2)} +} + +// --------------------------------------------------------------------------- +// Per-core CPU usage sampler +// --------------------------------------------------------------------------- +// +// /proc/stat exposes cumulative jiffies per core: +// +// cpuN user nice system idle iowait irq softirq steal guest guest_nice +// +// We sample every second, compute the delta, and derive: +// +// usage% = (totalΔ − idleΔ) / totalΔ × 100 +// +// The result is cached behind a RWMutex so the HTTP handler never blocks. + +var ( + usageMu sync.RWMutex + usageCache []CoreUsage +) + +func cachedCPUUsage() []CoreUsage { + usageMu.RLock() + defer usageMu.RUnlock() + // Return a copy so callers can't mutate the cache. + if usageCache == nil { + return nil + } + out := make([]CoreUsage, len(usageCache)) + copy(out, usageCache) + return out +} + +// startCPUSampler launches a goroutine that samples /proc/stat once per second +// for the lifetime of the process. Safe to call multiple times (only the first +// call starts the goroutine). +var samplerOnce sync.Once + +func startCPUSampler() { + samplerOnce.Do(func() { + go cpuSamplerLoop("/proc/stat", 1*time.Second) + }) +} + +func cpuSamplerLoop(statPath string, interval time.Duration) { + prev := readProcStat(statPath) + for { + time.Sleep(interval) + cur := readProcStat(statPath) + usage := computeUsage(prev, cur) + usageMu.Lock() + usageCache = usage + usageMu.Unlock() + prev = cur + } +} + +// cpuCoreTicks holds the cumulative jiffies for one "cpuN" line. +type cpuCoreTicks struct { + core int + total uint64 + idle uint64 +} + +// readProcStat reads /proc/stat and returns per-core tick totals. The +// aggregate "cpu" line (no digit suffix) is skipped. +func readProcStat(path string) []cpuCoreTicks { + data, _ := os.ReadFile(path) + var cores []cpuCoreTicks + for line := range strings.SplitSeq(string(data), "\n") { + if !strings.HasPrefix(line, "cpu") { + continue + } + fields := strings.Fields(line) + if len(fields) < 5 { + continue + } + // Skip the aggregate "cpu" line; we only want "cpu0", "cpu1", … + name := fields[0] + if name == "cpu" { + continue + } + coreIdx, err := strconv.Atoi(strings.TrimPrefix(name, "cpu")) + if err != nil { + continue + } + // Fields: user(1) nice(2) system(3) idle(4) iowait(5) irq(6) softirq(7) steal(8) … + var total, idle uint64 + for _, f := range fields[1:] { + v, _ := strconv.ParseUint(f, 10, 64) + total += v + } + // idle = idle + iowait (indices 4 and 5 in the original line). + if len(fields) > 5 { + v4, _ := strconv.ParseUint(fields[4], 10, 64) + v5, _ := strconv.ParseUint(fields[5], 10, 64) + idle = v4 + v5 + } else { + v4, _ := strconv.ParseUint(fields[4], 10, 64) + idle = v4 + } + cores = append(cores, cpuCoreTicks{core: coreIdx, total: total, idle: idle}) + } + return cores +} + +func computeUsage(prev, cur []cpuCoreTicks) []CoreUsage { + prevMap := make(map[int]cpuCoreTicks, len(prev)) + for _, c := range prev { + prevMap[c.core] = c + } + usage := make([]CoreUsage, 0, len(cur)) + for _, c := range cur { + p, ok := prevMap[c.core] + if !ok { + continue + } + dTotal := c.total - p.total + dIdle := c.idle - p.idle + var pct float64 + if dTotal > 0 { + pct = float64(dTotal-dIdle) / float64(dTotal) * 100 + // Round to one decimal. + pct = math.Round(pct*10) / 10 + } + usage = append(usage, CoreUsage{Core: c.core, UsagePct: pct}) + } + return usage +} + +// uptimeAndBoot reads /proc/uptime (seconds since boot) and derives boot time. +// On any read error it returns zero values rather than failing the request. +func uptimeAndBoot() (int64, time.Time) { + data, err := os.ReadFile("/proc/uptime") + if err != nil { + return 0, time.Time{} + } + fields := strings.Fields(string(data)) + if len(fields) == 0 { + return 0, time.Time{} + } + secs, err := strconv.ParseFloat(fields[0], 64) + if err != nil { + return 0, time.Time{} + } + boot := time.Now().Add(-time.Duration(secs * float64(time.Second))).UTC() + return int64(secs), boot +} + +func diskInfo() []DiskInfo { + entries, err := mounts.Proc() + if err != nil { + return nil + } + disks := []DiskInfo{} + seen := map[string]bool{} + for _, e := range entries { + // Only real block devices; skip pseudo filesystems and snap's squashfs + // loop mounts that would otherwise clutter the list. + if !strings.HasPrefix(e.Device, "/dev/") || e.FSType == "squashfs" || seen[e.Mountpoint] { + continue + } + var st syscall.Statfs_t + if syscall.Statfs(e.Mountpoint, &st) != nil || st.Blocks == 0 { + continue + } + seen[e.Mountpoint] = true + bs := uint64(st.Bsize) + disks = append(disks, DiskInfo{ + Mountpoint: e.Mountpoint, + Filesystem: e.Device, + FSType: e.FSType, + TotalBytes: st.Blocks * bs, + FreeBytes: st.Bavail * bs, + UsedBytes: (st.Blocks - st.Bfree) * bs, + }) + } + return disks +} + +func netInfo() []NetInterface { + ifaces, err := net.Interfaces() + if err != nil { + return nil + } + out := []NetInterface{} + for _, ifi := range ifaces { + addrs, _ := ifi.Addrs() + strs := []string{} + for _, a := range addrs { + strs = append(strs, a.String()) + } + out = append(out, NetInterface{ + Name: ifi.Name, + MAC: ifi.HardwareAddr.String(), + Up: ifi.Flags&net.FlagUp != 0, + Addresses: strs, + }) + } + return out +} + +func tempInfo() []Temperature { + if t := readHwmonTemps("/sys/class/hwmon"); len(t) > 0 { + return t + } + // ponytail: stock Ubuntu server has no coretemp/k10temp loaded, so hwmon + // is empty; thermal_zone exposes ACPI sensors (coarser, no chip name). + return readThermalZones("/sys/class/thermal") +} + +// readThermalZones reads /sys/class/thermal/thermal_zone*/temp as a fallback +// for hosts without hwmon chip drivers. "type" names the zone (e.g. acpitz, +// x86_pkg_temp); used as both chip and label. +func readThermalZones(root string) []Temperature { + zones, _ := filepath.Glob(filepath.Join(root, "thermal_zone*")) + temps := []Temperature{} + for _, dir := range zones { + milli, err := strconv.Atoi(readTrim(filepath.Join(dir, "temp"))) + if err != nil { + continue + } + c := float64(milli) / 1000 + if c <= 0 || c >= 150 { + continue + } + name := readTrim(filepath.Join(dir, "type")) + if name == "" { + name = filepath.Base(dir) + } + temps = append(temps, Temperature{Chip: name, Label: name, Celsius: c}) + } + return temps +} + +// readHwmonTemps walks /sys/class/hwmon, which (unlike /sys/class/thermal, that +// only exposes generic ACPI zones like "acpitz") names each chip — so callers +// can split CPU (k10temp/coretemp) from GPU (amdgpu/nvidia) from disk (nvme). +// Each tempN_input may carry a tempN_label; when absent we fall back to the +// chip name. Best-effort: unreadable, empty, or implausible sensors are skipped. +func readHwmonTemps(root string) []Temperature { + chips, _ := filepath.Glob(filepath.Join(root, "hwmon*")) + temps := []Temperature{} + for _, dir := range chips { + chip := readTrim(filepath.Join(dir, "name")) + inputs, _ := filepath.Glob(filepath.Join(dir, "temp*_input")) + for _, in := range inputs { + milli, err := strconv.Atoi(readTrim(in)) + if err != nil { + continue + } + c := float64(milli) / 1000 + // Disabled/placeholder sensors report absurd values (e.g. -0.15 or + // 179.8 °C). Drop anything outside a plausible band. + if c <= 0 || c >= 150 { + continue + } + label := readTrim(strings.TrimSuffix(in, "_input") + "_label") + if label == "" { + label = chip + } + temps = append(temps, Temperature{Chip: chip, Label: label, Celsius: c}) + } + } + return temps +} + +// readTrim reads a sysfs file and trims it; a missing file yields "". +func readTrim(path string) string { + b, _ := os.ReadFile(path) + return strings.TrimSpace(string(b)) +} diff --git a/internal/modules/system/info_test.go b/internal/modules/system/info_test.go new file mode 100644 index 0000000..2fae7df --- /dev/null +++ b/internal/modules/system/info_test.go @@ -0,0 +1,177 @@ +package system + +import ( + "os" + "path/filepath" + "testing" +) + +func TestReadHwmonTemps(t *testing.T) { + root := t.TempDir() + // k10temp: CPU, labelled Tctl. + write(t, root, "hwmon0", "name", "k10temp") + write(t, root, "hwmon0", "temp1_input", "62125") + write(t, root, "hwmon0", "temp1_label", "Tctl") + // amdgpu: GPU, no label -> falls back to chip name. Also a sensor reading + // an implausible value that must be dropped. + write(t, root, "hwmon1", "name", "amdgpu") + write(t, root, "hwmon1", "temp1_input", "56000") + write(t, root, "hwmon1", "temp2_input", "-150") // disabled placeholder + // empty input -> skipped without error. + write(t, root, "hwmon2", "name", "cros_ec") + write(t, root, "hwmon2", "temp1_input", "") + + got := readHwmonTemps(root) + if len(got) != 2 { + t.Fatalf("want 2 temps, got %d: %+v", len(got), got) + } + if got[0].Chip != "k10temp" || got[0].Label != "Tctl" || got[0].Celsius != 62.125 { + t.Errorf("cpu sensor wrong: %+v", got[0]) + } + if got[1].Chip != "amdgpu" || got[1].Label != "amdgpu" || got[1].Celsius != 56 { + t.Errorf("gpu sensor wrong (label should fall back to chip): %+v", got[1]) + } +} + +func write(t *testing.T, root, chip, file, val string) { + t.Helper() + dir := filepath.Join(root, chip) + if err := os.MkdirAll(dir, 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(dir, file), []byte(val), 0o644); err != nil { + t.Fatal(err) + } +} + +func TestParseMeminfo(t *testing.T) { + data := []byte(`MemTotal: 16000 kB +MemFree: 2000 kB +MemAvailable: 8000 kB +SwapTotal: 4000 kB +SwapFree: 3000 kB +`) + m := parseMeminfo(data) + const kb = 1024 + if m.TotalBytes != 16000*kb { + t.Errorf("TotalBytes = %d", m.TotalBytes) + } + if m.AvailableBytes != 8000*kb { + t.Errorf("AvailableBytes = %d", m.AvailableBytes) + } + if m.UsedBytes != (16000-8000)*kb { + t.Errorf("UsedBytes = %d, want %d", m.UsedBytes, (16000-8000)*kb) + } + if m.SwapTotalBytes != 4000*kb || m.SwapFreeBytes != 3000*kb { + t.Errorf("swap parsed wrong: %+v", m) + } +} + +func TestParseLoadavg(t *testing.T) { + l := parseLoadavg("0.42 0.55 0.61 1/234 5678") + if l.Load1 != 0.42 || l.Load5 != 0.55 || l.Load15 != 0.61 { + t.Errorf("got %+v", l) + } + bad := parseLoadavg("garbage") + if bad.Load1 != 0 || bad.Load5 != 0 || bad.Load15 != 0 { + t.Error("short input should yield zero LoadInfo") + } +} + +func TestCpuFreqMHz(t *testing.T) { + root := t.TempDir() + // Hardware limits live under cpu0 only. + write(t, root, "cpu0/cpufreq", "cpuinfo_min_freq", "400000") // 400 MHz + write(t, root, "cpu0/cpufreq", "cpuinfo_max_freq", "5137000") // 5137 MHz + // Per-core current frequencies (kHz). Core 2 has the peak. + write(t, root, "cpu0/cpufreq", "scaling_cur_freq", "2500000") + write(t, root, "cpu1/cpufreq", "scaling_cur_freq", "1100000") + write(t, root, "cpu2/cpufreq", "scaling_cur_freq", "3200000") + write(t, root, "cpu3/cpufreq", "scaling_cur_freq", "2800000") + + min, max, cur := cpuFreqMHz(root) + if min != 400 { + t.Errorf("min = %d, want 400", min) + } + if max != 5137 { + t.Errorf("max = %d, want 5137", max) + } + if cur != 3200 { + t.Errorf("cur = %d, want 3200 (peak across cores)", cur) + } + + // Absent cpufreq → all zeros. + emptyRoot := t.TempDir() + min, max, cur = cpuFreqMHz(emptyRoot) + if min != 0 || max != 0 || cur != 0 { + t.Errorf("absent cpufreq: got min=%d max=%d cur=%d, want all 0", min, max, cur) + } +} + +func TestCPUModel(t *testing.T) { + x86 := "processor\t: 0\nmodel name\t: AMD Ryzen 7 7840U\ncpu MHz\t: 3000\n" + if got := cpuModel(x86); got != "AMD Ryzen 7 7840U" { + t.Errorf("x86: got %q", got) + } + arm := "processor\t: 0\nModel\t: Raspberry Pi 5 Model B\n" + if got := cpuModel(arm); got != "Raspberry Pi 5 Model B" { + t.Errorf("arm fallback: got %q", got) + } + if got := cpuModel("processor\t: 0\n"); got != "" { + t.Errorf("no model: got %q", got) + } +} + +func TestReadProcStat(t *testing.T) { + content := `cpu 100 20 30 400 10 5 3 2 0 0 +cpu0 50 10 15 200 5 3 1 1 0 0 +cpu1 50 10 15 200 5 2 2 1 0 0 +intr 12345 +` + f := filepath.Join(t.TempDir(), "stat") + if err := os.WriteFile(f, []byte(content), 0o644); err != nil { + t.Fatal(err) + } + cores := readProcStat(f) + if len(cores) != 2 { + t.Fatalf("want 2 cores, got %d", len(cores)) + } + // cpu0: user=50 nice=10 sys=15 idle=200 iowait=5 irq=3 softirq=1 steal=1 guest=0 guest_nice=0 + // total = 50+10+15+200+5+3+1+1+0+0 = 285 + // idle = 200 + 5 = 205 + if cores[0].core != 0 { + t.Errorf("core index = %d, want 0", cores[0].core) + } + if cores[0].total != 285 { + t.Errorf("core0 total = %d, want 285", cores[0].total) + } + if cores[0].idle != 205 { + t.Errorf("core0 idle = %d, want 205", cores[0].idle) + } +} + +func TestComputeUsage(t *testing.T) { + prev := []cpuCoreTicks{ + {core: 0, total: 1000, idle: 800}, + {core: 1, total: 1000, idle: 900}, + } + cur := []cpuCoreTicks{ + {core: 0, total: 1100, idle: 850}, // 100 total delta, 50 idle delta → 50% usage + {core: 1, total: 1200, idle: 950}, // 200 total delta, 50 idle delta → 75% usage + } + usage := computeUsage(prev, cur) + if len(usage) != 2 { + t.Fatalf("want 2 entries, got %d", len(usage)) + } + if usage[0].Core != 0 || usage[0].UsagePct != 50.0 { + t.Errorf("core0: got %+v, want 50%%", usage[0]) + } + if usage[1].Core != 1 || usage[1].UsagePct != 75.0 { + t.Errorf("core1: got %+v, want 75%%", usage[1]) + } + + // Empty prev → empty result. + if got := computeUsage(nil, cur); len(got) != 0 { + t.Errorf("nil prev should yield empty, got %d", len(got)) + } +} diff --git a/internal/modules/system/locale.go b/internal/modules/system/locale.go new file mode 100644 index 0000000..fecc643 --- /dev/null +++ b/internal/modules/system/locale.go @@ -0,0 +1,186 @@ +package system + +import ( + "context" + "slices" + "strings" + + "nadir/internal/oscmd" + + "github.com/danielgtaylor/huma/v2" +) + +type LocaleStatusBody struct { + Lang string `json:"lang" example:"it_IT.UTF-8" doc:"System locale (LANG)"` + VCKeymap string `json:"vc_keymap" example:"it" doc:"Virtual console keymap"` + X11Layout string `json:"x11_layout" example:"it" doc:"X11 keyboard layout"` +} + +type GetLocaleOutput struct{ Body LocaleStatusBody } + +type LocalesOutput struct { + Body struct { + Locales []string `json:"locales" doc:"Available locales"` + } +} + +type KeymapsOutput struct { + Body struct { + Keymaps []string `json:"keymaps" doc:"Available virtual console keymaps"` + } +} + +type SetLocaleInput struct { + Body struct { + Lang string `json:"lang" example:"it_IT.UTF-8" doc:"Locale to set as LANG"` + } +} + +type SetKeymapInput struct { + Body struct { + Keymap string `json:"keymap" example:"it" doc:"Virtual console keymap"` + } +} + +func localeStatus() (LocaleStatusBody, error) { + lines, err := oscmd.RunLines("localectl", "status") + if err != nil { + return LocaleStatusBody{}, err + } + var b LocaleStatusBody + for _, line := range lines { + label, val, ok := strings.Cut(line, ":") + if !ok { + continue + } + switch strings.TrimSpace(label) { + case "System Locale": + for kv := range strings.FieldsSeq(val) { + if k, v, ok := strings.Cut(kv, "="); ok && k == "LANG" { + b.Lang = v + } + } + case "VC Keymap": + b.VCKeymap = strings.TrimSpace(val) + case "X11 Layout": + b.X11Layout = strings.TrimSpace(val) + } + } + return b, nil +} + +func registerLocale(api huma.API) { + huma.Register(api, huma.Operation{ + OperationID: "system-get-locale", + Method: "GET", + Path: "/api/system/locale", + Summary: "Get locale and keyboard layout", + Description: "Returns the system locale (LANG), virtual console keymap, " + + "and X11 layout from localectl.", + Tags: []string{tagSystem}, + Metadata: op("read"), + Errors: readErrors, + }, func(ctx context.Context, _ *struct{}) (*GetLocaleOutput, error) { + b, err := localeStatus() + if err != nil { + return nil, huma.Error500InternalServerError("localectl failed", err) + } + return &GetLocaleOutput{Body: b}, nil + }) + + huma.Register(api, huma.Operation{ + OperationID: "system-list-locales", + Method: "GET", + Path: "/api/system/locales", + Summary: "List available locales", + Description: "Returns every locale the host can be configured to use, " + + "suitable for populating a selector.", + Tags: []string{tagSystem}, + Metadata: op("read"), + Errors: readErrors, + }, func(ctx context.Context, _ *struct{}) (*LocalesOutput, error) { + locales, err := oscmd.RunLines("localectl", "list-locales") + if err != nil { + return nil, huma.Error500InternalServerError("localectl failed", err) + } + out := &LocalesOutput{} + out.Body.Locales = locales + return out, nil + }) + + huma.Register(api, huma.Operation{ + OperationID: "system-set-locale", + Method: "POST", + Path: "/api/system/locale", + Summary: "Set system locale (LANG)", + Description: "Sets LANG via localectl. The value is validated against the " + + "host's locale list, so an unknown locale returns 400.", + Tags: []string{tagSystem}, + Metadata: op("write"), + Errors: writeErrors, + }, func(ctx context.Context, in *SetLocaleInput) (*oscmd.StatusOutput, error) { + lang := strings.TrimSpace(in.Body.Lang) + if lang == "" { + return nil, huma.Error400BadRequest("empty locale") + } + locales, err := oscmd.RunLines("localectl", "list-locales") + if err != nil { + return nil, huma.Error500InternalServerError("localectl failed", err) + } + if !slices.Contains(locales, lang) { + return nil, huma.Error400BadRequest("unknown locale: " + lang) + } + if _, err := oscmd.Run("localectl", "set-locale", "LANG="+lang); err != nil { + return nil, huma.Error500InternalServerError("set-locale failed", err) + } + return oscmd.OK(), nil + }) + + huma.Register(api, huma.Operation{ + OperationID: "system-list-keymaps", + Method: "GET", + Path: "/api/system/keymaps", + Summary: "List available console keymaps", + Description: "Returns every virtual console keymap known to the host.", + Tags: []string{tagSystem}, + Metadata: op("read"), + Errors: readErrors, + }, func(ctx context.Context, _ *struct{}) (*KeymapsOutput, error) { + keymaps, err := oscmd.RunLines("localectl", "list-keymaps") + if err != nil { + return nil, huma.Error500InternalServerError("localectl failed", err) + } + out := &KeymapsOutput{} + out.Body.Keymaps = keymaps + return out, nil + }) + + huma.Register(api, huma.Operation{ + OperationID: "system-set-keymap", + Method: "POST", + Path: "/api/system/keymap", + Summary: "Set virtual console keymap", + Description: "Sets the virtual console keymap via localectl. The value is " + + "validated against the host's keymap list, so an unknown keymap " + + "returns 400.", + Tags: []string{tagSystem}, + Metadata: op("write"), + Errors: writeErrors, + }, func(ctx context.Context, in *SetKeymapInput) (*oscmd.StatusOutput, error) { + km := strings.TrimSpace(in.Body.Keymap) + if km == "" { + return nil, huma.Error400BadRequest("empty keymap") + } + keymaps, err := oscmd.RunLines("localectl", "list-keymaps") + if err != nil { + return nil, huma.Error500InternalServerError("localectl failed", err) + } + if !slices.Contains(keymaps, km) { + return nil, huma.Error400BadRequest("unknown keymap: " + km) + } + if _, err := oscmd.Run("localectl", "set-keymap", km); err != nil { + return nil, huma.Error500InternalServerError("set-keymap failed", err) + } + return oscmd.OK(), nil + }) +} diff --git a/internal/modules/system/module.go b/internal/modules/system/module.go new file mode 100644 index 0000000..5aab28d --- /dev/null +++ b/internal/modules/system/module.go @@ -0,0 +1,35 @@ +package system + +import ( + "nadir/internal/rbac" + + "github.com/danielgtaylor/huma/v2" +) + +const ModuleID = "system" + +type Module struct{} + +func New() *Module { return &Module{} } + +func (m *Module) ID() string { return ModuleID } +func (m *Module) Name() string { return "System" } + +func (m *Module) Permissions() []rbac.Permission { + return []rbac.Permission{rbac.Read, rbac.Write, rbac.Root} +} + +func (m *Module) Register(api huma.API) { + registerInfo(api) + registerHostname(api) + registerTimedate(api) + registerLocale(api) + registerPower(api) +} + +func op(permission string) map[string]any { + return map[string]any{ + "module": ModuleID, + "permission": permission, + } +} diff --git a/internal/modules/system/power.go b/internal/modules/system/power.go new file mode 100644 index 0000000..3cca472 --- /dev/null +++ b/internal/modules/system/power.go @@ -0,0 +1,73 @@ +package system + +import ( + "context" + "regexp" + "strings" + + "nadir/internal/oscmd" + + "github.com/danielgtaylor/huma/v2" +) + +// whenRe matches the shutdown(8) TIME forms we accept: "now", "+", or +// "hh:mm". Validating before exec keeps a value like "-r" or "--help" from +// being read as a flag (shutdown does not honor a "--" separator), per the +// repo's input-validation rule. +var whenRe = regexp.MustCompile(`^(now|\+\d+|([01]?\d|2[0-3]):[0-5]\d)$`) + +type PowerInput struct { + Body struct { + When string `json:"when,omitempty" example:"+1" doc:"When to act, as a shutdown(8) TIME (e.g. \"+5\" minutes, \"23:00\"). Empty or 'now' = immediate."` + } +} + +const powerDescription = "Requires the `root` permission. The response is sent once `shutdown` " + + "accepts the request; for the immediate form it returns before the machine " + + "actually goes down, so a 200 does not guarantee a clean shutdown completed." + +func registerPower(api huma.API) { + huma.Register(api, huma.Operation{ + OperationID: "system-reboot", + Method: "POST", + Path: "/api/system/reboot", + Summary: "Reboot the system", + Description: powerDescription, + Tags: []string{tagSystem}, + Metadata: op("root"), + Errors: []int{400, 401, 403, 500}, + }, func(ctx context.Context, in *PowerInput) (*oscmd.StatusOutput, error) { + return schedulePower("reboot", in.Body.When) + }) + + huma.Register(api, huma.Operation{ + OperationID: "system-poweroff", + Method: "POST", + Path: "/api/system/poweroff", + Summary: "Power off the system", + Description: powerDescription, + Tags: []string{tagSystem}, + Metadata: op("root"), + Errors: []int{400, 401, 403, 500}, + }, func(ctx context.Context, in *PowerInput) (*oscmd.StatusOutput, error) { + return schedulePower("poweroff", in.Body.When) + }) +} + +func schedulePower(action, when string) (*oscmd.StatusOutput, error) { + w := strings.TrimSpace(when) + if w == "" { + w = "now" + } + if !whenRe.MatchString(w) { + return nil, huma.Error400BadRequest("invalid when: " + w) + } + flag := "-h" + if action == "reboot" { + flag = "-r" + } + if _, err := oscmd.Run("shutdown", flag, w); err != nil { + return nil, huma.Error500InternalServerError("shutdown failed", err) + } + return oscmd.OK(), nil +} diff --git a/internal/modules/system/power_test.go b/internal/modules/system/power_test.go new file mode 100644 index 0000000..63a0e36 --- /dev/null +++ b/internal/modules/system/power_test.go @@ -0,0 +1,18 @@ +package system + +import "testing" + +func TestWhenRe(t *testing.T) { + valid := []string{"now", "+0", "+5", "+120", "0:00", "9:30", "23:59", "07:05"} + for _, w := range valid { + if !whenRe.MatchString(w) { + t.Errorf("whenRe.MatchString(%q) = false, want true", w) + } + } + invalid := []string{"", "-r", "-h", "--help", "24:00", "9:60", "+5; reboot", "now ", "5", "1:2"} + for _, w := range invalid { + if whenRe.MatchString(w) { + t.Errorf("whenRe.MatchString(%q) = true, want false", w) + } + } +} diff --git a/internal/modules/system/system_handler_test.go b/internal/modules/system/system_handler_test.go new file mode 100644 index 0000000..ca6b203 --- /dev/null +++ b/internal/modules/system/system_handler_test.go @@ -0,0 +1,235 @@ +package system + +import ( + "encoding/json" + "net/http" + "os" + "reflect" + "testing" + + "nadir/internal/oscmd" + + "github.com/danielgtaylor/huma/v2" + "github.com/danielgtaylor/huma/v2/adapters/humago" + "github.com/danielgtaylor/huma/v2/humatest" +) + +func TestMain(m *testing.M) { + if oscmd.RunHelperProcess() { + return + } + os.Exit(m.Run()) +} + +func TestSystemHandlers(t *testing.T) { + mux := http.NewServeMux() + api := humatest.Wrap(t, humago.New(mux, huma.DefaultConfig("Test", "1.0.0"))) + + registerHostname(api) + registerTimedate(api) + registerLocale(api) + registerPower(api) + registerInfo(api) + + // Mock uname for GET /api/system/info + oscmd.SetMock("uname", func(args []string) oscmd.MockCommand { + if reflect.DeepEqual(args, []string{"-r"}) { + return oscmd.MockCommand{Stdout: "6.9.1-1-test\n", ExitCode: 0} + } + if reflect.DeepEqual(args, []string{"-m"}) { + return oscmd.MockCommand{Stdout: "x86_64\n", ExitCode: 0} + } + return oscmd.MockCommand{ExitCode: 1} + }) + defer oscmd.ClearMocks() + + // 1. Test GET /api/system/info + resp := api.Get("/api/system/info") + if resp.Code != http.StatusOK { + t.Errorf("get info: got %d, want %d", resp.Code, http.StatusOK) + } + + // 2. Test GET & POST /api/system/hostname + oscmd.SetMock("hostnamectl", func(args []string) oscmd.MockCommand { + if reflect.DeepEqual(args, []string{"hostname"}) { + return oscmd.MockCommand{Stdout: "server01\n", ExitCode: 0} + } + if reflect.DeepEqual(args, []string{"set-hostname", "server02"}) { + return oscmd.MockCommand{ExitCode: 0} + } + return oscmd.MockCommand{ExitCode: 1} + }) + + resp = api.Get("/api/system/hostname") + if resp.Code != http.StatusOK { + t.Errorf("get hostname: got %d, want %d", resp.Code, http.StatusOK) + } + var hostnameRes GetHostnameOutput + if err := json.Unmarshal(resp.Body.Bytes(), &hostnameRes.Body); err != nil { + t.Fatal(err) + } + if hostnameRes.Body.Hostname != "server01" { + t.Errorf("got hostname %q, want %q", hostnameRes.Body.Hostname, "server01") + } + + resp = api.Post("/api/system/hostname", struct { + Hostname string `json:"hostname"` + }{ + Hostname: "server02", + }) + if resp.Code != http.StatusOK { + t.Errorf("set hostname: got %d, want %d", resp.Code, http.StatusOK) + } + + // 3. Test GET & POST /api/system/time + oscmd.SetMock("timedatectl", func(args []string) oscmd.MockCommand { + if reflect.DeepEqual(args, []string{"show"}) { + showOut := "Timezone=Europe/Rome\nLocalRTC=no\nNTP=yes\nNTPSynchronized=yes\nCanNTP=yes\n" + return oscmd.MockCommand{Stdout: showOut, ExitCode: 0} + } + if reflect.DeepEqual(args, []string{"list-timezones"}) { + return oscmd.MockCommand{Stdout: "Europe/Rome\nUTC\n", ExitCode: 0} + } + if reflect.DeepEqual(args, []string{"set-timezone", "Europe/Rome"}) { + return oscmd.MockCommand{ExitCode: 0} + } + if reflect.DeepEqual(args, []string{"set-ntp", "true"}) { + return oscmd.MockCommand{ExitCode: 0} + } + if len(args) == 2 && args[0] == "set-time" { + return oscmd.MockCommand{ExitCode: 0} + } + return oscmd.MockCommand{ExitCode: 1} + }) + + resp = api.Get("/api/system/time") + if resp.Code != http.StatusOK { + t.Errorf("get time: got %d, want %d", resp.Code, http.StatusOK) + } + var timeRes GetTimeOutput + if err := json.Unmarshal(resp.Body.Bytes(), &timeRes.Body); err != nil { + t.Fatal(err) + } + if timeRes.Body.Timezone != "Europe/Rome" || !timeRes.Body.NTP { + t.Errorf("got time settings: %+v", timeRes.Body) + } + + resp = api.Get("/api/system/timezones") + if resp.Code != http.StatusOK { + t.Errorf("list timezones: got %d, want %d", resp.Code, http.StatusOK) + } + + resp = api.Post("/api/system/timezone", struct { + Timezone string `json:"timezone"` + }{ + Timezone: "Europe/Rome", + }) + if resp.Code != http.StatusOK { + t.Errorf("set timezone: got %d, want %d", resp.Code, http.StatusOK) + } + + resp = api.Post("/api/system/ntp", struct { + Enabled bool `json:"enabled"` + }{ + Enabled: true, + }) + if resp.Code != http.StatusOK { + t.Errorf("set ntp: got %d, want %d", resp.Code, http.StatusOK) + } + + resp = api.Post("/api/system/time", struct { + Time string `json:"time"` + }{ + Time: "2026-06-20T12:00:00Z", + }) + if resp.Code != http.StatusOK { + t.Errorf("set time: got %d, want %d", resp.Code, http.StatusOK) + } + + // 4. Test GET & POST /api/system/locale + oscmd.SetMock("localectl", func(args []string) oscmd.MockCommand { + if reflect.DeepEqual(args, []string{"status"}) { + statusOut := " System Locale: LANG=it_IT.UTF-8\n VC Keymap: it\n X11 Layout: it\n" + return oscmd.MockCommand{Stdout: statusOut, ExitCode: 0} + } + if reflect.DeepEqual(args, []string{"list-locales"}) { + return oscmd.MockCommand{Stdout: "it_IT.UTF-8\nen_US.UTF-8\n", ExitCode: 0} + } + if reflect.DeepEqual(args, []string{"set-locale", "LANG=it_IT.UTF-8"}) { + return oscmd.MockCommand{ExitCode: 0} + } + if reflect.DeepEqual(args, []string{"list-keymaps"}) { + return oscmd.MockCommand{Stdout: "it\nus\n", ExitCode: 0} + } + if reflect.DeepEqual(args, []string{"set-keymap", "it"}) { + return oscmd.MockCommand{ExitCode: 0} + } + return oscmd.MockCommand{ExitCode: 1} + }) + + resp = api.Get("/api/system/locale") + if resp.Code != http.StatusOK { + t.Errorf("get locale: got %d, want %d", resp.Code, http.StatusOK) + } + var localeRes GetLocaleOutput + if err := json.Unmarshal(resp.Body.Bytes(), &localeRes.Body); err != nil { + t.Fatal(err) + } + if localeRes.Body.Lang != "it_IT.UTF-8" || localeRes.Body.VCKeymap != "it" { + t.Errorf("got locale status: %+v", localeRes.Body) + } + + resp = api.Get("/api/system/locales") + if resp.Code != http.StatusOK { + t.Errorf("list locales: got %d, want %d", resp.Code, http.StatusOK) + } + + resp = api.Post("/api/system/locale", struct { + Lang string `json:"lang"` + }{ + Lang: "it_IT.UTF-8", + }) + if resp.Code != http.StatusOK { + t.Errorf("set locale: got %d, want %d", resp.Code, http.StatusOK) + } + + resp = api.Get("/api/system/keymaps") + if resp.Code != http.StatusOK { + t.Errorf("list keymaps: got %d, want %d", resp.Code, http.StatusOK) + } + + resp = api.Post("/api/system/keymap", struct { + Keymap string `json:"keymap"` + }{ + Keymap: "it", + }) + if resp.Code != http.StatusOK { + t.Errorf("set keymap: got %d, want %d", resp.Code, http.StatusOK) + } + + // 5. Test POST /api/system/reboot and /api/system/poweroff + oscmd.SetMock("shutdown", func(args []string) oscmd.MockCommand { + if reflect.DeepEqual(args, []string{"-r", "now"}) || reflect.DeepEqual(args, []string{"-h", "now"}) { + return oscmd.MockCommand{ExitCode: 0} + } + return oscmd.MockCommand{ExitCode: 1} + }) + + resp = api.Post("/api/system/reboot", struct { + When string `json:"when"` + }{ + When: "now", + }) + if resp.Code != http.StatusOK { + t.Errorf("reboot: got %d, want %d", resp.Code, http.StatusOK) + } + + resp = api.Post("/api/system/poweroff", struct { + When string `json:"when"` + }{ + When: "now", + }) + if resp.Code != http.StatusOK { + t.Errorf("poweroff: got %d, want %d", resp.Code, http.StatusOK) + } +} diff --git a/internal/modules/system/timedate.go b/internal/modules/system/timedate.go new file mode 100644 index 0000000..29d5c5c --- /dev/null +++ b/internal/modules/system/timedate.go @@ -0,0 +1,200 @@ +package system + +import ( + "context" + "slices" + "strconv" + "strings" + "time" + + "nadir/internal/oscmd" + + "github.com/danielgtaylor/huma/v2" +) + +type TimeStatusBody struct { + Timezone string `json:"timezone" example:"Europe/Rome" doc:"IANA timezone name"` + LocalRTC bool `json:"local_rtc" doc:"Hardware clock kept in local time instead of UTC"` + NTP bool `json:"ntp" doc:"Network time synchronization enabled"` + NTPSynchronized bool `json:"ntp_synchronized" doc:"Clock is currently synchronized"` + CanNTP bool `json:"can_ntp" doc:"An NTP service is available on this host"` + Time string `json:"time" example:"2026-06-19T13:36:31Z" doc:"Current system time (RFC3339, UTC)"` +} + +type GetTimeOutput struct{ Body TimeStatusBody } + +type TimezonesOutput struct { + Body struct { + Timezones []string `json:"timezones" doc:"Available IANA timezone names"` + } +} + +type SetTimezoneInput struct { + Body struct { + Timezone string `json:"timezone" example:"Europe/Rome" doc:"IANA timezone name"` + } +} + +type SetNTPInput struct { + Body struct { + Enabled bool `json:"enabled" doc:"Enable network time synchronization"` + } +} + +type SetTimeInput struct { + Body struct { + Time string `json:"time" example:"2026-06-19T13:36:31Z" doc:"New time (RFC3339). Requires NTP disabled."` + } +} + +func timedatectlShow() (map[string]string, error) { + lines, err := oscmd.RunLines("timedatectl", "show") + if err != nil { + return nil, err + } + return oscmd.ParseKV(lines), nil +} + +// readTimeStatus builds the current time/timezone/sync status. Shared by the +// GET endpoint and the write endpoints, which return the resulting state so +// clients render ground truth instead of guessing after an opaque "ok". Note +// NTPSynchronized lags NTP by seconds-to-minutes while the NTP daemon converges. +func readTimeStatus() (TimeStatusBody, error) { + m, err := timedatectlShow() + if err != nil { + return TimeStatusBody{}, err + } + return TimeStatusBody{ + Timezone: m["Timezone"], + LocalRTC: m["LocalRTC"] == "yes", + NTP: m["NTP"] == "yes", + NTPSynchronized: m["NTPSynchronized"] == "yes", + CanNTP: m["CanNTP"] == "yes", + Time: time.Now().UTC().Format(time.RFC3339), + }, nil +} + +func registerTimedate(api huma.API) { + huma.Register(api, huma.Operation{ + OperationID: "system-get-time", + Method: "GET", + Path: "/api/system/time", + Summary: "Get time, timezone and sync status", + Description: "Returns the current UTC time plus timezone and " + + "synchronization state from timedatectl.", + Tags: []string{tagSystem}, + Metadata: op("read"), + Errors: readErrors, + }, func(ctx context.Context, _ *struct{}) (*GetTimeOutput, error) { + body, err := readTimeStatus() + if err != nil { + return nil, huma.Error500InternalServerError("timedatectl failed", err) + } + return &GetTimeOutput{Body: body}, nil + }) + + huma.Register(api, huma.Operation{ + OperationID: "system-list-timezones", + Method: "GET", + Path: "/api/system/timezones", + Summary: "List available timezones", + Description: "Returns every IANA timezone name known to the host, " + + "suitable for populating a selector.", + Tags: []string{tagSystem}, + Metadata: op("read"), + Errors: readErrors, + }, func(ctx context.Context, _ *struct{}) (*TimezonesOutput, error) { + zones, err := oscmd.RunLines("timedatectl", "list-timezones") + if err != nil { + return nil, huma.Error500InternalServerError("timedatectl failed", err) + } + out := &TimezonesOutput{} + out.Body.Timezones = zones + return out, nil + }) + + huma.Register(api, huma.Operation{ + OperationID: "system-set-timezone", + Method: "POST", + Path: "/api/system/timezone", + Summary: "Set system timezone", + Description: "Sets the timezone via timedatectl. The value is validated " + + "against the host's timezone list, so an unknown name returns 400.", + Tags: []string{tagSystem}, + Metadata: op("write"), + Errors: writeErrors, + }, func(ctx context.Context, in *SetTimezoneInput) (*oscmd.StatusOutput, error) { + tz := strings.TrimSpace(in.Body.Timezone) + if tz == "" { + return nil, huma.Error400BadRequest("empty timezone") + } + zones, err := oscmd.RunLines("timedatectl", "list-timezones") + if err != nil { + return nil, huma.Error500InternalServerError("timedatectl failed", err) + } + if !slices.Contains(zones, tz) { + return nil, huma.Error400BadRequest("unknown timezone: " + tz) + } + if _, err := oscmd.Run("timedatectl", "set-timezone", tz); err != nil { + return nil, huma.Error500InternalServerError("set-timezone failed", err) + } + return oscmd.OK(), nil + }) + + huma.Register(api, huma.Operation{ + OperationID: "system-set-ntp", + Method: "POST", + Path: "/api/system/ntp", + Summary: "Enable or disable time synchronization", + Description: "Toggles network time synchronization via " + + "`timedatectl set-ntp`. Selecting specific NTP servers is not yet " + + "supported. Returns the resulting time status: on enable, `ntp` is " + + "true immediately, but `ntp_synchronized` stays false until the NTP " + + "daemon converges (seconds to minutes).", + Tags: []string{tagSystem}, + Metadata: op("write"), + Errors: []int{400, 401, 403, 409, 500}, + }, func(ctx context.Context, in *SetNTPInput) (*GetTimeOutput, error) { + // Enabling is a no-op when no NTP daemon is installed (CanNTP=no): + // timedatectl reports success but nothing syncs. Reject it clearly. + if in.Body.Enabled { + if m, err := timedatectlShow(); err == nil && m["CanNTP"] != "yes" { + return nil, huma.Error409Conflict("cannot enable NTP: no NTP service available on this host") + } + } + val := strconv.FormatBool(in.Body.Enabled) + if _, err := oscmd.Run("timedatectl", "set-ntp", val); err != nil { + return nil, huma.Error500InternalServerError("set-ntp failed", err) + } + body, err := readTimeStatus() + if err != nil { + return nil, huma.Error500InternalServerError("set-ntp succeeded but reading status failed", err) + } + return &GetTimeOutput{Body: body}, nil + }) + + huma.Register(api, huma.Operation{ + OperationID: "system-set-time", + Method: "POST", + Path: "/api/system/time", + Summary: "Set system time manually", + Description: "Sets the clock to an explicit RFC3339 time. This only works " + + "when NTP is disabled; otherwise timedatectl refuses and the call " + + "returns 409.", + Tags: []string{tagSystem}, + Metadata: op("write"), + Errors: []int{400, 401, 403, 409, 500}, + }, func(ctx context.Context, in *SetTimeInput) (*oscmd.StatusOutput, error) { + t, err := time.Parse(time.RFC3339, strings.TrimSpace(in.Body.Time)) + if err != nil { + return nil, huma.Error400BadRequest("time must be RFC3339", err) + } + // timedatectl set-time interprets its argument as local wall-clock time. + stamp := t.Local().Format("2006-01-02 15:04:05") + if _, err := oscmd.Run("timedatectl", "set-time", stamp); err != nil { + // timedatectl refuses while NTP is active; make that actionable. + return nil, huma.Error409Conflict("set-time failed (disable NTP first?)", err) + } + return oscmd.OK(), nil + }) +} diff --git a/internal/modules/terminal/terminal.go b/internal/modules/terminal/terminal.go new file mode 100644 index 0000000..eb81ab0 --- /dev/null +++ b/internal/modules/terminal/terminal.go @@ -0,0 +1,210 @@ +package terminal + +import ( + "context" + "encoding/json" + "net/http" + "os/exec" + "strings" + "sync/atomic" + "time" + + "nadir/internal/auth" + "nadir/internal/module" + "nadir/internal/rbac" + + "github.com/coder/websocket" + "github.com/creack/pty" + "github.com/danielgtaylor/huma/v2" + "github.com/danielgtaylor/huma/v2/adapters/humago" +) + +// tagTerminal is the OpenAPI tag for this module (registered in server.go), +// keeping tags 1:1 with modules per the project convention. +const tagTerminal = "Terminal" + +const ( + // maxTerminals caps concurrent shells so a buggy frontend or careless admin + // can't pile up PTYs (guideline: limit everything). Raise if it's ever a real + // limit in practice. + maxTerminals = 10 + // idleTimeout closes a session after this long with no I/O in either + // direction, reclaiming abandoned shells. + idleTimeout = 15 * time.Minute +) + +// terminalSem is the concurrency limiter; a slot is held for the life of a session. +var terminalSem = make(chan struct{}, maxTerminals) + +type terminalModule struct { + sessions *auth.SessionStore +} + +// New creates a new Terminal module that allows interactive shell access. +func New(sessions *auth.SessionStore) module.Module { + return &terminalModule{sessions: sessions} +} + +func (m *terminalModule) ID() string { return "terminal" } +func (m *terminalModule) Name() string { return "Terminal" } + +func (m *terminalModule) Permissions() []rbac.Permission { + return []rbac.Permission{rbac.Root} +} + +type TerminalInput struct { + Ctx huma.Context `json:"-"` +} + +// Resolve extracts the huma.Context into the input struct. +func (i *TerminalInput) Resolve(ctx huma.Context) []error { + i.Ctx = ctx + return nil +} + +type resizeMessage struct { + Cols uint16 `json:"cols"` + Rows uint16 `json:"rows"` +} + +func (m *terminalModule) Register(api huma.API) { + huma.Register(api, huma.Operation{ + OperationID: "terminal-connect", + Method: "GET", + Path: "/api/terminal", + Summary: "Connect to an interactive terminal", + Description: "Upgrades the connection to a WebSocket and spawns a PTY shell as the logged-in user. Send JSON `{cols, rows}` text messages to resize, and raw binary/text messages for stdin. This is a raw WebSocket endpoint — it cannot be exercised from the API docs \"Try it\" panel; use a WebSocket client.", + Tags: []string{tagTerminal}, + Metadata: map[string]any{"module": m.ID(), "permission": string(rbac.Root)}, + Errors: []int{401, 403, 426, 500}, + }, func(ctx context.Context, in *TerminalInput) (*struct{}, error) { + // The RBAC middleware already authenticated this request and enforced the + // "root" permission before we got here. We re-read the session only to get + // the username for `su`; the 401 below is the fallback for when the module + // is mounted without the middleware (e.g. in unit tests). + cookie, err := huma.ReadCookie(in.Ctx, "nadir_session_id") + if err != nil || cookie == nil { + return nil, huma.Error401Unauthorized("unauthorized") + } + + sess, ok := m.sessions.GetByToken(cookie.Value) + if !ok { + return nil, huma.Error401Unauthorized("unauthorized") + } + + req, res := humago.Unwrap(in.Ctx) + if req == nil || res == nil { + return nil, huma.Error500InternalServerError("missing http context") + } + + // Reject plain GETs (e.g. the docs "Try it" button) with a clear 426 rather + // than letting websocket.Accept emit a raw protocol-violation error. + if !strings.EqualFold(req.Header.Get("Upgrade"), "websocket") { + return nil, huma.NewError(http.StatusUpgradeRequired, + "this endpoint requires a WebSocket connection; connect with a WebSocket client") + } + + // InsecureSkipVerify is deliberately NOT set: coder/websocket then enforces + // that the Origin host matches the request Host, rejecting cross-site upgrade + // attempts. Defense-in-depth on top of the SameSite=Strict session cookie — + // important because this endpoint hands out an interactive shell. + conn, err := websocket.Accept(res, req, nil) + if err != nil { + // websocket.Accept already wrote the error response (e.g. 403 on an + // Origin mismatch). Just stop; writing again would corrupt the response. + return nil, nil + } + defer conn.CloseNow() + + // Bound concurrent shells: take a slot or reject (don't pile up PTYs). + select { + case terminalSem <- struct{}{}: + defer func() { <-terminalSem }() + default: + conn.Close(websocket.StatusTryAgainLater, "too many terminal sessions") + return nil, nil + } + + // Launch the user's login shell via su. + // "su - " ensures we get their actual environment and shell. + cmd := exec.CommandContext(req.Context(), "su", "-", sess.Username) + + // Start the command with a PTY. + ptmx, err := pty.Start(cmd) + if err != nil { + conn.Close(websocket.StatusInternalError, "failed to start pty") + return nil, nil + } + defer ptmx.Close() + + // lastActive is bumped by both pumps; the watchdog uses it to close idle + // sessions. Output activity (e.g. `top`) counts, so it isn't killed. + var lastActive atomic.Int64 + lastActive.Store(time.Now().UnixNano()) + go func() { + tick := time.NewTicker(idleTimeout / 4) + defer tick.Stop() + for { + select { + case <-req.Context().Done(): + return + case <-tick.C: + if time.Since(time.Unix(0, lastActive.Load())) > idleTimeout { + conn.Close(websocket.StatusGoingAway, "idle timeout") + return + } + } + } + }() + + // Pump stdout/stderr from PTY to WebSocket. + go func() { + buf := make([]byte, 8192) + for { + n, err := ptmx.Read(buf) + if err != nil { + break + } + lastActive.Store(time.Now().UnixNano()) + // We write PTY output as binary messages. The frontend (e.g., xterm.js) + // can handle UTF-8 binary or text transparently. + err = conn.Write(req.Context(), websocket.MessageBinary, buf[:n]) + if err != nil { + break + } + } + conn.Close(websocket.StatusNormalClosure, "") + }() + + // Pump stdin and resize commands from WebSocket to PTY. + for { + typ, b, err := conn.Read(req.Context()) + if err != nil { + break + } + lastActive.Store(time.Now().UnixNano()) + + if typ == websocket.MessageText { + var resize resizeMessage + if err := json.Unmarshal(b, &resize); err == nil && resize.Cols > 0 && resize.Rows > 0 { + // Handle resize + _ = pty.Setsize(ptmx, &pty.Winsize{ + Cols: resize.Cols, + Rows: resize.Rows, + }) + continue + } + } + + // If not a valid resize message, or if it's MessageBinary, pass to PTY stdin. + _, _ = ptmx.Write(b) + } + + // The read loop has ended (client gone or PTY closed). Tear down the shell + // and reap it: closing the PTY sends EOF, Kill covers shells that ignore it. + _ = ptmx.Close() + _ = cmd.Process.Kill() + _ = cmd.Wait() + return nil, nil + }) +} diff --git a/internal/modules/terminal/terminal_test.go b/internal/modules/terminal/terminal_test.go new file mode 100644 index 0000000..400b472 --- /dev/null +++ b/internal/modules/terminal/terminal_test.go @@ -0,0 +1,124 @@ +package terminal + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "nadir/internal/auth" + + "github.com/coder/websocket" + "github.com/danielgtaylor/huma/v2" + "github.com/danielgtaylor/huma/v2/adapters/humago" +) + +func TestTerminalConnectUnauthorized(t *testing.T) { + sessions, err := auth.NewSessionStore("file::memory:?cache=shared") + if err != nil { + t.Fatalf("session db: %v", err) + } + + mux := http.NewServeMux() + api := humago.New(mux, huma.DefaultConfig("Test", "1.0.0")) + mod := New(sessions) + mod.Register(api) + + srv := httptest.NewServer(mux) + defer srv.Close() + + wsURL := strings.Replace(srv.URL, "http://", "ws://", 1) + "/api/terminal" + + // Try without cookie + ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second) + defer cancel() + + _, resp, err := websocket.Dial(ctx, wsURL, nil) + if err == nil { + t.Fatal("expected error, got success") + } + if resp != nil && resp.StatusCode != http.StatusUnauthorized { + t.Errorf("expected 401, got %d", resp.StatusCode) + } +} + +func TestTerminalConnectPlainGET(t *testing.T) { + // An authenticated but non-WebSocket GET (e.g. the docs "Try it" button) must + // get a clean 426 Upgrade Required, not a raw websocket protocol error. + sessions, err := auth.NewSessionStore("file::memory:?cache=shared") + if err != nil { + t.Fatalf("session db: %v", err) + } + token, err := sessions.Create("root") + if err != nil { + t.Fatalf("create session: %v", err) + } + + mux := http.NewServeMux() + api := humago.New(mux, huma.DefaultConfig("Test", "1.0.0")) + New(sessions).Register(api) + + srv := httptest.NewServer(mux) + defer srv.Close() + + req, _ := http.NewRequest(http.MethodGet, srv.URL+"/api/terminal", nil) + req.Header.Set("Cookie", "nadir_session_id="+token) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("GET failed: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusUpgradeRequired { + t.Errorf("expected 426, got %d", resp.StatusCode) + } +} + +func TestTerminalConnectAuthorized(t *testing.T) { + // Root or specific system configs might cause PTY or su to fail if run in constrained environments. + // But we expect the websocket upgrade to succeed and then maybe close with an error if PTY fails. + sessions, err := auth.NewSessionStore("file::memory:?cache=shared") + if err != nil { + t.Fatalf("session db: %v", err) + } + + token, err := sessions.Create("root") + if err != nil { + t.Fatalf("create session: %v", err) + } + + mux := http.NewServeMux() + api := humago.New(mux, huma.DefaultConfig("Test", "1.0.0")) + mod := New(sessions) + mod.Register(api) + + srv := httptest.NewServer(mux) + defer srv.Close() + + wsURL := strings.Replace(srv.URL, "http://", "ws://", 1) + "/api/terminal" + + ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second) + defer cancel() + + opts := &websocket.DialOptions{ + HTTPHeader: http.Header{}, + } + opts.HTTPHeader.Set("Cookie", "nadir_session_id="+token) + + conn, resp, err := websocket.Dial(ctx, wsURL, opts) + if err != nil { + // Depending on the test environment, "su" or "pty" might fail, but + // the websocket upgrade itself should succeed before it drops. + // If it fails to upgrade, that's a real error. + t.Fatalf("dial failed: %v (status %d)", err, resp.StatusCode) + } + defer conn.CloseNow() + + // The connection was upgraded successfully. + // If PTY allocation failed or su failed, the server might close the connection immediately. + // We just verify the upgrade succeeded. + if resp.StatusCode != http.StatusSwitchingProtocols { + t.Errorf("expected 101, got %d", resp.StatusCode) + } +} diff --git a/internal/modules/users/module.go b/internal/modules/users/module.go new file mode 100644 index 0000000..d3f6b1e --- /dev/null +++ b/internal/modules/users/module.go @@ -0,0 +1,30 @@ +package users + +import ( + "nadir/internal/rbac" + + "github.com/danielgtaylor/huma/v2" +) + +const ModuleID = "users" + +type Module struct{} + +func New() *Module { return &Module{} } + +func (m *Module) ID() string { return ModuleID } +func (m *Module) Name() string { return "Users" } + +// Permissions: read to list/inspect accounts; write to create and change +// passwords; root to delete (irreversible). +func (m *Module) Permissions() []rbac.Permission { + return []rbac.Permission{rbac.Read, rbac.Write, rbac.Root} +} + +func (m *Module) Register(api huma.API) { + registerUsers(api) +} + +func op(permission string) map[string]any { + return map[string]any{"module": ModuleID, "permission": permission} +} diff --git a/internal/modules/users/users.go b/internal/modules/users/users.go new file mode 100644 index 0000000..8a85c3f --- /dev/null +++ b/internal/modules/users/users.go @@ -0,0 +1,373 @@ +package users + +import ( + "context" + "os" + "path/filepath" + "regexp" + "strconv" + "strings" + + "nadir/internal/oscmd" + + "github.com/danielgtaylor/huma/v2" +) + +const tagUsers = "Users" + +var passwdPath = "/etc/passwd" + +// systemUIDMax is the conventional upper bound for system (non-login) accounts; +// regular users start at 1000 on Debian/Fedora. Used only to flag accounts for +// the client, not to filter them out. +const systemUIDMax = 1000 + +var ( + readErrors = []int{401, 403, 500} + writeErrors = []int{400, 401, 403, 404, 409, 500} +) + +// userNameRe matches valid Linux usernames (the useradd default NAME_REGEX). +// Starting with a letter/underscore also rejects leading-dash flag injection. +var userNameRe = regexp.MustCompile(`^[a-z_][a-z0-9_-]{0,31}\$?$`) + +// User mirrors one /etc/passwd entry. +type User struct { + Username string `json:"username" example:"alice" doc:"Login name"` + UID int `json:"uid" example:"1000" doc:"User ID"` + GID int `json:"gid" example:"1000" doc:"Primary group ID"` + Comment string `json:"comment" example:"Alice Smith" doc:"GECOS comment (often the full name)"` + Home string `json:"home" example:"/home/alice" doc:"Home directory"` + Shell string `json:"shell" example:"/bin/bash" doc:"Login shell"` + System bool `json:"system" doc:"True for system accounts (uid < 1000)"` +} + +type ListUsersOutput struct { + Body struct { + Users []User `json:"users" doc:"All accounts from /etc/passwd"` + } +} + +type GetUserOutput struct{ Body User } + +type UserPath struct { + Username string `path:"username" example:"alice" doc:"Login name"` +} + +type CreateUserInput struct { + Body struct { + Username string `json:"username" example:"alice" doc:"Login name"` + Comment string `json:"comment,omitempty" example:"Alice Smith" doc:"GECOS comment"` + Shell string `json:"shell,omitempty" example:"/bin/bash" doc:"Login shell"` + Home string `json:"home,omitempty" example:"/home/alice" doc:"Home directory (defaults to /home/)"` + CreateHome bool `json:"create_home,omitempty" doc:"Create the home directory (useradd -m)"` + System bool `json:"system,omitempty" doc:"Create a system account (useradd --system)"` + } +} + +type DeleteUserInput struct { + Username string `path:"username" example:"alice" doc:"Login name"` + RemoveHome bool `query:"remove_home" doc:"Also remove the home directory and mail spool (userdel -r)"` +} + +type SetPasswordInput struct { + Username string `path:"username" example:"alice" doc:"Login name"` + Body struct { + Password string `json:"password" doc:"New password (sent to chpasswd over stdin, never argv)"` + } +} + +type SetGroupsInput struct { + Username string `path:"username" example:"alice" doc:"Login name"` + Body struct { + Groups []string `json:"groups" doc:"Supplementary groups; replaces the user's full supplementary set"` + } +} + +type UserGroupsOutput struct { + Body struct { + Username string `json:"username" example:"alice"` + Groups []string `json:"groups" doc:"All groups the user belongs to (primary + supplementary)"` + } +} + +func registerUsers(api huma.API) { + huma.Register(api, huma.Operation{ + OperationID: "users-list", + Method: "GET", + Path: "/api/users", + Summary: "List user accounts", + Description: "Returns every account in /etc/passwd, including system " + + "accounts (flagged via `system`).", + Tags: []string{tagUsers}, + Metadata: op("read"), + Errors: readErrors, + }, func(ctx context.Context, _ *struct{}) (*ListUsersOutput, error) { + list, err := listUsers() + if err != nil { + return nil, huma.Error500InternalServerError("read "+passwdPath+" failed", err) + } + out := &ListUsersOutput{} + out.Body.Users = list + return out, nil + }) + + huma.Register(api, huma.Operation{ + OperationID: "users-get", + Method: "GET", + Path: "/api/users/{username}", + Summary: "Get a single user", + Description: "Returns one account by login name. 404 if it does not exist.", + Tags: []string{tagUsers}, + Metadata: op("read"), + Errors: []int{400, 401, 403, 404, 500}, + }, func(ctx context.Context, in *UserPath) (*GetUserOutput, error) { + if err := validateUsername(in.Username); err != nil { + return nil, err + } + u, ok, err := lookupUser(in.Username) + if err != nil { + return nil, huma.Error500InternalServerError("read "+passwdPath+" failed", err) + } + if !ok { + return nil, huma.Error404NotFound("user not found: " + in.Username) + } + return &GetUserOutput{Body: u}, nil + }) + + huma.Register(api, huma.Operation{ + OperationID: "users-create", + Method: "POST", + Path: "/api/users", + Summary: "Create a user account", + Description: "Creates an account via useradd. The new account has a locked " + + "password until one is set via the password endpoint. 409 if the user " + + "already exists.", + Tags: []string{tagUsers}, + Metadata: op("write"), + Errors: writeErrors, + }, func(ctx context.Context, in *CreateUserInput) (*GetUserOutput, error) { + if err := validateUsername(in.Body.Username); err != nil { + return nil, err + } + // -c/-s/-d are option *arguments*, so the "--" separator doesn't shield + // them. Validate at the boundary: a ':' or newline in the GECOS field + // would corrupt /etc/passwd; shell/home must be absolute paths. + if strings.ContainsAny(in.Body.Comment, ":\n") { + return nil, huma.Error400BadRequest("comment may not contain ':' or newlines") + } + if in.Body.Shell != "" && !filepath.IsAbs(in.Body.Shell) { + return nil, huma.Error400BadRequest("shell must be an absolute path") + } + if in.Body.Home != "" && !filepath.IsAbs(in.Body.Home) { + return nil, huma.Error400BadRequest("home must be an absolute path") + } + if _, ok, err := lookupUser(in.Body.Username); err != nil { + return nil, huma.Error500InternalServerError("read "+passwdPath+" failed", err) + } else if ok { + return nil, huma.Error409Conflict("user already exists: " + in.Body.Username) + } + + args := []string{} + if in.Body.System { + args = append(args, "--system") + } + if in.Body.CreateHome { + args = append(args, "-m") + } + if in.Body.Comment != "" { + args = append(args, "-c", in.Body.Comment) + } + if in.Body.Shell != "" { + args = append(args, "-s", in.Body.Shell) + } + if in.Body.Home != "" { + args = append(args, "-d", in.Body.Home) + } + args = append(args, "--", in.Body.Username) + if _, err := oscmd.Run("useradd", args...); err != nil { + return nil, huma.Error500InternalServerError("useradd failed", err) + } + + u, ok, err := lookupUser(in.Body.Username) + if err != nil || !ok { + return nil, huma.Error500InternalServerError("user created but could not be read back", err) + } + return &GetUserOutput{Body: u}, nil + }) + + huma.Register(api, huma.Operation{ + OperationID: "users-delete", + Method: "DELETE", + Path: "/api/users/{username}", + Summary: "Delete a user account", + Description: "Removes an account via userdel. Pass ?remove_home=true to " + + "also delete the home directory. 404 if the user does not exist.", + Tags: []string{tagUsers}, + // Deleting an account is irreversible - gated behind root, not write. + Metadata: op("root"), + Errors: []int{400, 401, 403, 404, 500}, + }, func(ctx context.Context, in *DeleteUserInput) (*oscmd.StatusOutput, error) { + if err := validateUsername(in.Username); err != nil { + return nil, err + } + if _, ok, err := lookupUser(in.Username); err != nil { + return nil, huma.Error500InternalServerError("read "+passwdPath+" failed", err) + } else if !ok { + return nil, huma.Error404NotFound("user not found: " + in.Username) + } + + args := []string{} + if in.RemoveHome { + args = append(args, "-r") + } + args = append(args, "--", in.Username) + if _, err := oscmd.Run("userdel", args...); err != nil { + return nil, huma.Error500InternalServerError("userdel failed", err) + } + return oscmd.OK(), nil + }) + + huma.Register(api, huma.Operation{ + OperationID: "users-set-password", + Method: "POST", + Path: "/api/users/{username}/password", + Summary: "Set a user's password", + Description: "Sets the password via chpasswd (fed over stdin, so the secret " + + "never appears in the process list). 404 if the user does not exist. " + + "Requires the `root` permission: resetting a privileged account's " + + "password (e.g. root) is a full-system action, not a routine write.", + Tags: []string{tagUsers}, + Metadata: op("root"), + Errors: []int{400, 401, 403, 404, 500}, + }, func(ctx context.Context, in *SetPasswordInput) (*oscmd.StatusOutput, error) { + if err := validateUsername(in.Username); err != nil { + return nil, err + } + if in.Body.Password == "" { + return nil, huma.Error400BadRequest("empty password") + } + // chpasswd reads one "name:password" line per stdin line, so a newline in + // the password would inject a second line and set another account's password. + if strings.ContainsAny(in.Body.Password, "\n\r") { + return nil, huma.Error400BadRequest("password may not contain newlines") + } + if _, ok, err := lookupUser(in.Username); err != nil { + return nil, huma.Error500InternalServerError("read "+passwdPath+" failed", err) + } else if !ok { + return nil, huma.Error404NotFound("user not found: " + in.Username) + } + // chpasswd reads "name:password" lines from stdin. + if _, err := oscmd.RunStdin(in.Username+":"+in.Body.Password+"\n", "chpasswd"); err != nil { + return nil, huma.Error500InternalServerError("chpasswd failed", err) + } + return oscmd.OK(), nil + }) + + huma.Register(api, huma.Operation{ + OperationID: "users-set-groups", + Method: "PUT", + Path: "/api/users/{username}/groups", + Summary: "Set a user's supplementary groups", + Description: "Replaces the user's full supplementary group set via " + + "`usermod -G` (an empty list removes them from all supplementary " + + "groups). Returns the resulting group membership. 404 if the user " + + "does not exist; 400 if any named group is missing. Requires the " + + "`root` permission: adding an account to wheel/sudo/docker is a " + + "privilege grant, not a routine write.", + Tags: []string{tagUsers}, + Metadata: op("root"), + Errors: writeErrors, + }, func(ctx context.Context, in *SetGroupsInput) (*UserGroupsOutput, error) { + if err := validateUsername(in.Username); err != nil { + return nil, err + } + for _, g := range in.Body.Groups { + // Group names follow the same rule as usernames; this also blocks + // flag injection and a stray comma turning one group into two. + if !userNameRe.MatchString(g) { + return nil, huma.Error400BadRequest("invalid group name: " + g) + } + } + if _, ok, err := lookupUser(in.Username); err != nil { + return nil, huma.Error500InternalServerError("read "+passwdPath+" failed", err) + } else if !ok { + return nil, huma.Error404NotFound("user not found: " + in.Username) + } + if _, err := oscmd.Run("usermod", "-G", strings.Join(in.Body.Groups, ","), "--", in.Username); err != nil { + return nil, huma.Error400BadRequest("usermod failed (does every named group exist?)", err) + } + + // id -nG lists all groups (primary + supplementary) the user now has. + out, err := oscmd.Run("id", "-nG", in.Username) + if err != nil { + return nil, huma.Error500InternalServerError("groups set but read-back failed", err) + } + res := &UserGroupsOutput{} + res.Body.Username = in.Username + res.Body.Groups = strings.Fields(out) + return res, nil + }) +} + +// validateUsername rejects empty, flag-like, or malformed names before exec. +func validateUsername(name string) error { + if !userNameRe.MatchString(name) { + return huma.Error400BadRequest("invalid username: " + name) + } + return nil +} + +func listUsers() ([]User, error) { + data, err := os.ReadFile(passwdPath) + if err != nil { + return nil, err + } + return parsePasswd(data), nil +} + +func lookupUser(name string) (User, bool, error) { + list, err := listUsers() + if err != nil { + return User{}, false, err + } + for _, u := range list { + if u.Username == name { + return u, true, nil + } + } + return User{}, false, nil +} + +// parsePasswd parses /etc/passwd content. Lines that are blank, commented, or +// malformed (fewer than 7 fields, non-numeric ids) are skipped. +func parsePasswd(data []byte) []User { + var users []User + for line := range strings.SplitSeq(string(data), "\n") { + if line == "" || strings.HasPrefix(line, "#") { + continue + } + f := strings.Split(line, ":") + if len(f) < 7 { + continue + } + uid, err := strconv.Atoi(f[2]) + if err != nil { + continue + } + gid, err := strconv.Atoi(f[3]) + if err != nil { + continue + } + users = append(users, User{ + Username: f[0], + UID: uid, + GID: gid, + Comment: f[4], + Home: f[5], + Shell: f[6], + System: uid < systemUIDMax, + }) + } + return users +} diff --git a/internal/modules/users/users_handler_test.go b/internal/modules/users/users_handler_test.go new file mode 100644 index 0000000..2121a23 --- /dev/null +++ b/internal/modules/users/users_handler_test.go @@ -0,0 +1,152 @@ +package users + +import ( + "encoding/json" + "net/http" + "os" + "path/filepath" + "reflect" + "testing" + + "nadir/internal/oscmd" + + "github.com/danielgtaylor/huma/v2" + "github.com/danielgtaylor/huma/v2/adapters/humago" + "github.com/danielgtaylor/huma/v2/humatest" +) + +func TestMain(m *testing.M) { + if oscmd.RunHelperProcess() { + return + } + os.Exit(m.Run()) +} + +func TestUsersHandlers(t *testing.T) { + tempPasswd := filepath.Join(t.TempDir(), "passwd") + initialContent := "root:x:0:0:root:/root:/bin/bash\nalice:x:1000:1000:Alice Smith:/home/alice:/bin/bash\n" + if err := os.WriteFile(tempPasswd, []byte(initialContent), 0644); err != nil { + t.Fatal(err) + } + + oldPasswd := passwdPath + passwdPath = tempPasswd + defer func() { passwdPath = oldPasswd }() + + mux := http.NewServeMux() + api := humatest.Wrap(t, humago.New(mux, huma.DefaultConfig("Test", "1.0.0"))) + registerUsers(api) + + // 1. Test GET /api/users + resp := api.Get("/api/users") + if resp.Code != http.StatusOK { + t.Errorf("list users: got %d, want %d", resp.Code, http.StatusOK) + } + var listRes ListUsersOutput + if err := json.Unmarshal(resp.Body.Bytes(), &listRes.Body); err != nil { + t.Fatal(err) + } + if len(listRes.Body.Users) != 2 { + t.Errorf("got %d users, want 2", len(listRes.Body.Users)) + } + + // 2. Test GET /api/users/{username} + resp = api.Get("/api/users/alice") + if resp.Code != http.StatusOK { + t.Errorf("get user: got %d, want %d", resp.Code, http.StatusOK) + } + var getRes GetUserOutput + if err := json.Unmarshal(resp.Body.Bytes(), &getRes.Body); err != nil { + t.Fatal(err) + } + if getRes.Body.Username != "alice" || getRes.Body.UID != 1000 { + t.Errorf("get user: got %+v", getRes.Body) + } + + resp = api.Get("/api/users/bob") + if resp.Code != http.StatusNotFound { + t.Errorf("get non-existent user: got %d, want %d", resp.Code, http.StatusNotFound) + } + + // 3. Test POST /api/users + oscmd.SetMock("useradd", func(args []string) oscmd.MockCommand { + wantArgs := []string{"-m", "-c", "Bob Jones", "-s", "/bin/sh", "--", "bob"} + if !reflect.DeepEqual(args, wantArgs) { + t.Errorf("useradd args: got %v, want %v", args, wantArgs) + } + bobContent := initialContent + "bob:x:1001:1001:Bob Jones:/home/bob:/bin/sh\n" + os.WriteFile(tempPasswd, []byte(bobContent), 0644) + return oscmd.MockCommand{ExitCode: 0} + }) + defer oscmd.ClearMocks() + + resp = api.Post("/api/users", struct { + Username string `json:"username"` + Comment string `json:"comment"` + Shell string `json:"shell"` + CreateHome bool `json:"create_home"` + }{ + Username: "bob", + Comment: "Bob Jones", + Shell: "/bin/sh", + CreateHome: true, + }) + if resp.Code != http.StatusOK { + t.Errorf("create user: got %d, want %d", resp.Code, http.StatusOK) + } + + // 4. Test DELETE /api/users/{username} + oscmd.SetMock("userdel", func(args []string) oscmd.MockCommand { + wantArgs := []string{"-r", "--", "bob"} + if !reflect.DeepEqual(args, wantArgs) { + t.Errorf("userdel args: got %v, want %v", args, wantArgs) + } + os.WriteFile(tempPasswd, []byte(initialContent), 0644) + return oscmd.MockCommand{ExitCode: 0} + }) + + resp = api.Delete("/api/users/bob?remove_home=true") + if resp.Code != http.StatusOK { + t.Errorf("delete user: got %d, want %d", resp.Code, http.StatusOK) + } + + // 5. Test POST /api/users/{username}/password + oscmd.SetMock("chpasswd", func(args []string) oscmd.MockCommand { + return oscmd.MockCommand{ExitCode: 0} + }) + + resp = api.Post("/api/users/alice/password", struct { + Password string `json:"password"` + }{ + Password: "newsecretpwd", + }) + if resp.Code != http.StatusOK { + t.Errorf("set password: got %d, want %d", resp.Code, http.StatusOK) + } + + // 6. Test PUT /api/users/{username}/groups + oscmd.SetMock("usermod", func(args []string) oscmd.MockCommand { + wantArgs := []string{"-G", "wheel,dev", "--", "alice"} + if !reflect.DeepEqual(args, wantArgs) { + t.Errorf("usermod args: got %v, want %v", args, wantArgs) + } + return oscmd.MockCommand{ExitCode: 0} + }) + + oscmd.SetMock("id", func(args []string) oscmd.MockCommand { + wantArgs := []string{"-nG", "alice"} + if !reflect.DeepEqual(args, wantArgs) { + t.Errorf("id args: got %v, want %v", args, wantArgs) + } + return oscmd.MockCommand{Stdout: "alice wheel dev\n", ExitCode: 0} + }) + + resp = api.Put("/api/users/alice/groups", struct { + Groups []string `json:"groups"` + }{ + Groups: []string{"wheel", "dev"}, + }) + if resp.Code != http.StatusOK { + t.Errorf("set groups: got %d, want %d", resp.Code, http.StatusOK) + } +} diff --git a/internal/modules/users/users_test.go b/internal/modules/users/users_test.go new file mode 100644 index 0000000..58cb356 --- /dev/null +++ b/internal/modules/users/users_test.go @@ -0,0 +1,53 @@ +package users + +import "testing" + +func TestParsePasswd(t *testing.T) { + data := []byte(`root:x:0:0:root:/root:/bin/bash +# a comment + +daemon:x:1:1:daemon:/usr/sbin:/usr/sbin/nologin +alice:x:1000:1000:Alice Smith:/home/alice:/bin/bash +broken:x:notanumber:5::: +short:x:2:2 +`) + got := parsePasswd(data) + if len(got) != 3 { + t.Fatalf("expected 3 valid users, got %d: %+v", len(got), got) + } + + alice := got[2] + if alice.Username != "alice" || alice.UID != 1000 || alice.GID != 1000 || + alice.Comment != "Alice Smith" || alice.Home != "/home/alice" || + alice.Shell != "/bin/bash" || alice.System { + t.Errorf("alice parsed wrong: %+v", alice) + } + if !got[0].System || !got[1].System { + t.Error("root/daemon should be flagged as system accounts") + } +} + +func TestValidateUsername(t *testing.T) { + valid := []string{"alice", "_svc", "user-1", "a", "machine$", "abc_def"} + for _, n := range valid { + if err := validateUsername(n); err != nil { + t.Errorf("validateUsername(%q) = %v, want nil", n, err) + } + } + + invalid := []string{ + "", // empty + "-rf", // leading dash (flag injection) + "Alice", // uppercase + "1user", // leading digit + "a b", // space + "foo;rm", // shell metachar + "root:x", // colon (passwd separator) + "waytoolongusernamethatexceedsthirtytwochars", // >32 + } + for _, n := range invalid { + if err := validateUsername(n); err == nil { + t.Errorf("validateUsername(%q) = nil, want error", n) + } + } +} diff --git a/internal/mounts/mounts.go b/internal/mounts/mounts.go new file mode 100644 index 0000000..9f2995d --- /dev/null +++ b/internal/mounts/mounts.go @@ -0,0 +1,68 @@ +// Package mounts parses the kernel mount table (/proc/mounts). Both the system +// dashboard (disk usage) and the storage module (mount management) need it, so +// it lives here rather than being duplicated. It also exposes the octal +// unescaping that /proc/mounts and /etc/fstab use for spaces/tabs in paths. +package mounts + +import ( + "os" + "strings" +) + +// procMounts is a var so tests can point it at a fixture. +var procMounts = "/proc/mounts" + +// Mount is one mount-table line: the backing device, where it's mounted, the +// filesystem type, and the comma-separated mount options. +type Mount struct { + Device string `json:"device" example:"/dev/sda1"` + Mountpoint string `json:"mountpoint" example:"/mnt/data"` + FSType string `json:"fstype" example:"ext4"` + Options string `json:"options" example:"rw,relatime"` +} + +// Proc reads and parses /proc/mounts. +func Proc() ([]Mount, error) { + data, err := os.ReadFile(procMounts) + if err != nil { + return nil, err + } + return parseProc(string(data)), nil +} + +func parseProc(data string) []Mount { + entries := []Mount{} + for line := range strings.SplitSeq(data, "\n") { + f := strings.Fields(line) + if len(f) < 4 { // device mountpoint fstype options [dump pass] + continue + } + entries = append(entries, Mount{ + Device: Unescape(f[0]), + Mountpoint: Unescape(f[1]), + FSType: f[2], + Options: f[3], + }) + } + return entries +} + +// Unescape decodes the octal escapes (\040 space, \011 tab, \012 newline, +// \134 backslash) that mount tables and fstab use for whitespace in fields. +func Unescape(s string) string { + if !strings.ContainsRune(s, '\\') { + return s + } + var b strings.Builder + for i := 0; i < len(s); i++ { + if s[i] == '\\' && i+3 < len(s) && isOctal(s[i+1]) && isOctal(s[i+2]) && isOctal(s[i+3]) { + b.WriteByte((s[i+1]-'0')*64 + (s[i+2]-'0')*8 + (s[i+3] - '0')) + i += 3 + continue + } + b.WriteByte(s[i]) + } + return b.String() +} + +func isOctal(c byte) bool { return c >= '0' && c <= '7' } diff --git a/internal/mounts/mounts_test.go b/internal/mounts/mounts_test.go new file mode 100644 index 0000000..2e667e3 --- /dev/null +++ b/internal/mounts/mounts_test.go @@ -0,0 +1,58 @@ +package mounts + +import ( + "os" + "path/filepath" + "reflect" + "testing" +) + +func TestParseProc(t *testing.T) { + // Real /proc/mounts shape, including a pseudo fs and an octal-escaped space. + data := `proc /proc proc rw,nosuid,nodev,noexec 0 0 +/dev/sda1 / ext4 rw,relatime 0 0 +/dev/sdb1 /mnt/my\040disk ext4 rw 0 0 +` + got := parseProc(data) + want := []Mount{ + {Device: "proc", Mountpoint: "/proc", FSType: "proc", Options: "rw,nosuid,nodev,noexec"}, + {Device: "/dev/sda1", Mountpoint: "/", FSType: "ext4", Options: "rw,relatime"}, + {Device: "/dev/sdb1", Mountpoint: "/mnt/my disk", FSType: "ext4", Options: "rw"}, + } + if !reflect.DeepEqual(got, want) { + t.Errorf("parseProc:\n got %+v\nwant %+v", got, want) + } +} + +func TestUnescape(t *testing.T) { + cases := map[string]string{ + `/mnt/my\040disk`: "/mnt/my disk", + `/no/escapes`: "/no/escapes", + `tab\011here`: "tab\there", + `back\134slash`: `back\slash`, + } + for in, want := range cases { + if got := Unescape(in); got != want { + t.Errorf("Unescape(%q) = %q, want %q", in, got, want) + } + } +} + +func TestProc(t *testing.T) { + dir := t.TempDir() + f := filepath.Join(dir, "mounts") + if err := os.WriteFile(f, []byte("/dev/sda1 / ext4 rw 0 0\n"), 0644); err != nil { + t.Fatal(err) + } + old := procMounts + procMounts = f + defer func() { procMounts = old }() + + got, err := Proc() + if err != nil { + t.Fatal(err) + } + if len(got) != 1 || got[0].Mountpoint != "/" { + t.Errorf("Proc() = %+v", got) + } +} diff --git a/internal/openapitest/openapi_test.go b/internal/openapitest/openapi_test.go new file mode 100644 index 0000000..ae210ea --- /dev/null +++ b/internal/openapitest/openapi_test.go @@ -0,0 +1,71 @@ +// Package openapitest holds a cross-module guard: it registers every module +// into one Huma API and generates the spec. Huma names OpenAPI schemas by Go +// type name alone (not package-qualified), so two modules declaring an +// identically named response type (e.g. two "ListOutput"s) collide only when +// registered together - a failure invisible to per-package `go build`/`go test`. +package openapitest + +import ( + "net/http" + "path/filepath" + "testing" + + "nadir/internal/auditlog" + "nadir/internal/auth" + "nadir/internal/meta" + "nadir/internal/module" + "nadir/internal/modules/audit" + "nadir/internal/modules/groups" + "nadir/internal/modules/networking" + "nadir/internal/modules/packages" + "nadir/internal/modules/services" + "nadir/internal/modules/storage" + "nadir/internal/modules/system" + "nadir/internal/modules/terminal" + "nadir/internal/modules/users" + "nadir/internal/rbac" + + "github.com/danielgtaylor/huma/v2" + "github.com/danielgtaylor/huma/v2/adapters/humago" +) + +func TestOpenAPISchemaNoCollisions(t *testing.T) { + auditStore, err := auditlog.New(filepath.Join(t.TempDir(), "audit.db")) + if err != nil { + t.Fatal(err) + } + sessions, err := auth.NewSessionStore(filepath.Join(t.TempDir(), "sessions.db")) + if err != nil { + t.Fatal(err) + } + roles := rbac.New() + + // Mirror cmd/server.go's registration so this catches real collisions. + mods := []module.Module{ + system.New(), + services.New(nil), + users.New(), + groups.New(), + packages.New(), + networking.New(), + storage.New(), + audit.New(auditStore), + terminal.New(sessions), + } + + mux := http.NewServeMux() + api := humago.New(mux, huma.DefaultConfig("test", "1.0.0")) + for _, m := range mods { + m.Register(api) // huma panics here on a duplicate schema name + } + meta.Register(api, mods) + meta.RegisterHealth(api, sessions) + meta.RegisterWhoami(api, sessions, roles, mods) + auth.RegisterLogin(api, sessions, auditStore, true) + auth.RegisterLogout(api, sessions, true) + + // Force full schema resolution. + if _, err := api.OpenAPI().YAML(); err != nil { + t.Fatalf("OpenAPI generation failed: %v", err) + } +} diff --git a/internal/oscmd/oscmd.go b/internal/oscmd/oscmd.go new file mode 100644 index 0000000..449ba70 --- /dev/null +++ b/internal/oscmd/oscmd.go @@ -0,0 +1,354 @@ +// Package oscmd holds small helpers shared by modules that shell out to system +// tools (hostnamectl, timedatectl, systemctl, …): a command runner that +// surfaces stderr, and the common "status: ok" HTTP response. +package oscmd + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "log" + "os" + "os/exec" + "strings" + "sync" + "time" +) + +// cmdTimeout caps how long a synchronous system command may run before it is +// killed, so a wedged tool can't tie up a request indefinitely. +// +// It bounds hangs for the plain Run path (which uses context.Background()). +// RunContext additionally propagates client cancellation. Long ops (package +// install/upgrade) use the streaming runners below, which take a ctx and are +// uncapped. +const cmdTimeout = 60 * time.Second + +// CommandRunner allows overriding the command execution function for testing. +var CommandRunner = func(ctx context.Context, name string, args ...string) *exec.Cmd { + mockMu.Lock() + handler, ok := mockCmds[name] + mockMu.Unlock() + + if ok { + mockRes := handler(args) + tmpFile, err := os.CreateTemp("", "nadir-mock-*.json") + if err != nil { + log.Printf("oscmd mock: failed to create temp file: %v", err) + return exec.CommandContext(ctx, name, args...) + } + + encoder := json.NewEncoder(tmpFile) + if err := encoder.Encode(mockRes); err != nil { + log.Printf("oscmd mock: failed to write json: %v", err) + tmpFile.Close() + return exec.CommandContext(ctx, name, args...) + } + tmpPath := tmpFile.Name() + tmpFile.Close() + + cmd := exec.CommandContext(ctx, os.Args[0]) + cmd.Env = append(os.Environ(), + "GO_WANT_HELPER_PROCESS=1", + "NADIR_MOCK_FILE="+tmpPath, + ) + return cmd + } + + return exec.CommandContext(ctx, name, args...) +} + +// Run executes name with args and returns trimmed stdout. On failure it wraps +// the command's stderr (falling back to the exec error) so handlers can surface +// a meaningful message instead of a bare "exit status 1". +// +// Run uses context.Background(): it is bounded by cmdTimeout but not tied to any +// request. Callers running a slow command on behalf of a request (so a client +// disconnect should kill it) should use RunContext instead. +func Run(name string, args ...string) (string, error) { + return RunContext(context.Background(), name, args...) +} + +// RunContext is Run with a caller-supplied context. The command is killed when +// ctx is cancelled (e.g. the client disconnected) or after cmdTimeout, whichever +// comes first. +func RunContext(ctx context.Context, name string, args ...string) (string, error) { + ctx, cancel := context.WithTimeout(ctx, cmdTimeout) + defer cancel() + cmd := CommandRunner(ctx, name, args...) + var stderr strings.Builder + cmd.Stderr = &stderr + out, err := cmd.Output() + if err != nil { + if msg := strings.TrimSpace(stderr.String()); msg != "" { + return "", errors.New(msg) + } + return "", err + } + return strings.TrimSpace(string(out)), nil +} + +// RunStdin is Run with data fed to the command's stdin. Use it for secrets +// (e.g. piping "user:password" to chpasswd) so they never appear in argv/ps. +func RunStdin(stdin, name string, args ...string) (string, error) { + ctx, cancel := context.WithTimeout(context.Background(), cmdTimeout) + defer cancel() + cmd := CommandRunner(ctx, name, args...) + cmd.Stdin = strings.NewReader(stdin) + var stderr strings.Builder + cmd.Stderr = &stderr + out, err := cmd.Output() + if err != nil { + if msg := strings.TrimSpace(stderr.String()); msg != "" { + return "", errors.New(msg) + } + return "", err + } + return strings.TrimSpace(string(out)), nil +} + +// RunLines runs the command and splits stdout into non-empty lines. +func RunLines(name string, args ...string) ([]string, error) { + out, err := Run(name, args...) + if err != nil { + return nil, err + } + if out == "" { + return []string{}, nil + } + return strings.Split(out, "\n"), nil +} + +// RunStatus runs name and returns trimmed stdout plus the process exit code. It +// does NOT treat a non-zero exit as an error, because some tools signal state +// that way (dnf `check-update` exits 100 when updates exist; pacman `-Qu` exits +// 1 when there are none). err is set only when the command could not be run. +func RunStatus(name string, args ...string) (string, int, error) { + ctx, cancel := context.WithTimeout(context.Background(), cmdTimeout) + defer cancel() + cmd := CommandRunner(ctx, name, args...) + out, err := cmd.Output() + if err != nil { + var ee *exec.ExitError + if errors.As(err, &ee) { + return strings.TrimSpace(string(out)), ee.ExitCode(), nil + } + return "", -1, err + } + return strings.TrimSpace(string(out)), 0, nil +} + +// RunStream runs a long-lived command (e.g. `journalctl -f`) and pushes its +// stdout lines onto the returned channel until the process exits or ctx is +// cancelled. Cancelling ctx kills the process (exec.CommandContext) and closes +// the channel, so an SSE handler can stop simply by cancelling its request +// context (client disconnect). +func RunStream(ctx context.Context, name string, args ...string) (<-chan string, error) { + cmd := CommandRunner(ctx, name, args...) + stdout, err := cmd.StdoutPipe() + if err != nil { + return nil, err + } + if err := cmd.Start(); err != nil { + return nil, err + } + ch := make(chan string) + go func() { + defer close(ch) + defer cmd.Wait() // reap; process is already killed on ctx cancel + sc := bufio.NewScanner(stdout) + sc.Buffer(make([]byte, 64*1024), 1024*1024) // tolerate long log lines + for sc.Scan() { + select { + case ch <- sc.Text(): + case <-ctx.Done(): + return + } + } + // A scan error (e.g. a line over the buffer cap) ends the loop silently + // otherwise; surface it so a truncated stream is diagnosable. Context + // cancellation returns above, so this only fires on a real read error. + if err := sc.Err(); err != nil && ctx.Err() == nil { + log.Printf("oscmd: stream %s: %v", name, err) + } + }() + return ch, nil +} + +// RunStreamCombined runs a command and streams its merged stdout+stderr line by +// line, the way the command's own terminal output looks. Lines are split on \r +// as well as \n so progress redraws (apt/dnf) stream instead of buffering until +// a newline. extraEnv is appended to the process environment. +// +// It returns a lines channel and a one-shot error channel that delivers the +// process's exit status after the lines channel closes (nil = success). +// Cancelling ctx kills the process. +func RunStreamCombined(ctx context.Context, extraEnv []string, name string, args ...string) (<-chan string, <-chan error, error) { + cmd := CommandRunner(ctx, name, args...) + if len(extraEnv) > 0 { + if cmd.Env == nil { + cmd.Env = append(os.Environ(), extraEnv...) + } else { + cmd.Env = append(cmd.Env, extraEnv...) + } + } + // One OS pipe for both streams gives the natural interleaving; passing an + // *os.File hands the fd straight to the child (no copy goroutine). + pr, pw, err := os.Pipe() + if err != nil { + return nil, nil, err + } + cmd.Stdout, cmd.Stderr = pw, pw + if err := cmd.Start(); err != nil { + pr.Close() + pw.Close() + return nil, nil, err + } + pw.Close() // the child holds its own dup; the reader sees EOF when it exits + + lines := make(chan string) + errc := make(chan error, 1) + go func() { + defer close(lines) + defer func() { errc <- cmd.Wait() }() + defer pr.Close() + sc := bufio.NewScanner(pr) + sc.Buffer(make([]byte, 64*1024), 1024*1024) + sc.Split(scanLinesCR) + for sc.Scan() { + if line := sc.Text(); line != "" { + select { + case lines <- line: + case <-ctx.Done(): + return + } + } + } + // Surface a read error (e.g. a line over the buffer cap) so a truncated + // stream is diagnosable. ctx cancellation returns above, so this only + // fires on a real error; the exit status still travels via errc. + if err := sc.Err(); err != nil && ctx.Err() == nil { + log.Printf("oscmd: stream %s: %v", name, err) + } + }() + return lines, errc, nil +} + +// scanLinesCR is a bufio.SplitFunc that breaks on either \r or \n, so terminal +// progress redraws (carriage returns) surface as they happen. +func scanLinesCR(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := bytes.IndexAny(data, "\r\n"); i >= 0 { + return i + 1, data[:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil // request more data +} + +// ParseKV parses key=value lines (e.g. from timedatectl/systemctl show) into a map. +func ParseKV(lines []string) map[string]string { + m := make(map[string]string, len(lines)) + for _, line := range lines { + if k, v, ok := strings.Cut(line, "="); ok { + m[k] = v + } + } + return m +} + +// StatusOutput is the shared response for write operations that just report +// success. Reusing one type means all such endpoints share a single OpenAPI +// schema. +type StatusOutput struct { + Body struct { + Status string `json:"status" example:"ok" doc:"Always \"ok\" on success"` + } +} + +// OK returns a populated StatusOutput. +func OK() *StatusOutput { + out := &StatusOutput{} + out.Body.Status = "ok" + return out +} + +// --- Mocking helpers for testing --------------------------------------------- + +// MockCommand holds the behavior for a mocked command. +type MockCommand struct { + Stdout string `json:"stdout"` + Stderr string `json:"stderr"` + ExitCode int `json:"exit_code"` + Lines []string `json:"lines,omitempty"` + DelayMs int `json:"delay_ms,omitempty"` +} + +var ( + mockMu sync.Mutex + mockCmds = make(map[string]func(args []string) MockCommand) +) + +// SetMock registers a mock handler function for the given command name. +func SetMock(name string, handler func(args []string) MockCommand) { + mockMu.Lock() + defer mockMu.Unlock() + mockCmds[name] = handler +} + +// ClearMocks removes all registered mock command handlers. +func ClearMocks() { + mockMu.Lock() + defer mockMu.Unlock() + clear(mockCmds) +} + +// RunHelperProcess executes the mock helper process logic if GO_WANT_HELPER_PROCESS is set. +// It returns true if it ran (and exits the process), false otherwise. +func RunHelperProcess() bool { + if os.Getenv("GO_WANT_HELPER_PROCESS") != "1" { + return false + } + mockFile := os.Getenv("NADIR_MOCK_FILE") + if mockFile == "" { + os.Exit(1) + } + defer os.Remove(mockFile) + + data, err := os.ReadFile(mockFile) + if err != nil { + fmt.Fprintf(os.Stderr, "mock helper: read failed: %v\n", err) + os.Exit(1) + } + + var mock MockCommand + if err := json.Unmarshal(data, &mock); err != nil { + fmt.Fprintf(os.Stderr, "mock helper: unmarshal failed: %v\n", err) + os.Exit(1) + } + + if len(mock.Lines) > 0 { + for _, line := range mock.Lines { + fmt.Fprintln(os.Stdout, line) + if mock.DelayMs > 0 { + time.Sleep(time.Duration(mock.DelayMs) * time.Millisecond) + } + } + } else { + if mock.Stdout != "" { + fmt.Fprint(os.Stdout, mock.Stdout) + } + if mock.Stderr != "" { + fmt.Fprint(os.Stderr, mock.Stderr) + } + } + + os.Exit(mock.ExitCode) + return true +} diff --git a/internal/oscmd/oscmd_test.go b/internal/oscmd/oscmd_test.go new file mode 100644 index 0000000..0b1fc0c --- /dev/null +++ b/internal/oscmd/oscmd_test.go @@ -0,0 +1,143 @@ +package oscmd + +import ( + "os" + "reflect" + "strings" + "testing" +) + +func TestMain(m *testing.M) { + if RunHelperProcess() { + return + } + os.Exit(m.Run()) +} + +func TestRunTrimsStdout(t *testing.T) { + out, err := Run("echo", "hello") + if err != nil { + t.Fatal(err) + } + if out != "hello" { + t.Errorf("got %q, want %q", out, "hello") + } +} + +func TestRunSurfacesStderr(t *testing.T) { + _, err := Run("sh", "-c", "echo boom >&2; exit 1") + if err == nil || !strings.Contains(err.Error(), "boom") { + t.Fatalf("expected stderr 'boom' in error, got %v", err) + } +} + +func TestRunMissingBinary(t *testing.T) { + if _, err := Run("definitely-not-a-real-binary-xyz"); err == nil { + t.Fatal("expected error for missing binary") + } +} + +func TestRunLines(t *testing.T) { + got, err := RunLines("printf", `a\nb\nc\n`) + if err != nil { + t.Fatal(err) + } + want := []string{"a", "b", "c"} + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) + } +} + +func TestRunLinesEmpty(t *testing.T) { + got, err := RunLines("true") + if err != nil { + t.Fatal(err) + } + if len(got) != 0 { + t.Errorf("empty output should yield no lines, got %v", got) + } +} + +func TestOK(t *testing.T) { + if OK().Body.Status != "ok" { + t.Error("OK() should report status ok") + } +} + +func TestRunStdin(t *testing.T) { + SetMock("cat", func(args []string) MockCommand { + return MockCommand{Stdout: "piped input data", ExitCode: 0} + }) + defer ClearMocks() + + got, err := RunStdin("some input", "cat") + if err != nil { + t.Fatal(err) + } + if got != "piped input data" { + t.Errorf("got %q, want %q", got, "piped input data") + } +} + +func TestRunStatus(t *testing.T) { + SetMock("check-update", func(args []string) MockCommand { + return MockCommand{Stdout: "updates available", ExitCode: 100} + }) + defer ClearMocks() + + out, code, err := RunStatus("check-update") + if err != nil { + t.Fatal(err) + } + if out != "updates available" || code != 100 { + t.Errorf("got out=%q code=%d, want out=%q code=100", out, code, "updates available") + } +} + +func TestRunStream(t *testing.T) { + SetMock("stream-cmd", func(args []string) MockCommand { + return MockCommand{Lines: []string{"line1", "line2"}, DelayMs: 1} + }) + defer ClearMocks() + + ch, err := RunStream(t.Context(), "stream-cmd") + if err != nil { + t.Fatal(err) + } + + var got []string + for line := range ch { + got = append(got, line) + } + + want := []string{"line1", "line2"} + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) + } +} + +func TestRunStreamCombined(t *testing.T) { + SetMock("combined-cmd", func(args []string) MockCommand { + return MockCommand{Lines: []string{"output1", "output2"}, ExitCode: 0} + }) + defer ClearMocks() + + lines, errc, err := RunStreamCombined(t.Context(), nil, "combined-cmd") + if err != nil { + t.Fatal(err) + } + + var got []string + for line := range lines { + got = append(got, line) + } + + if err := <-errc; err != nil { + t.Errorf("stream failed with error: %v", err) + } + + want := []string{"output1", "output2"} + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) + } +} diff --git a/internal/rbac/benchmark_test.go b/internal/rbac/benchmark_test.go new file mode 100644 index 0000000..ce216e1 --- /dev/null +++ b/internal/rbac/benchmark_test.go @@ -0,0 +1,32 @@ +package rbac + +import ( + "testing" +) + +func BenchmarkCan(b *testing.B) { + r := New() + r.DefineRole(Role{Name: "admin", ModuleGrants: map[string][]Permission{"*": {"*"}}}) + r.DefineRole(Role{Name: "viewer", ModuleGrants: map[string][]Permission{ + "system": {"read"}, + "services": {"read"}, + "audit": {"read"}, + }}) + r.DefineRole(Role{Name: "operator", ModuleGrants: map[string][]Permission{ + "system": {"read", "root"}, + "services": {"read", "write"}, + }}) + + r.AssignRole("alice", "admin") + r.AssignRole("bob", "viewer") + r.AssignRole("charlie", "viewer") + r.AssignRole("charlie", "operator") + + // b.Loop automatically excludes the setup above from timing (no ResetTimer + // needed) and keeps the calls from being optimised away. + for b.Loop() { + _ = r.Can("charlie", "services", "write") + _ = r.Can("bob", "system", "root") + _ = r.Can("alice", "audit", "write") + } +} diff --git a/internal/rbac/middleware.go b/internal/rbac/middleware.go new file mode 100644 index 0000000..219ca4f --- /dev/null +++ b/internal/rbac/middleware.go @@ -0,0 +1,121 @@ +package rbac + +import ( + "net/http" + "net/url" + + "nadir/internal/auditlog" + "nadir/internal/auth" + + "github.com/danielgtaylor/huma/v2" +) + +func RbacMiddleware(api huma.API, sessions *auth.SessionStore, tokens *auth.TokenAuth, roles *RBAC, auditor *auditlog.Store) func(huma.Context, func(huma.Context)) { + return func(ctx huma.Context, next func(huma.Context)) { + // CSRF defense-in-depth (beyond the SameSite=Strict cookie): reject any + // state-changing request whose Origin doesn't match our Host. Runs before + // the permission check so it also covers the public login/logout writes. + // + // Bearer tokens are exempt by construction: browsers don't auto-attach an + // Authorization header, so a token request carries no ambient authority for + // CSRF to exploit. Such requests also send no Origin, so they pass here too. + if !sameOriginOK(ctx) { + huma.WriteErr(api, ctx, http.StatusForbidden, "cross-origin request blocked") + return + } + + op := ctx.Operation() + perm, _ := op.Metadata["permission"].(string) + if perm == "" { + next(ctx) + return + } + moduleID, _ := op.Metadata["module"].(string) + + // Machine-to-machine: a Bearer token authenticates as its token name (a + // dashboard managing N nodes needs no PAM session). Checked before the + // cookie. The token name is the RBAC subject - assign it a role in + // config.yaml's `assignments`, same as a username. + if raw, isBearer := auth.BearerToken(ctx.Header("Authorization")); isBearer { + if tokens == nil { + huma.WriteErr(api, ctx, http.StatusUnauthorized, "unauthorized") + return + } + name, ok, throttled := tokens.Verify(auth.ClientIP(ctx.Context()), raw) + if throttled { + huma.WriteErr(api, ctx, http.StatusTooManyRequests, "too many failed token attempts; wait a minute") + return + } + if !ok { + huma.WriteErr(api, ctx, http.StatusUnauthorized, "unauthorized") + return + } + // Audit actor is prefixed "token:" so the trail distinguishes a token + // from a human with the same name; RBAC still checks the bare name. + if !roles.Can(name, moduleID, Permission(perm)) { + huma.WriteErr(api, ctx, http.StatusForbidden, "forbidden") + record(auditor, "token:"+name, op, ctx.URL().Path, http.StatusForbidden) + return + } + next(ctx) + record(auditor, "token:"+name, op, ctx.URL().Path, ctx.Status()) + return + } + + cookie, err := huma.ReadCookie(ctx, "nadir_session_id") + if err != nil || cookie == nil { + huma.WriteErr(api, ctx, http.StatusUnauthorized, "unauthorized") + return + } + sess, ok := sessions.GetByToken(cookie.Value) + if !ok { + huma.WriteErr(api, ctx, http.StatusUnauthorized, "unauthorized") + return + } + if !roles.Can(sess.Username, moduleID, Permission(perm)) { + huma.WriteErr(api, ctx, http.StatusForbidden, "forbidden") + record(auditor, sess.Username, op, ctx.URL().Path, http.StatusForbidden) + return + } + next(ctx) + record(auditor, sess.Username, op, ctx.URL().Path, ctx.Status()) + } +} + +// sameOriginOK allows safe methods and any request whose Origin header matches +// the request Host. A missing Origin (non-browser client, or a same-origin +// navigation) is allowed - CSRF is a browser-only concern. Only browser-issued +// cross-origin writes, which always carry an Origin, are rejected. +func sameOriginOK(ctx huma.Context) bool { + switch ctx.Method() { + case http.MethodGet, http.MethodHead, http.MethodOptions: + return true + } + origin := ctx.Header("Origin") + if origin == "" { + return true + } + u, err := url.Parse(origin) + if err != nil { + return false + } + return u.Host == ctx.Host() +} + +// record logs a mutation (anything but a read) to the audit trail. Reads are +// skipped to keep the trail to "who changed what". Best-effort: a logging +// failure is reported to the server log, never to the caller. +func record(auditor *auditlog.Store, username string, op *huma.Operation, path string, status int) { + if op.Method == http.MethodGet || op.Method == http.MethodHead || op.Method == http.MethodOptions { + return + } + // SSE handlers stream via BodyWriter and never call SetStatus, so huma's + // context reports status 0 even though net/http has implicitly sent a 200. + // A streamed response that reached here passed RBAC and started, so record + // it as 200 (stream-level failures surface as `error` events, not codes). + if status == 0 { + status = http.StatusOK + } + moduleID, _ := op.Metadata["module"].(string) + auditor.Record(username, op.Method, path, moduleID, status) +} diff --git a/internal/rbac/middleware_test.go b/internal/rbac/middleware_test.go new file mode 100644 index 0000000..0d91645 --- /dev/null +++ b/internal/rbac/middleware_test.go @@ -0,0 +1,165 @@ +package rbac + +import ( + "context" + "net/http" + "os" + "path/filepath" + "testing" + + "nadir/internal/auditlog" + "nadir/internal/auth" + "nadir/internal/oscmd" + + "github.com/danielgtaylor/huma/v2" + "github.com/danielgtaylor/huma/v2/adapters/humago" + "github.com/danielgtaylor/huma/v2/humatest" +) + +func TestMain(m *testing.M) { + if oscmd.RunHelperProcess() { + return + } + os.Exit(m.Run()) +} + +func TestRbacMiddleware(t *testing.T) { + tempDir := t.TempDir() + auditStore, err := auditlog.New(filepath.Join(tempDir, "audit.db")) + if err != nil { + t.Fatal(err) + } + defer auditStore.Close() + + sessions, err := auth.NewSessionStore(filepath.Join(tempDir, "sessions.db")) + if err != nil { + t.Fatal(err) + } + + tokenStore, err := auth.NewTokenStore(filepath.Join(tempDir, "tokens.db")) + if err != nil { + t.Fatal(err) + } + tokenAuth := auth.NewTokenAuth(tokenStore) + + r := New() + r.DefineRole(Role{ + Name: "test-role", + ModuleGrants: map[string][]Permission{ + "test-mod": {Read, Write}, + }, + }) + r.AssignRole("alice", "test-role") + // A machine credential is just another RBAC subject: the token name is + // assigned a role exactly like a username. + r.AssignRole("dash", "test-role") + + mux := http.NewServeMux() + api := humatest.Wrap(t, humago.New(mux, huma.DefaultConfig("Test", "1.0.0"))) + + api.UseMiddleware(RbacMiddleware(api, sessions, tokenAuth, r, auditStore)) + + huma.Register(api, huma.Operation{ + OperationID: "public-get", + Method: "GET", + Path: "/public", + }, func(ctx context.Context, _ *struct{}) (*struct{ Body string }, error) { + return &struct{ Body string }{Body: "public"}, nil + }) + + huma.Register(api, huma.Operation{ + OperationID: "gated-get", + Method: "GET", + Path: "/gated-read", + Metadata: map[string]any{"module": "test-mod", "permission": "read"}, + }, func(ctx context.Context, _ *struct{}) (*struct{ Body string }, error) { + return &struct{ Body string }{Body: "gated-read"}, nil + }) + + huma.Register(api, huma.Operation{ + OperationID: "gated-post", + Method: "POST", + Path: "/gated-write", + Metadata: map[string]any{"module": "test-mod", "permission": "write"}, + }, func(ctx context.Context, _ *struct{}) (*struct{ Body string }, error) { + return &struct{ Body string }{Body: "gated-write"}, nil + }) + + // 1. Test public route + resp := api.Get("/public") + if resp.Code != http.StatusOK { + t.Errorf("public GET: got status %d, want %d", resp.Code, http.StatusOK) + } + + // 2. Test gated route without cookie -> 401 Unauthorized + resp = api.Get("/gated-read") + if resp.Code != http.StatusUnauthorized { + t.Errorf("gated GET no cookie: got status %d, want %d", resp.Code, http.StatusUnauthorized) + } + + // 3. Test gated route with invalid cookie -> 401 Unauthorized + resp = api.Get("/gated-read", "Cookie: nadir_session_id=invalid") + if resp.Code != http.StatusUnauthorized { + t.Errorf("gated GET invalid cookie: got status %d, want %d", resp.Code, http.StatusUnauthorized) + } + + // Create valid session + token, err := sessions.Create("alice") + if err != nil { + t.Fatal(err) + } + + // 4. Test gated route with valid cookie -> 200 OK + resp = api.Get("/gated-read", "Cookie: nadir_session_id="+token) + if resp.Code != http.StatusOK { + t.Errorf("gated GET valid cookie: got status %d, want %d", resp.Code, http.StatusOK) + } + + // 5. Test CSRF violation: POST with mismatched Origin header -> 403 Forbidden + resp = api.Post("/gated-write", "Cookie: nadir_session_id="+token, "Origin: http://evil.com", "Host: example.com", struct{}{}) + if resp.Code != http.StatusForbidden { + t.Errorf("CSRF mismatched Origin: got status %d, want %d", resp.Code, http.StatusForbidden) + } + + // 6. Test CSRF success: POST with matching Origin header -> 200 OK + resp = api.Post("/gated-write", "Cookie: nadir_session_id="+token, "Origin: http://example.com", "Host: example.com", struct{}{}) + if resp.Code != http.StatusOK { + t.Errorf("CSRF matching Origin: got status %d, want %d", resp.Code, http.StatusOK) + } + + // 7. Test gated route with unauthorized user -> 403 Forbidden + tokenBob, err := sessions.Create("bob") + if err != nil { + t.Fatal(err) + } + resp = api.Get("/gated-read", "Cookie: nadir_session_id="+tokenBob) + if resp.Code != http.StatusForbidden { + t.Errorf("bob unauthorized GET: got status %d, want %d", resp.Code, http.StatusForbidden) + } + + // 8. Bearer token for an assigned name -> 200 OK + rawToken, err := tokenStore.Create("dash") + if err != nil { + t.Fatal(err) + } + resp = api.Get("/gated-read", "Authorization: Bearer "+rawToken) + if resp.Code != http.StatusOK { + t.Errorf("valid bearer GET: got status %d, want %d", resp.Code, http.StatusOK) + } + + // 9. Bogus bearer token -> 401 Unauthorized + resp = api.Get("/gated-read", "Authorization: Bearer nad_deadbeef") + if resp.Code != http.StatusUnauthorized { + t.Errorf("bogus bearer GET: got status %d, want %d", resp.Code, http.StatusUnauthorized) + } + + // 10. Bearer token with no role assignment -> 403 Forbidden + rawUnassigned, err := tokenStore.Create("orphan") + if err != nil { + t.Fatal(err) + } + resp = api.Get("/gated-read", "Authorization: Bearer "+rawUnassigned) + if resp.Code != http.StatusForbidden { + t.Errorf("unassigned bearer GET: got status %d, want %d", resp.Code, http.StatusForbidden) + } +} diff --git a/internal/rbac/rbac.go b/internal/rbac/rbac.go new file mode 100644 index 0000000..5117cf6 --- /dev/null +++ b/internal/rbac/rbac.go @@ -0,0 +1,71 @@ +package rbac + +type Permission string + +const ( + Read Permission = "read" + Write Permission = "write" + // Root is the high-impact tier: destructive or irreversible operations + // (reboot, shutdown, account deletion, firewall flush, …) that callers + // should be able to grant separately from routine writes. + Root Permission = "root" + All Permission = "*" // wildcard: matches any permission +) + +// Wildcard is the module-key value that matches all modules in a Role's grants. +const Wildcard = "*" + +type Role struct { + Name string + ModuleGrants map[string][]Permission // module ID (or "*") -> permissions (each may be "*") +} + +type RBAC struct { + roles map[string]Role + userRoles map[string][]string +} + +func New() *RBAC { + return &RBAC{ + roles: make(map[string]Role), + userRoles: make(map[string][]string), + } +} + +func (r *RBAC) DefineRole(role Role) { + r.roles[role.Name] = role +} + +// RoleExists reports whether a role with the given name has been defined. +func (r *RBAC) RoleExists(name string) bool { + _, ok := r.roles[name] + return ok +} + +func (r *RBAC) AssignRole(username, roleName string) { + r.userRoles[username] = append(r.userRoles[username], roleName) +} + +// Can checks whether the user holds any role granting (module, perm), +// honoring "*" wildcards on both the module key and inside the permission list. +func (r *RBAC) Can(username, module string, perm Permission) bool { + for _, roleName := range r.userRoles[username] { + role, ok := r.roles[roleName] + if !ok { + continue + } + // Check the exact module key AND the wildcard module key. + for _, key := range []string{module, Wildcard} { + grants, ok := role.ModuleGrants[key] + if !ok { + continue + } + for _, p := range grants { + if p == perm || p == All { + return true + } + } + } + } + return false +} diff --git a/internal/rbac/rbac_test.go b/internal/rbac/rbac_test.go new file mode 100644 index 0000000..12a7a57 --- /dev/null +++ b/internal/rbac/rbac_test.go @@ -0,0 +1,84 @@ +package rbac + +import "testing" + +// build sets up an RBAC store with the given roles, then assigns them all to +// user "u". +func build(t *testing.T, roles ...Role) *RBAC { + t.Helper() + r := New() + for _, role := range roles { + r.DefineRole(role) + r.AssignRole("u", role.Name) + } + return r +} + +func TestCan(t *testing.T) { + admin := Role{Name: "admin", ModuleGrants: map[string][]Permission{Wildcard: {All}}} + reader := Role{Name: "reader", ModuleGrants: map[string][]Permission{Wildcard: {Read}}} + sysWrite := Role{Name: "sysw", ModuleGrants: map[string][]Permission{"system": {Read, Write}}} + + tests := []struct { + name string + role Role + module string + perm Permission + want bool + }{ + {"wildcard module + wildcard perm", admin, "system", Root, true}, + {"wildcard module + wildcard perm, other module", admin, "services", Write, true}, + {"wildcard module, read only, read", reader, "anything", Read, true}, + {"wildcard module, read only, write denied", reader, "anything", Write, false}, + {"exact module exact perm", sysWrite, "system", Write, true}, + {"exact module missing perm", sysWrite, "system", Root, false}, + {"exact module, different module denied", sysWrite, "services", Read, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := build(t, tt.role) + if got := r.Can("u", tt.module, tt.perm); got != tt.want { + t.Errorf("Can(u, %q, %q) = %v, want %v", tt.module, tt.perm, got, tt.want) + } + }) + } +} + +func TestCanNoRolesDenied(t *testing.T) { + r := New() + if r.Can("nobody", "system", Read) { + t.Fatal("user with no roles was granted access") + } +} + +func TestCanAssignedUndefinedRoleSkipped(t *testing.T) { + r := New() + r.AssignRole("u", "ghost") // never defined + if r.Can("u", "system", Read) { + t.Fatal("undefined role granted access") + } +} + +func TestCanUnionOfRoles(t *testing.T) { + r := build(t, + Role{Name: "a", ModuleGrants: map[string][]Permission{"system": {Read}}}, + Role{Name: "b", ModuleGrants: map[string][]Permission{"services": {Write}}}, + ) + if !r.Can("u", "system", Read) || !r.Can("u", "services", Write) { + t.Fatal("union of roles not honored") + } + if r.Can("u", "system", Write) { + t.Fatal("permission leaked across modules") + } +} + +func TestRoleExists(t *testing.T) { + r := New() + r.DefineRole(Role{Name: "x"}) + if !r.RoleExists("x") { + t.Error("defined role reported missing") + } + if r.RoleExists("y") { + t.Error("undefined role reported present") + } +} diff --git a/tls/cert.pem b/tls/cert.pem new file mode 100644 index 0000000..451bdcb --- /dev/null +++ b/tls/cert.pem @@ -0,0 +1,21 @@ +-----BEGIN CERTIFICATE----- +MIIDbDCCAlSgAwIBAgIUfPrui00oPGw5okXJNXo0RP4dEMEwDQYJKoZIhvcNAQEL +BQAwLjEYMBYGA1UECgwPbmFkaXItZGV2LWxvY2FsMRIwEAYDVQQDDAlsb2NhbGhv +c3QwHhcNMjYwNjIwMDgwNDQ1WhcNMjcwNjIwMDgwNDQ1WjAuMRgwFgYDVQQKDA9u +YWRpci1kZXYtbG9jYWwxEjAQBgNVBAMMCWxvY2FsaG9zdDCCASIwDQYJKoZIhvcN +AQEBBQADggEPADCCAQoCggEBALifDPyGxdt32aP/QKbq+DISfwNzurcNu9hCqQT9 +keMgarPGnLd7J1aI6c12ypVz0DtFwfj9ALBTS7ga1x2FTo6cZtCF9tA0v6cpCEQs +mC0ecv8VOnWD9hy9teh6a/Elc48FMg1aug2/bPdHiKCfUT7yd7A7YEOl/o1RZj3A +1MkJllJ9CrHpgjmZBuypI21V0pYl1Wxwh3NvtMOAEkOp2OWZFsltzHoAtVbzPTR+ +5iDmrPUrmzToth0cakqBZlqlvU1e9bmSdiLdUj3Z0poZouGyjuuEixZzX3UJJwsM +DI8juRbWJioKjl7P+M5D6rZyB0RoO1n4N33lTluMqBeH8M8CAwEAAaOBgTB/MB0G +A1UdDgQWBBT9ugmFk9gRxY8Gkys7U4sYiiHiBzAfBgNVHSMEGDAWgBT9ugmFk9gR +xY8Gkys7U4sYiiHiBzAPBgNVHRMBAf8EBTADAQH/MCwGA1UdEQQlMCOCCWxvY2Fs +aG9zdIcEfwAAAYcQAAAAAAAAAAAAAAAAAAAAATANBgkqhkiG9w0BAQsFAAOCAQEA +nT3ca+8u8uhPlqj4M/uLwdxYsvZbYDRkAs058aXSjFl2bvF+jYqJhbO9YR5SkqPk +fFXG2/d7gtQZAMpgfKI3fIaEkWGUvj98MDqTbSfksspbTpIQUBssV8eBRUi+L74f +Midm6Ua6+M0kqCCwyCVs6befTBSnmxUa/xsMfHwaDH2YoAVxySnuT5/lU6y+X5fP +eAsGokINX2rZNBZbQofQwJ5y1rfPye9u6giPMcyEjnyHUvHuNUvBcxCfNy8gvs7T ++/kUbO1uQFRm1HcXh5dfdkb1QjottH1JGoF0AA3l1XqP6m9EdtFldt7WyIHqr7NJ +D+qOxooyToC2JLCaxDAoZw== +-----END CERTIFICATE----- diff --git a/tls/key.pem b/tls/key.pem new file mode 100644 index 0000000..2649ebc --- /dev/null +++ b/tls/key.pem @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC4nwz8hsXbd9mj +/0Cm6vgyEn8Dc7q3DbvYQqkE/ZHjIGqzxpy3eydWiOnNdsqVc9A7RcH4/QCwU0u4 +GtcdhU6OnGbQhfbQNL+nKQhELJgtHnL/FTp1g/YcvbXoemvxJXOPBTINWroNv2z3 +R4ign1E+8newO2BDpf6NUWY9wNTJCZZSfQqx6YI5mQbsqSNtVdKWJdVscIdzb7TD +gBJDqdjlmRbJbcx6ALVW8z00fuYg5qz1K5s06LYdHGpKgWZapb1NXvW5knYi3VI9 +2dKaGaLhso7rhIsWc191CScLDAyPI7kW1iYqCo5ez/jOQ+q2cgdEaDtZ+Dd95U5b +jKgXh/DPAgMBAAECggEAAWckAB8+Dabhfn+IDDyo2iiN0obkmlN+Y+xNwH30x9cN +OIR/2F0VNXEg5bDLZUtV/71N9ghmIvDfGG0LyWuj5y2FEnySHY7pDeof5/S2y1D5 +6rpMkWwJSLqgUT3s6A4yzJlrgfJ4i3Yy68YdYasUQPgytKIe3yS5xHUj48A9XbGz +ppbkHg76WuU9AUhCIgfSq0Xzz6cYv09eFx8ueiBlB2fGbjTs0j5CMlxLFkciNrYZ +2VNdLt5UuNvQFjj5gH0nWe76WAb2PjeQBzgiFdRKGh19AuClWG4URXLvnScGvSti +DuDY0QVusTCLeUJXx8bjxlK7wb7SKDGFGxjp7kr9gQKBgQDfWQG4Mts0fegxbsPy +WM42WQvSfxn6SSVDnoxAG4TZ0+FpaUKQKXgQ0z3sC8P78obE2aC+cPKWK8ZWHNt7 +6yaflG8yn4hodfXqb2JVp4WltXwPvG9s2+GnwFqs6yooXiRtnMp99ZUpqvnCg1i7 +e1mGe3nDBVMIquhP0VoVbFBgoQKBgQDTnKxO7IA2JtLQFWgQCDcO0uL2yvGXPLze +5pyjOCVuNUnA60ubVBdS+bB14KYZjdPLnX08QqtIdKjessHRAi36NStAA8cWUoeq +O10dWjmM5oMGtOnSn564+fj7nwoLKLWq9PQq18zuJll4MU9ntycgjAL9vTc5nKoT +0BlepJMrbwKBgQDbx5BDm/fM3aDhE+hJ0E2LeXCCwIPloJjEw32rj+jZGQCVY/kW +N1ho5hXm82T1xiAMEUN2Y1qzn3vaPSdV933YRo5tuELY2EsXWGfhdamz+LSOH5Vd +/7k8A7K2ueqQMqOSIVm5PTJ9ADwpxmpIgwcDqPmWiOS+gL9927rTnfQyQQKBgHUG +WNgQvFq2H7GJlRIAqQoen/uhgfeEVGLkn8032KNY/t+cgCR3Xaq6gNa/lLvfDji1 +cLOpnvWj5lu5+atvjCOp0bBGJox2uaXvzG/WHKuKMv27gO/E7E8ZlpL4geJn8geI +DZu/2gn91U693k7aH95E78aJJIhM1lW8qLsJQoYrAoGACZ5kJQcyasDCTh+cw/dB +B3qHhu9R416hr3OZrqDNHfqpI8D26ON87LECPg+veOCuZ3cmj/j6cCuQl1RXabxA +B0/kWfPc5M+WU9wgvRuYfBRfsNBVUiz3anEvo74rzFoUdl6rTg2OE2TtBz4ElITr +Ju5DB/Mki7nr8L0FHsJzMmY= +-----END PRIVATE KEY-----