first commit
This commit is contained in:
+15
@@ -0,0 +1,15 @@
|
||||
# Build artifacts
|
||||
/main
|
||||
/server
|
||||
/nadir
|
||||
|
||||
# Local environment / secrets
|
||||
.env
|
||||
config.yaml
|
||||
config.yml
|
||||
|
||||
# Editor
|
||||
*.swp
|
||||
server
|
||||
|
||||
CLAUDE.md
|
||||
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2026 urania
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
@@ -0,0 +1,497 @@
|
||||
# Nadir
|
||||
|
||||
Nadir is a lightweight, modular Linux system-administration backend - a modern,
|
||||
FOSS system admin panels. It exposes a typed REST API for the everyday tasks
|
||||
you'd otherwise SSH in to do: inspect the host, manage systemd services, edit
|
||||
local users and groups, install packages, and read logs - all behind
|
||||
role-based access control and a tamper-evident audit trail.
|
||||
|
||||
The API is generated with [Huma](https://huma.rocks) (OpenAPI 3.1) and ships
|
||||
interactive docs at `/docs`. The backend is Go and self-contained: no external
|
||||
database, no agent, no runtime dependencies beyond the standard system tools it
|
||||
drives (`systemctl`, `hostnamectl`, `useradd`, the host package manager, …).
|
||||
|
||||
---
|
||||
|
||||
## What it does
|
||||
|
||||
Functionality is organized into **modules**. Each module owns a slice of the
|
||||
API and declares its own permission vocabulary.
|
||||
|
||||
- **System** - Dashboard overview (OS/kernel, CPU, memory, disks, load, uptime,
|
||||
network interfaces, temperatures); get/set hostname; time, timezone, and NTP;
|
||||
locale and console keymap; reboot and power off.
|
||||
- **Services** - List and inspect systemd units; start / stop / restart / enable
|
||||
/ disable; read service logs from the journal or an allowlisted file, as a
|
||||
snapshot or a live Server-Sent-Events stream.
|
||||
- **Users** - List, inspect, create, and delete local accounts; set a password;
|
||||
set supplementary groups.
|
||||
- **Groups** - List, inspect, create, and delete local groups.
|
||||
- **Packages** - List installed packages and available updates; install, remove,
|
||||
and upgrade - streamed live over SSE. Auto-detects `dnf`, `apt`, or `pacman`.
|
||||
- **Networking** - List network interfaces, routing tables, and DNS settings; configure IPv4 settings with temporary applying and safety auto-rollback; bring interfaces up or down.
|
||||
- **Audit** - Read-only trail of every privileged write (who, what, when, result).
|
||||
- **Terminal** - Interactive shell access. Upgrades connection to a WebSocket and spawns a PTY shell as the logged-in user (requires `root` permission).
|
||||
- **Meta** - Self-description for clients: `/api/_modules`, `/api/whoami`,
|
||||
`/api/health`.
|
||||
|
||||
### Security model at a glance
|
||||
|
||||
- **Authentication** is delegated to PAM (`pam_unix`), so logins use real system
|
||||
credentials. A successful login sets an `HttpOnly`, `SameSite=Strict` session
|
||||
cookie; sessions are stored in SQLite and survive restarts.
|
||||
- **Machine credentials** for non-interactive callers (e.g. a central dashboard
|
||||
managing many nodes) authenticate with a static `Authorization: Bearer nad_…`
|
||||
token instead of a PAM session. Mint with `nadir token add <name>` (shown once,
|
||||
only its SHA-256 is stored); revoke with `nadir token rm <name>` (immediate, no
|
||||
restart). A token is an ordinary RBAC subject - its name is assigned a role in
|
||||
`config.yaml` `assignments`, so a leaked token is scoped, not implicitly admin.
|
||||
The audit trail records the actor as `token:<name>` to distinguish it from a
|
||||
human. CSRF does not apply: browsers never auto-attach a Bearer header, so the
|
||||
same-origin cookie defense is irrelevant for token auth. Bad-token guesses are
|
||||
throttled per source IP.
|
||||
- **Authorization** is RBAC driven entirely by `config.yaml`. Every protected
|
||||
operation declares a `module` and one of three permission tiers:
|
||||
- `read` - inspect (list users, read status, view logs…)
|
||||
- `write` - routine changes (create a user, restart a service, set the hostname…)
|
||||
- `root` - high-impact or irreversible actions (reboot, delete an account,
|
||||
**reset a password**, **change group membership**). Password and group-
|
||||
membership changes are `root` precisely because they can hand someone root.
|
||||
- **Brute-force throttling** on login (per username + source IP cooldown).
|
||||
- **CSRF** defense via `SameSite=Strict` plus a same-origin check on writes.
|
||||
- **Audit** of every mutation, written off the request path to SQLite.
|
||||
- The server **must run as root** - PAM reads `/etc/shadow`, and the system
|
||||
tools it drives (`hostnamectl`, `systemctl`, `useradd`, `shutdown`, …) require
|
||||
it.
|
||||
|
||||
<!-- api-desc-end -->
|
||||
|
||||
---
|
||||
|
||||
## Installing
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Linux with **systemd** (the Services module and the `nadir` service wrapper
|
||||
use it).
|
||||
- **Root** access (see above).
|
||||
- Go (recent) to build from source.
|
||||
|
||||
### Build
|
||||
|
||||
The entry point is the `main` package under `cmd/server`:
|
||||
|
||||
```bash
|
||||
go build -o nadir ./cmd/server
|
||||
```
|
||||
|
||||
This produces a single static-ish binary, `nadir`.
|
||||
|
||||
### Run directly
|
||||
|
||||
On first start, `nadir` requires a configuration file to exist. If the configuration is missing, the server will fail to start and ask you to run `nadir install` (to install the systemd service) or use `--save-config`.
|
||||
|
||||
To generate a default configuration file (assigning the admin role to your current user) without installing the systemd service:
|
||||
|
||||
```bash
|
||||
./nadir --save-config
|
||||
```
|
||||
|
||||
To save it for the root user (who runs the server):
|
||||
|
||||
```bash
|
||||
sudo ./nadir --save-config
|
||||
```
|
||||
|
||||
You can also specify a custom path using `-f`/`--config`:
|
||||
|
||||
```bash
|
||||
./nadir --save-config -f ./config.yaml
|
||||
```
|
||||
|
||||
Once the configuration file is created, start the server directly:
|
||||
|
||||
```bash
|
||||
sudo ./nadir # same as: sudo ./nadir run
|
||||
```
|
||||
|
||||
By default it reads `~/.config/config.yaml` (resolving to the running user's home, i.e., `/root/.config/config.yaml` when run as root); override with the `-f`/`--config` flag or `CONFIG_PATH` env var:
|
||||
|
||||
```bash
|
||||
sudo ./nadir -f /etc/nadir/config.yaml
|
||||
# or: sudo CONFIG_PATH=/etc/nadir/config.yaml ./nadir
|
||||
```
|
||||
|
||||
By default it serves **HTTPS** with a self-signed certificate (see
|
||||
[Deployment note 2](#2-tls-three-modes)) on the `hostname:port` from the config,
|
||||
and exposes interactive docs at `https://<host>:<port>/docs` and the raw spec at
|
||||
`/openapi.json`.
|
||||
|
||||
### Run in the background (`-d`)
|
||||
|
||||
Like `docker run -d`, this detaches from the terminal and returns your shell:
|
||||
|
||||
```bash
|
||||
sudo ./nadir run -d
|
||||
# nadir running in background (pid 12345); logs: /var/lib/nadir/server.log
|
||||
# follow with: nadir logs
|
||||
```
|
||||
|
||||
Output goes to `/var/lib/nadir/server.log`.
|
||||
|
||||
### Install as a systemd service (start on boot)
|
||||
|
||||
For a real deployment, register nadir as a service so it starts on boot and is
|
||||
managed with the usual tooling:
|
||||
|
||||
```bash
|
||||
sudo ./nadir install # writes the unit, enables it, and starts it now
|
||||
sudo ./nadir status
|
||||
sudo ./nadir logs # follow the journal live
|
||||
```
|
||||
|
||||
`install` writes `/etc/systemd/system/nadir.service` pinning the **absolute**
|
||||
binary and the absolute config file path (so it doesn't depend on the working directory at
|
||||
boot), runs `systemctl daemon-reload`, and `enable --now`. If no configuration file
|
||||
exists at the target path, `install` automatically creates a default config file and
|
||||
assigns the admin role to the installing user.
|
||||
|
||||
### CLI reference
|
||||
|
||||
| Command | Effect |
|
||||
| ------------------------------------------------ | --------------------------------------------------------------------------- |
|
||||
| `nadir [run] [-d]` | Start the server. `-d` / `--detach` runs it in the background. |
|
||||
| `nadir --save-config` | Save the default configuration template to the target path and exit. |
|
||||
| `nadir install` | Install + enable the systemd service (starts now and on boot). |
|
||||
| `nadir uninstall` | Stop, disable, and remove the systemd service. |
|
||||
| `nadir start` \| `stop` \| `restart` \| `status` | Control the running service. |
|
||||
| `nadir enable` \| `disable` | Toggle start-on-boot without removing the unit. |
|
||||
| `nadir logs` | Follow logs - journald if installed as a service, otherwise the detach log. |
|
||||
| `nadir help` | Show usage. |
|
||||
|
||||
Most commands need root.
|
||||
|
||||
---
|
||||
|
||||
## Configuration (`config.yaml`)
|
||||
|
||||
`config.yaml` is the single source of truth for runtime configuration: server
|
||||
and TLS settings, which roles exist, what each role can do, and who holds which
|
||||
role. By default, it reads `~/.config/config.yaml`. The path can be overridden using
|
||||
the `-f` / `--config` CLI flags or the `CONFIG_PATH` environment variable.
|
||||
|
||||
```yaml
|
||||
server:
|
||||
secure_tls: true # Secure flag on the session cookie (keep true behind TLS)
|
||||
trust_proxy: true # a reverse proxy terminates TLS; see Deployment note 3
|
||||
# tls_cert: /etc/nadir/tls/cert.pem # or terminate TLS in nadir yourself
|
||||
# tls_key: /etc/nadir/tls/key.pem
|
||||
hostname: 100.64.0.189
|
||||
port: 9999
|
||||
|
||||
# Quote "*" - bare * is YAML alias syntax and fails to parse.
|
||||
roles:
|
||||
admin:
|
||||
"*": ["*"] # every permission on every module (including future ones)
|
||||
auditor:
|
||||
"*": ["read"] # read-only everywhere
|
||||
system_ops:
|
||||
system: ["read", "write"]
|
||||
|
||||
assignments:
|
||||
urania: [admin]
|
||||
|
||||
# Optional: per-unit allowlist of log files the Services module may read.
|
||||
log_files:
|
||||
nginx:
|
||||
- /var/log/nginx/access.log
|
||||
- /var/log/nginx/error.log
|
||||
```
|
||||
|
||||
### `server`
|
||||
|
||||
| Key | Default | Meaning |
|
||||
| --------------------- | ------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `secure_tls` | `true` | Sets the `Secure` flag on the session cookie. Keep `true` whenever the browser reaches nadir over HTTPS (direct or via proxy); `false` only for local plain-HTTP dev. |
|
||||
| `trust_proxy` | `false` | When `true`, nadir serves plaintext HTTP and trusts `X-Forwarded-For` / forwarded `Host` from the proxy. See [Deployment note 3](#3-reverse-proxy--vpn). |
|
||||
| `tls_cert`, `tls_key` | - | PEM paths. When both are set (and `trust_proxy` is off), nadir terminates TLS with this pair. |
|
||||
| `hostname` | - | Address to bind. Use `127.0.0.1` for local-only, or an overlay/VPN address to expose nadir only on that interface. |
|
||||
| `port` | - | TCP port to listen on. |
|
||||
|
||||
TLS selection is covered in [Deployment note 2](#2-tls-three-modes).
|
||||
|
||||
### `roles` / `assignments`
|
||||
|
||||
- `roles` maps a role name to `module → [permissions]`. `"*"` as the module key
|
||||
means "all modules"; `"*"` in the permission list means "all permissions".
|
||||
- `assignments` maps a username to the roles they hold; effective grants are the
|
||||
union.
|
||||
- `"*"` must be quoted - bare `*` is YAML alias syntax and fails to parse.
|
||||
- Module keys and permissions are validated at startup against the modules
|
||||
actually compiled in. An unknown module, an unexported permission, or an
|
||||
assignment to an undefined role aborts startup with a clear message rather
|
||||
than silently granting or denying access.
|
||||
- Each module owns its permission vocabulary via `Permissions()`, so adding a
|
||||
module automatically makes it available to wildcard roles and validatable for
|
||||
restricted ones. Clients discover the live module/permission set at
|
||||
`GET /api/_modules`, and a user's own grants at `GET /api/whoami`.
|
||||
|
||||
### `log_files`
|
||||
|
||||
An allowlist, keyed by unit, of log file paths the Services module is allowed to
|
||||
read via the `source=file` log endpoints. The caller can only read paths an
|
||||
admin has listed here - never an arbitrary file.
|
||||
|
||||
---
|
||||
|
||||
## Deployment notes
|
||||
|
||||
These notes capture the non-obvious operational decisions. They'll seed
|
||||
the formal installation guide.
|
||||
|
||||
### 1. PAM service
|
||||
|
||||
**Nadir authenticates against its own PAM service, `/etc/pam.d/nadir`, and the
|
||||
server creates that file on startup if it is missing** (see
|
||||
`internal/auth/pamservice.go`). Here is why.
|
||||
|
||||
#### What went wrong with stock services
|
||||
|
||||
Originally we authenticated against the `"login"` service. On a Framework
|
||||
laptop (and many other machines) `/etc/pam.d/login` pulls in `system-auth`,
|
||||
whose auth stack lists `pam_fprintd.so` as `sufficient` **before**
|
||||
`pam_unix.so`:
|
||||
|
||||
```
|
||||
auth sufficient pam_fprintd.so # fingerprint, tried first
|
||||
auth sufficient pam_unix.so nullok # password, only reached if fprintd fails
|
||||
```
|
||||
|
||||
Our PAM conversation callback only answers the password prompt; it can't swipe
|
||||
a finger. So `pam_fprintd` would start a fingerprint scan and **block until its
|
||||
~30-second timeout** before falling through to the password check. Every login
|
||||
took 30s. (It was never a network, D-Bus, systemd, or NSS problem —
|
||||
`hostnamectl` was instant and there is no SSSD/LDAP on the box.)
|
||||
|
||||
Switching to `"passwd"` is not a fix either: `/etc/pam.d/passwd` has only a
|
||||
`password` stack and no `auth` stack, so it can't verify a login.
|
||||
|
||||
#### The fix
|
||||
|
||||
Ship a dedicated, minimal service - exactly what `sshd`, `cockpit`, and
|
||||
`polkit` do. `/etc/pam.d/nadir` contains only:
|
||||
|
||||
```
|
||||
#%PAM-1.0
|
||||
auth required pam_unix.so
|
||||
account required pam_unix.so
|
||||
```
|
||||
|
||||
That is a straight `/etc/shadow` password check plus an account-validity check
|
||||
— no fingerprint, no systemd, no env loading, no DNS. Authentication drops from
|
||||
~30s to milliseconds, and we stop inheriting whatever the distro's login stack
|
||||
happens to do.
|
||||
|
||||
Notes:
|
||||
|
||||
- We omit `nullok` on purpose: this service is reachable over the network, and
|
||||
`nullok` would let passwordless accounts log in.
|
||||
- `EnsurePAMService()` **only writes the file when it is absent** - a missing
|
||||
service falls through to `/etc/pam.d/other` (`pam_deny`), which looks identical
|
||||
to "wrong credentials". If an admin customizes the file, nadir leaves it
|
||||
untouched.
|
||||
- `pam_unix` reads `/etc/shadow`, so the server must run as root.
|
||||
|
||||
### 2. TLS: three modes
|
||||
|
||||
Credentials and session cookies must never travel in cleartext. Nadir picks how
|
||||
the connection is secured from `config.yaml`, in priority order:
|
||||
|
||||
1. **Behind a reverse proxy** (`trust_proxy: true`) - a proxy such as Traefik
|
||||
terminates TLS and forwards plaintext to nadir on a trusted network. Keep
|
||||
`secure_tls: true` (the browser↔proxy leg is HTTPS). This is the deployment
|
||||
covered in note 3.
|
||||
2. **Nadir terminates TLS** (`tls_cert` + `tls_key`) - point both at a PEM
|
||||
certificate/key pair and nadir serves HTTPS directly. Use this when there is
|
||||
no proxy.
|
||||
3. **Self-signed (dev only)** - when none of the above is configured, nadir
|
||||
generates a fresh in-memory self-signed certificate (valid for `localhost`
|
||||
and the loopback addresses, one year). Browsers will warn; that's expected.
|
||||
Never rely on this in production.
|
||||
|
||||
To create a persistent self-signed pair for mode 2 in development:
|
||||
|
||||
```bash
|
||||
openssl req -x509 -newkey rsa:2048 -nodes \
|
||||
-keyout key.pem -out cert.pem -days 365 \
|
||||
-subj "/O=nadir-dev-local/CN=localhost" \
|
||||
-addext "subjectAltName=DNS:localhost,IP:127.0.0.1,IP:::1"
|
||||
```
|
||||
|
||||
…then set `tls_cert`/`tls_key` to those paths.
|
||||
|
||||
### 3. Reverse proxy + VPN
|
||||
|
||||
When nadir runs behind a TLS-terminating reverse proxy (e.g. Traefik) on a
|
||||
private overlay network, set `trust_proxy: true`. Nadir then serves plaintext
|
||||
HTTP and trusts `X-Forwarded-For` (used by the login throttle) and the forwarded
|
||||
`Host` (used by the CSRF same-origin check). **That trust is only safe if
|
||||
nothing but the proxy can reach the app's port** - otherwise any client that
|
||||
reaches it directly can forge those headers.
|
||||
|
||||
The recommended shape: the proxy and the app each sit on a WireGuard-based
|
||||
overlay, and nadir binds to its overlay address so the public/LAN interfaces
|
||||
never answer.
|
||||
|
||||
```yaml
|
||||
server:
|
||||
trust_proxy: true
|
||||
secure_tls: true # browser↔proxy leg is HTTPS, so keep the cookie Secure
|
||||
hostname: 100.64.0.189 # the app's overlay IP - only the VPN interface listens
|
||||
port: 9999
|
||||
```
|
||||
|
||||
**Netbird / Tailscale** assign peers out of `100.64.0.0/10` (RFC 6598 CGNAT),
|
||||
which is not publicly routable - binding there means only VPN peers can connect.
|
||||
**Plain WireGuard** is the same idea with a private range you pick (e.g.
|
||||
`10.0.0.0/24`); bind to the app's address on the `wg0` interface.
|
||||
|
||||
Two things make the header trust airtight:
|
||||
|
||||
1. **Restrict the port to the proxy peer only.** Binding to the overlay limits
|
||||
reachability to _all_ VPN peers, not just the proxy. Tighten it so only the
|
||||
proxy can reach `:9999`:
|
||||
- _Netbird_: an access-control policy allowing the proxy peer/group → the app
|
||||
peer on tcp/9999, denying others.
|
||||
- _Tailscale_: an ACL rule (`"src": ["tag:proxy"], "dst": ["tag:nadir:9999"]`).
|
||||
- _Plain WireGuard_: a host firewall rule on the app, e.g.
|
||||
`iptables -A INPUT -i wg0 ! -s <proxy-wg-ip> -p tcp --dport 9999 -j DROP`.
|
||||
|
||||
2. **Make the proxy overwrite client-supplied forwarded headers.** Otherwise a
|
||||
client sending its own `X-Forwarded-For` / `X-Forwarded-Host` can have it
|
||||
passed through. In Traefik, mark the overlay as trusted on the entrypoint:
|
||||
|
||||
```yaml
|
||||
# traefik static config
|
||||
entryPoints:
|
||||
websecure:
|
||||
address: ":443"
|
||||
forwardedHeaders:
|
||||
trustedIPs:
|
||||
- 100.64.0.0/10 # or your wg subnet, e.g. 10.0.0.0/24
|
||||
```
|
||||
|
||||
And ensure it forwards the original host (Traefik does by default; nginx needs
|
||||
`proxy_set_header Host $host;`), since the CSRF check compares `Origin`
|
||||
against `Host`.
|
||||
|
||||
With both in place, the only path to the app is proxy → overlay → app, and the
|
||||
forwarded headers are trustworthy. Without step 1 you're trusting every peer on
|
||||
the overlay - fine for a single-tenant network you fully control, risky on a
|
||||
shared one.
|
||||
|
||||
### 4. Connecting a dashboard (machine clients)
|
||||
|
||||
To manage one or more Nadir instances via a central dashboard or non-interactive client, authenticate requests using a static Bearer token rather than interactive PAM credentials.
|
||||
|
||||
Here is how to authorize and connect a dashboard:
|
||||
|
||||
#### Step 1: Mint a token
|
||||
Run `nadir token add <name>` (for example, `dashboard`) to generate a unique API key:
|
||||
```bash
|
||||
sudo nadir token add dashboard
|
||||
```
|
||||
This generates a secure token starting with `nad_`. **Copy this token immediately**; only its SHA-256 hash is stored in `/var/lib/nadir/tokens.db` (shared via SQLite WAL between server and CLI), and the raw key cannot be retrieved again.
|
||||
|
||||
#### Step 2: Authorize the token in `config.yaml`
|
||||
Minting and authorizing are deliberately separate steps (safe default). A newly minted token does not grant any access.
|
||||
|
||||
To grant the token a role, edit the `assignments` map in your `config.yaml`:
|
||||
```yaml
|
||||
assignments:
|
||||
dashboard: [admin] # or another role like [system_ops] or [auditor]
|
||||
```
|
||||
The audit log will record mutations performed by this token as `token:<name>` (e.g., `token:dashboard`), distinguishing it from human logins.
|
||||
|
||||
#### Step 3: Restart Nadir
|
||||
While token creation and revocation (`nadir token rm`) are written to the database and take effect immediately, policy assignments live in `config.yaml`. To reload the configuration and authorize the new token name, you must restart the Nadir server:
|
||||
```bash
|
||||
sudo systemctl restart nadir
|
||||
```
|
||||
|
||||
#### Step 4: Configure the dashboard client
|
||||
Configure your client to include the token in the HTTP `Authorization` header of every API request:
|
||||
```http
|
||||
Authorization: Bearer nad_your_secret_token_here
|
||||
```
|
||||
|
||||
#### Note on CORS / Cross-Origin requests
|
||||
If your dashboard runs as a web application directly in the user's browser (cross-origin relative to the Nadir instance) and makes state-changing write requests (`POST`, `PUT`, `DELETE`), the browser will include an `Origin` header.
|
||||
|
||||
To defend against CSRF, Nadir's middleware rejects state-changing requests if an `Origin` header is present and does not match the request's `Host` header.
|
||||
|
||||
To connect a browser-based dashboard hosted on a different origin, choose one of these patterns:
|
||||
1. **Server-to-Server Calls (Recommended):** Build the dashboard with a backend that calls Nadir's API. Because the backend is not a browser, it does not send an `Origin` header, allowing the requests to pass.
|
||||
2. **Reverse Proxy:** Terminate the dashboard and the Nadir instance under the same origin (e.g., dashboard at `https://control.example.com/` and Nadir at `https://control.example.com/api/nadir-node-1/`), letting a reverse proxy route the requests.
|
||||
3. **Header Rewriting:** Have a proxy in front of Nadir rewrite/strip the `Origin` header for authorized token requests before forwarding them to Nadir.
|
||||
|
||||
---
|
||||
|
||||
## Layout
|
||||
|
||||
```
|
||||
cmd/ process entry point + CLI (run / install / logs …), TLS, service wiring
|
||||
internal/auth PAM auth, sessions, login/logout, login throttle, PAM service install
|
||||
internal/config config.yaml loader + startup validation
|
||||
internal/meta /api/_modules, /api/whoami, /api/health discovery endpoints
|
||||
internal/module the Module interface
|
||||
internal/modules concrete modules:
|
||||
system - info, hostname, time/timezone/NTP, locale/keymap, power
|
||||
services - systemd unit control + journal/file logs (snapshot + SSE)
|
||||
users - local accounts
|
||||
groups - local groups
|
||||
packages - dnf/apt/pacman install/remove/upgrade (streamed)
|
||||
audit - read-only audit trail
|
||||
networking - network interfaces, routing tables, DNS, and IP configurations
|
||||
terminal - interactive PTY shell over WebSocket
|
||||
internal/oscmd shared command runner (timeouts, stderr surfacing) + helpers
|
||||
internal/rbac roles, permissions ("*" wildcards), HTTP middleware (RBAC + CSRF)
|
||||
internal/audit SQLite-backed audit log writer
|
||||
```
|
||||
|
||||
## API docs
|
||||
|
||||
With the server running, browse `https://<host>:<port>/docs` for the Scalar UI,
|
||||
or fetch the raw OpenAPI document from `/openapi.json`.
|
||||
|
||||
---
|
||||
|
||||
## Built with LLM assistance
|
||||
|
||||
This project was built with the help of large language models - but every
|
||||
architectural choice, security decision, and operational trade-off is the
|
||||
author's. The LLM never drove; it was a power tool, not a co-pilot with the
|
||||
wheel.
|
||||
|
||||
In practice, the workflow looks like this: the author designs the feature,
|
||||
decides how it should fit into the existing module structure, specifies the API
|
||||
surface, and defines the security and permission semantics. The LLM then
|
||||
accelerates the mechanical side - scaffolding boilerplate, drafting
|
||||
implementations from precise instructions, generating documentation, and
|
||||
proposing test cases. Every line of output is reviewed, corrected where needed,
|
||||
and integrated only when it meets the project's standards.
|
||||
|
||||
What the LLM provides is _commodity leverage_: it collapses the time between
|
||||
"I know exactly what I want" and "it's written, tested, and documented." What
|
||||
it does not provide is judgment - that stays with the person who understands the
|
||||
system, its threat model, and its users.
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
[MIT](./LICENSE)
|
||||
|
||||
## Credits
|
||||
|
||||
Favicon: [Orbit](https://lucide.dev/icons/orbit) from [Lucide](https://lucide.dev), recolored. Lucide icons are licensed under the [ISC License](https://github.com/lucide-icons/lucide/blob/main/LICENSE).
|
||||
@@ -0,0 +1,28 @@
|
||||
// Package nadir exists only to embed shared documentation (the README) so a
|
||||
// single source of truth can feed both GitHub and the OpenAPI description.
|
||||
package nadir
|
||||
|
||||
import _ "embed"
|
||||
|
||||
// README is the project README. Content up to the "<!-- api-desc-end -->" marker
|
||||
// is reused as the API description; see cmd/main.go.
|
||||
//
|
||||
//go:embed README.md
|
||||
var README string
|
||||
|
||||
// Favicon is the orbit icon (lucide.dev, recolored midnight-blue) served as
|
||||
// both the app favicon and the /docs page icon.
|
||||
//
|
||||
//go:embed favicon.svg
|
||||
var Favicon string
|
||||
|
||||
// InstallScriptTemplate is the curl|sh bootstrap script served at /install.sh.
|
||||
// It contains the placeholder __NADIR_BASE_URL__, substituted at request time
|
||||
// with the scheme+host the script was fetched from - so it downloads the
|
||||
// binary from the very instance that served it. See cmd/server/server.go.
|
||||
//
|
||||
//go:embed install.sh.tmpl
|
||||
var InstallScriptTemplate string
|
||||
|
||||
// Version is the current version of the Nadir application.
|
||||
const Version = "0.0.1"
|
||||
@@ -0,0 +1,96 @@
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
# Nadir configuration - config.yaml
|
||||
#
|
||||
# This is the single source of truth for runtime settings: server/TLS,
|
||||
# roles, role assignments, and log-file allowlists. The only env var
|
||||
# nadir reads is CONFIG_PATH, the bootstrap pointer to this file.
|
||||
# By default, nadir uses ~/.config/config.yaml. You can override this
|
||||
# with the -f / --config flag or by setting the CONFIG_PATH env var.
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
server:
|
||||
# Secure attribute on the session cookie.
|
||||
# Keep true when the browser reaches nadir over HTTPS (direct or via proxy).
|
||||
# Set false only for local plain-HTTP development.
|
||||
secure_tls: true
|
||||
|
||||
# TLS mode, in priority order:
|
||||
#
|
||||
# 1. trust_proxy: true
|
||||
# A reverse proxy (e.g. Traefik) terminates TLS and forwards plaintext
|
||||
# to nadir. Bind hostname to a private/overlay address so only the proxy
|
||||
# can reach nadir. X-Forwarded-For is trusted in this mode.
|
||||
#
|
||||
# 2. tls_cert + tls_key (uncomment below)
|
||||
# Nadir terminates TLS itself with your PEM certificate and key.
|
||||
#
|
||||
# 3. Neither (default)
|
||||
# Nadir generates a fresh in-memory self-signed certificate on every
|
||||
# start. Browsers will warn - this is for development only.
|
||||
#
|
||||
# Keep secure_tls: true in modes 1 and 2.
|
||||
# trust_proxy: false
|
||||
# tls_cert: /etc/nadir/tls/cert.pem
|
||||
# tls_key: /etc/nadir/tls/key.pem
|
||||
|
||||
# Address and port to bind.
|
||||
# Use 127.0.0.1 for local-only, or an overlay/VPN address (e.g. a Netbird
|
||||
# or Tailscale IP) to expose nadir only on that interface.
|
||||
hostname: localhost
|
||||
port: 9999
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
# Roles
|
||||
#
|
||||
# Maps a role name to { module → [permissions] }.
|
||||
# - Module key "*" means "all modules (including future ones)".
|
||||
# - Permission "*" means "all permissions the module exports".
|
||||
# - IMPORTANT: quote "*" - bare * is YAML alias syntax and fails to parse.
|
||||
#
|
||||
# Each module exports its own permission vocabulary via Permissions().
|
||||
# Valid tiers are: read, write, root. Unknown modules or permissions
|
||||
# cause a startup error, not a silent denial.
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
roles:
|
||||
# Full access - every permission on every module.
|
||||
admin:
|
||||
"*": ["*"]
|
||||
|
||||
# Read-only on all modules - good for monitoring dashboards.
|
||||
# auditor:
|
||||
# "*": ["read"]
|
||||
|
||||
# Scoped operator - can read and write the system module only.
|
||||
# system_ops:
|
||||
# system: ["read", "write"]
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
# Assignments
|
||||
#
|
||||
# Maps a local username to one or more roles. Effective grants are the
|
||||
# union of all assigned roles' permissions. The username must match a
|
||||
# real system account (PAM authenticates against /etc/shadow).
|
||||
#
|
||||
# A machine credential (Bearer token) is assigned a role the same way, by its
|
||||
# token name rather than a system username. Mint with `nadir token add
|
||||
# central-dashboard`, then grant it scoped access below. Until it is listed
|
||||
# here the token authenticates but can do nothing.
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
assignments:
|
||||
# Replace with your admin username.
|
||||
ubuntu: [admin]
|
||||
# central-dashboard: [auditor]
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
# Log files (optional)
|
||||
#
|
||||
# Per-unit allowlist of log files the Services module may serve via the
|
||||
# source=file log endpoints. Only paths listed here are readable - the
|
||||
# caller can never request an arbitrary file.
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
# log_files:
|
||||
# nginx:
|
||||
# - /var/log/nginx/access.log
|
||||
# - /var/log/nginx/error.log
|
||||
# files:
|
||||
# - /var/log/auth.log
|
||||
@@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="#191970" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-orbit-icon lucide-orbit"><path d="M20.341 6.484A10 10 0 0 1 10.266 21.85"/><path d="M3.659 17.516A10 10 0 0 1 13.74 2.152"/><circle cx="12" cy="12" r="3"/><circle cx="19" cy="5" r="2"/><circle cx="5" cy="19" r="2"/></svg>
|
||||
|
After Width: | Height: | Size: 418 B |
@@ -0,0 +1,26 @@
|
||||
module nadir
|
||||
|
||||
go 1.26.4
|
||||
|
||||
require github.com/msteinert/pam v1.2.0
|
||||
|
||||
require github.com/danielgtaylor/huma/v2 v2.38.0
|
||||
|
||||
require (
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
modernc.org/sqlite v1.52.0
|
||||
github.com/coder/websocket v1.8.15
|
||||
github.com/creack/pty v1.1.24
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.21 // indirect
|
||||
github.com/ncruces/go-strftime v1.0.0 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
golang.org/x/sys v0.43.0 // indirect
|
||||
modernc.org/libc v1.72.3 // indirect
|
||||
modernc.org/mathutil v1.7.1 // indirect
|
||||
modernc.org/memory v1.11.0 // indirect
|
||||
)
|
||||
@@ -0,0 +1,74 @@
|
||||
github.com/coder/websocket v1.8.15 h1:6B2JPeOGlpff2Uz6vOEH1Vzpi0iUz20A+lPVhPHtNUA=
|
||||
github.com/coder/websocket v1.8.15/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
|
||||
github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s=
|
||||
github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE=
|
||||
github.com/danielgtaylor/huma/v2 v2.38.0 h1:fb0WZCatnaiHLphMQDDWDjygNxfMkX/ENma3QsRl7vY=
|
||||
github.com/danielgtaylor/huma/v2 v2.38.0/go.mod h1:k9hwjlgWFt1t2jsmQGlsgXAG2FBTZa4kkjV581qAtfo=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/fxamacker/cbor/v2 v2.9.1 h1:2rWm8B193Ll4VdjsJY28jxs70IdDsHRWgQYAI80+rMQ=
|
||||
github.com/fxamacker/cbor/v2 v2.9.1/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
|
||||
github.com/mattn/go-isatty v0.0.21 h1:xYae+lCNBP7QuW4PUnNG61ffM4hVIfm+zUzDuSzYLGs=
|
||||
github.com/mattn/go-isatty v0.0.21/go.mod h1:ZXfXG4SQHsB/w3ZeOYbR0PrPwLy+n6xiMrJlRFqopa4=
|
||||
github.com/msteinert/pam v1.2.0 h1:mYfjlvN2KYs2Pb9G6nb/1f/nPfAttT/Jee5Sq9r3bGE=
|
||||
github.com/msteinert/pam v1.2.0/go.mod h1:d2n0DCUK8rGecChV3JzvmsDjOY4R7AYbsNxAT+ftQl0=
|
||||
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
||||
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
|
||||
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
|
||||
golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8=
|
||||
golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w=
|
||||
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
||||
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||
golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI=
|
||||
golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/term v0.6.0 h1:clScbb1cHjoCkyRbWwBEUZ5H/tIFu5TAXIqaZD0Gcjw=
|
||||
golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U=
|
||||
golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k=
|
||||
golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
modernc.org/cc/v4 v4.28.2 h1:3tQ0lf2ADtoby2EtSP+J7IE2SHwEJdP8ioR59wx7XpY=
|
||||
modernc.org/cc/v4 v4.28.2/go.mod h1:OnovgIhbbMXMu1aISnJ0wvVD1KnW+cAUJkIrAWh+kVI=
|
||||
modernc.org/ccgo/v4 v4.34.0 h1:yRLPFZieg532OT4rp4JFNIVcquwalMX26G95WQDqwCQ=
|
||||
modernc.org/ccgo/v4 v4.34.0/go.mod h1:AS5WYMyBakQ+fhsHhtP8mWB82KTGPkNNJDGfGQCe0/A=
|
||||
modernc.org/fileutil v1.4.0 h1:j6ZzNTftVS054gi281TyLjHPp6CPHr2KCxEXjEbD6SM=
|
||||
modernc.org/fileutil v1.4.0/go.mod h1:EqdKFDxiByqxLk8ozOxObDSfcVOv/54xDs/DUHdvCUU=
|
||||
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
|
||||
modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
|
||||
modernc.org/gc/v3 v3.1.2 h1:ZtDCnhonXSZexk/AYsegNRV1lJGgaNZJuKjJSWKyEqo=
|
||||
modernc.org/gc/v3 v3.1.2/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
|
||||
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
|
||||
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
|
||||
modernc.org/libc v1.72.3 h1:ZnDF4tXn4NBXFutMMQC4vtbTFSXhhKzR73fv0beZEAU=
|
||||
modernc.org/libc v1.72.3/go.mod h1:dn0dZNnnn1clLyvRxLxYExxiKRZIRENOfqQ8XEeg4Qs=
|
||||
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
|
||||
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
|
||||
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
|
||||
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
|
||||
modernc.org/opt v0.2.0 h1:tGyef5ApycA7FSEOMraay9SaTk5zmbx7Tu+cJs4QKZg=
|
||||
modernc.org/opt v0.2.0/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
|
||||
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
|
||||
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
|
||||
modernc.org/sqlite v1.52.0 h1:p4dhYh2tXZCiyaqHwRVJDjIGKWyXayiQpThxgDzJaxo=
|
||||
modernc.org/sqlite v1.52.0/go.mod h1:tcNzv5p84E0skkmJn038y+hWJbLQXQqEnQfeh5r2JLM=
|
||||
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
|
||||
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
|
||||
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
||||
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=
|
||||
@@ -0,0 +1,52 @@
|
||||
#!/bin/sh
|
||||
set -e
|
||||
# Nadir bootstrap installer.
|
||||
#
|
||||
# Downloads the nadir binary from the SAME instance that served this script
|
||||
# and installs it as a systemd service. Intended for spinning up nadir on a
|
||||
# sibling host quickly (same architecture as the source instance):
|
||||
#
|
||||
# curl -fsSL https://<existing-nadir-host>:<port>/install.sh | sudo sh
|
||||
#
|
||||
# This is not a general-purpose distribution channel: there is no separate
|
||||
# release server, no signature, and no version pinning beyond "whatever the
|
||||
# source instance is currently running." Trust it exactly as much as you
|
||||
# trust the source host.
|
||||
#
|
||||
# Everything is wrapped in do_install and invoked on the last line, the same
|
||||
# pattern get.docker.com uses: when piped via `curl | sh`, sh executes while
|
||||
# still receiving bytes from the network. Defining the whole body as a
|
||||
# function before running anything means sh has fully parsed the script
|
||||
# before do_install ever runs, which avoids exactly the failure mode where a
|
||||
# command further down (here, the binary download) races the still-open
|
||||
# script-fetch connection and fails with a write error.
|
||||
|
||||
BASE_URL="__NADIR_BASE_URL__"
|
||||
|
||||
do_install() {
|
||||
if [ "$(id -u)" -ne 0 ]; then
|
||||
echo "this script must be run as root (try: curl -fsSL $BASE_URL/install.sh | sudo sh)" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! command -v curl >/dev/null 2>&1; then
|
||||
echo "curl is required" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "downloading nadir from $BASE_URL ..."
|
||||
# No -s here: the progress meter is the whole point of leaving it visible,
|
||||
# so you can see how much is left mid-download.
|
||||
curl -f --progress-bar -L "$BASE_URL/nadir-binary" -o /usr/local/bin/nadir.tmp
|
||||
mv /usr/local/bin/nadir.tmp /usr/local/bin/nadir
|
||||
chmod +x /usr/local/bin/nadir
|
||||
|
||||
echo "binary installed at /usr/local/bin/nadir"
|
||||
echo "installing as a systemd service ..."
|
||||
/usr/local/bin/nadir install
|
||||
|
||||
echo
|
||||
echo "done. check status with: nadir status"
|
||||
}
|
||||
|
||||
do_install
|
||||
@@ -0,0 +1,148 @@
|
||||
// Package audit records privileged write operations to an embedded SQLite
|
||||
// database so there is a durable "who did what" trail.
|
||||
package auditlog
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
// buffer sizes the channel that decouples Record (request path) from the DB
|
||||
// writer goroutine. Bursts up to this many entries return immediately; only a
|
||||
// sustained overflow falls back to a synchronous write.
|
||||
const buffer = 1024
|
||||
|
||||
// Entry is one recorded action.
|
||||
type Entry struct {
|
||||
Time string `json:"time" example:"2026-06-20T08:15:04Z" doc:"When the action occurred (RFC3339, UTC)"`
|
||||
Username string `json:"username" example:"alice" doc:"Who performed it"`
|
||||
Method string `json:"method" example:"POST" doc:"HTTP method"`
|
||||
Path string `json:"path" example:"/api/users" doc:"Request path"`
|
||||
Module string `json:"module" example:"users" doc:"Target module"`
|
||||
Status int `json:"status" example:"200" doc:"HTTP response status"`
|
||||
|
||||
// ts is the capture time (unix seconds), carried through the writer channel
|
||||
// so the stored timestamp reflects when the action happened, not when the
|
||||
// background writer got to it.
|
||||
ts int64
|
||||
}
|
||||
|
||||
type Store struct {
|
||||
db *sql.DB
|
||||
ch chan Entry
|
||||
done chan struct{} // closed when the writer has drained and exited
|
||||
|
||||
// mu guards the channel against a Record racing Close. Record takes RLock
|
||||
// (many concurrent senders); Close takes the write lock to flip closed and
|
||||
// close the channel exactly once. Without it, a Record after Close would
|
||||
// panic sending on a closed channel.
|
||||
mu sync.RWMutex
|
||||
closed bool
|
||||
}
|
||||
|
||||
// New opens (creating if needed) the audit database at path and starts the
|
||||
// background writer that drains recorded entries to disk.
|
||||
func New(path string) (*Store, error) {
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0700); err != nil {
|
||||
return nil, fmt.Errorf("audit db dir: %w", err)
|
||||
}
|
||||
db, err := sql.Open("sqlite", path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open audit db: %w", err)
|
||||
}
|
||||
// ponytail: single connection serializes writes, same rationale as the
|
||||
// session store; the background writer is the only writer.
|
||||
db.SetMaxOpenConns(1)
|
||||
if _, err := db.Exec(`CREATE TABLE IF NOT EXISTS audit (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
ts INTEGER NOT NULL,
|
||||
username TEXT NOT NULL,
|
||||
method TEXT NOT NULL,
|
||||
path TEXT NOT NULL,
|
||||
module TEXT NOT NULL,
|
||||
status INTEGER NOT NULL
|
||||
)`); err != nil {
|
||||
return nil, fmt.Errorf("create audit table: %w", err)
|
||||
}
|
||||
s := &Store{db: db, ch: make(chan Entry, buffer), done: make(chan struct{})}
|
||||
go s.writer()
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// writer is the single goroutine that persists entries, keeping all DB writes
|
||||
// off the request path. It exits once Close closes the channel, after draining
|
||||
// whatever is buffered.
|
||||
func (s *Store) writer() {
|
||||
defer close(s.done)
|
||||
for e := range s.ch {
|
||||
s.insert(e)
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops accepting entries, drains those still buffered, and closes the
|
||||
// database. Call it during shutdown, after the HTTP server has stopped so no
|
||||
// further Record calls can race with the channel close.
|
||||
func (s *Store) Close() error {
|
||||
s.mu.Lock()
|
||||
s.closed = true
|
||||
close(s.ch)
|
||||
s.mu.Unlock()
|
||||
<-s.done
|
||||
return s.db.Close()
|
||||
}
|
||||
|
||||
func (s *Store) insert(e Entry) {
|
||||
if _, err := s.db.Exec(
|
||||
`INSERT INTO audit (ts, username, method, path, module, status) VALUES (?, ?, ?, ?, ?, ?)`,
|
||||
e.ts, e.Username, e.Method, e.Path, e.Module, e.Status,
|
||||
); err != nil {
|
||||
log.Printf("audit: insert failed (%s %s by %s): %v", e.Method, e.Path, e.Username, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Record queues an entry for the background writer and returns immediately. If
|
||||
// the buffer is full (sustained overload), it writes synchronously instead so an
|
||||
// audit entry is never silently dropped - only then does it pay DB latency.
|
||||
func (s *Store) Record(username, method, path, module string, status int) {
|
||||
e := Entry{ts: time.Now().Unix(), Username: username, Method: method, Path: path, Module: module, Status: status}
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
if s.closed {
|
||||
s.insert(e) // post-shutdown straggler: write directly, channel is gone
|
||||
return
|
||||
}
|
||||
select {
|
||||
case s.ch <- e:
|
||||
default:
|
||||
s.insert(e)
|
||||
}
|
||||
}
|
||||
|
||||
// List returns the most recent entries, newest first, capped at limit.
|
||||
func (s *Store) List(limit int) ([]Entry, error) {
|
||||
rows, err := s.db.Query(
|
||||
`SELECT ts, username, method, path, module, status FROM audit ORDER BY id DESC LIMIT ?`, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
entries := []Entry{}
|
||||
for rows.Next() {
|
||||
var e Entry
|
||||
var ts int64
|
||||
if err := rows.Scan(&ts, &e.Username, &e.Method, &e.Path, &e.Module, &e.Status); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
e.Time = time.Unix(ts, 0).UTC().Format(time.RFC3339)
|
||||
entries = append(entries, e)
|
||||
}
|
||||
return entries, rows.Err()
|
||||
}
|
||||
@@ -0,0 +1,107 @@
|
||||
package auditlog
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func newTestStore(t *testing.T) *Store {
|
||||
t.Helper()
|
||||
s, err := New(filepath.Join(t.TempDir(), "audit.db"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// eventually polls List until at least want entries are persisted (writes are
|
||||
// async) or it times out.
|
||||
func eventually(t *testing.T, s *Store, want int) []Entry {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(time.Second)
|
||||
for {
|
||||
got, err := s.List(100)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(got) >= want {
|
||||
return got
|
||||
}
|
||||
if time.Now().After(deadline) {
|
||||
t.Fatalf("timed out waiting for %d entries, have %d", want, len(got))
|
||||
}
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordAndList(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
s.Record("alice", "POST", "/api/users", "users", 200)
|
||||
s.Record("bob", "DELETE", "/api/users/x", "users", 403)
|
||||
|
||||
got := eventually(t, s, 2)
|
||||
// Newest first.
|
||||
if got[0].Username != "bob" || got[0].Status != 403 {
|
||||
t.Errorf("newest entry wrong: %+v", got[0])
|
||||
}
|
||||
if got[1].Username != "alice" || got[1].Method != "POST" {
|
||||
t.Errorf("oldest entry wrong: %+v", got[1])
|
||||
}
|
||||
if got[0].Time == "" {
|
||||
t.Error("time not populated")
|
||||
}
|
||||
}
|
||||
|
||||
func TestListLimit(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
for range 5 {
|
||||
s.Record("u", "POST", "/api/x", "system", 200)
|
||||
}
|
||||
eventually(t, s, 5)
|
||||
got, err := s.List(3)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(got) != 3 {
|
||||
t.Errorf("limit not honored: got %d, want 3", len(got))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloseDrains(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "audit.db")
|
||||
s, err := New(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
for range 50 {
|
||||
s.Record("u", "POST", "/api/x", "system", 200)
|
||||
}
|
||||
// Close must flush everything still buffered before returning.
|
||||
if err := s.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Reopen and confirm all 50 made it to disk.
|
||||
s2, err := New(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got, err := s2.List(100)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(got) != 50 {
|
||||
t.Errorf("Close did not drain: got %d entries, want 50", len(got))
|
||||
}
|
||||
}
|
||||
|
||||
func TestListEmpty(t *testing.T) {
|
||||
got, err := newTestStore(t).List(10)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(got) != 0 {
|
||||
t.Errorf("fresh store should be empty, got %d", len(got))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TokenAuth verifies Bearer credentials against the TokenStore and throttles
|
||||
// brute force. Unlike the login throttle (keyed on username+IP), a Bearer token
|
||||
// has no "login" step to rate-limit, so guesses are throttled by source IP
|
||||
// alone. The window is looser than login's because a legitimate dashboard may
|
||||
// fire many requests in a minute - only repeated *failures* count.
|
||||
type TokenAuth struct {
|
||||
store *TokenStore
|
||||
throttle *failLimiter
|
||||
}
|
||||
|
||||
// NewTokenAuth wraps a store with an IP-keyed failure throttle.
|
||||
func NewTokenAuth(store *TokenStore) *TokenAuth {
|
||||
return &TokenAuth{store: store, throttle: newFailLimiter(20, time.Minute)}
|
||||
}
|
||||
|
||||
// Verify resolves a presented Bearer token to its name. throttled is true when
|
||||
// the source IP is in cooldown after too many bad tokens; the caller should
|
||||
// answer 429 without consulting the store.
|
||||
func (a *TokenAuth) Verify(ip, raw string) (name string, ok, throttled bool) {
|
||||
if a.throttle.blocked(ip) {
|
||||
return "", false, true
|
||||
}
|
||||
name, found := a.store.Lookup(raw)
|
||||
if !found {
|
||||
a.throttle.fail(ip)
|
||||
return "", false, false
|
||||
}
|
||||
a.throttle.reset(ip)
|
||||
return name, true, false
|
||||
}
|
||||
|
||||
// BearerToken extracts the credential from an Authorization header value,
|
||||
// reporting whether it was a Bearer scheme. Returns ("", false) for cookie-only
|
||||
// or unauthenticated requests.
|
||||
func BearerToken(authHeader string) (string, bool) {
|
||||
const scheme = "Bearer "
|
||||
if len(authHeader) <= len(scheme) || !strings.EqualFold(authHeader[:len(scheme)], scheme) {
|
||||
return "", false
|
||||
}
|
||||
return strings.TrimSpace(authHeader[len(scheme):]), true
|
||||
}
|
||||
@@ -0,0 +1,90 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"nadir/internal/auditlog"
|
||||
|
||||
"github.com/danielgtaylor/huma/v2"
|
||||
)
|
||||
|
||||
// authenticator verifies a username/password (PAM in production). It's a field
|
||||
// of the login handler rather than a package global so tests can inject a stub
|
||||
// without mutating shared state.
|
||||
type authenticator func(username, password string) error
|
||||
|
||||
type LoginInput struct {
|
||||
Body struct {
|
||||
Username string `json:"username" doc:"System username"`
|
||||
Password string `json:"password" doc:"System password"`
|
||||
}
|
||||
}
|
||||
|
||||
type LoginOutput struct {
|
||||
SetCookie http.Cookie `header:"Set-Cookie"`
|
||||
Body struct {
|
||||
Status string `json:"status" example:"logged in"`
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterLogin wires the login operation into the Huma API. It has no
|
||||
// permission metadata, so the RBAC middleware lets it through unauthenticated.
|
||||
//
|
||||
// secure sets the Secure attribute on the session cookie. Keep it true in
|
||||
// production (the cookie is then only sent over HTTPS); set it false for local
|
||||
// development over plain HTTP, where a Secure cookie would never be sent back.
|
||||
func RegisterLogin(api huma.API, sessions *SessionStore, auditor *auditlog.Store, secure bool) {
|
||||
// loginThrottle blunts brute force: 5 failures for a username+source IP
|
||||
// trigger a one-minute cooldown. See throttle.go for the ceiling/upgrade path.
|
||||
registerLogin(api, sessions, auditor, secure, Authenticate, newFailLimiter(5, time.Minute))
|
||||
}
|
||||
|
||||
func registerLogin(api huma.API, sessions *SessionStore, auditor *auditlog.Store, secure bool, authenticate authenticator, throttle *failLimiter) {
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "login",
|
||||
Method: "POST",
|
||||
Path: "/api/login",
|
||||
Summary: "Authenticate and start a session",
|
||||
Description: "Verifies the username and password against PAM (the " +
|
||||
"dedicated nadir service) and, on success, sets an HttpOnly " +
|
||||
"session cookie used to authorize all other endpoints.",
|
||||
Tags: []string{"Authentication"},
|
||||
Errors: []int{401, 429},
|
||||
}, func(ctx context.Context, in *LoginInput) (*LoginOutput, error) {
|
||||
// Throttle brute force: too many recent failures for this account/source
|
||||
// put it in a short cooldown before the password is even checked.
|
||||
throttleKey := in.Body.Username + "|" + ClientIP(ctx)
|
||||
if throttle.blocked(throttleKey) {
|
||||
return nil, huma.Error429TooManyRequests("too many failed login attempts; wait a minute")
|
||||
}
|
||||
// Record both outcomes: failed logins are the brute-force signal, and the
|
||||
// username is captured even on failure (which account is being targeted).
|
||||
if err := authenticate(in.Body.Username, in.Body.Password); err != nil {
|
||||
throttle.fail(throttleKey)
|
||||
auditor.Record(in.Body.Username, "POST", "/api/login", "auth", http.StatusUnauthorized)
|
||||
return nil, huma.Error401Unauthorized("invalid credentials")
|
||||
}
|
||||
throttle.reset(throttleKey)
|
||||
auditor.Record(in.Body.Username, "POST", "/api/login", "auth", http.StatusOK)
|
||||
|
||||
sessionID, err := sessions.Create(in.Body.Username)
|
||||
if err != nil {
|
||||
return nil, huma.Error500InternalServerError("could not create session", err)
|
||||
}
|
||||
out := &LoginOutput{
|
||||
SetCookie: http.Cookie{
|
||||
Name: "nadir_session_id",
|
||||
Value: sessionID,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: secure,
|
||||
SameSite: http.SameSiteStrictMode,
|
||||
Expires: time.Now().Add(24 * time.Hour),
|
||||
},
|
||||
}
|
||||
out.Body.Status = "logged in"
|
||||
return out, nil
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,185 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"nadir/internal/auditlog"
|
||||
"nadir/internal/oscmd"
|
||||
|
||||
"github.com/danielgtaylor/huma/v2"
|
||||
"github.com/danielgtaylor/huma/v2/adapters/humago"
|
||||
"github.com/danielgtaylor/huma/v2/humatest"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
if oscmd.RunHelperProcess() {
|
||||
return
|
||||
}
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
func TestEnsurePAMService(t *testing.T) {
|
||||
tempFile := filepath.Join(t.TempDir(), "nadir-pam-test")
|
||||
oldPath := pamServicePath
|
||||
pamServicePath = tempFile
|
||||
defer func() { pamServicePath = oldPath }()
|
||||
|
||||
if err := EnsurePAMService(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(tempFile)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(data) != pamServiceContent {
|
||||
t.Errorf("got content %q, want %q", string(data), pamServiceContent)
|
||||
}
|
||||
|
||||
customContent := "custom pam content"
|
||||
if err := os.WriteFile(tempFile, []byte(customContent), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := EnsurePAMService(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, err = os.ReadFile(tempFile)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(data) != customContent {
|
||||
t.Errorf("EnsurePAMService clobbered file: got %q, want %q", string(data), customContent)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoginLogoutThrottling(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
auditStore, err := auditlog.New(filepath.Join(tempDir, "audit.db"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer auditStore.Close()
|
||||
|
||||
sessions, err := NewSessionStore(filepath.Join(tempDir, "sessions.db"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
api := humatest.Wrap(t, humago.New(mux, huma.DefaultConfig("Test", "1.0.0")))
|
||||
|
||||
// Inject a stub authenticator and a short-window throttle (3 failures) at
|
||||
// registration — no package globals to mutate. The throttle keys on
|
||||
// username+IP, so the success/failure cases below and the throttle case (a
|
||||
// distinct user) don't interfere.
|
||||
authMock := func(username, password string) error {
|
||||
if password == "correct" {
|
||||
return nil
|
||||
}
|
||||
return errors.New("pam error")
|
||||
}
|
||||
registerLogin(api, sessions, auditStore, false, authMock, newFailLimiter(3, 500*time.Millisecond))
|
||||
RegisterLogout(api, sessions, false)
|
||||
|
||||
// 1. Test failed login
|
||||
resp := api.Post("/api/login", struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}{
|
||||
Username: "admin",
|
||||
Password: "wrong",
|
||||
})
|
||||
if resp.Code != http.StatusUnauthorized {
|
||||
t.Errorf("failed login: got code %d, want %d", resp.Code, http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
// 2. Test successful login
|
||||
resp = api.Post("/api/login", struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}{
|
||||
Username: "admin",
|
||||
Password: "correct",
|
||||
})
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("successful login: got code %d, want %d", resp.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
cookieHeader := resp.Header().Get("Set-Cookie")
|
||||
if !strings.Contains(cookieHeader, "nadir_session_id=") {
|
||||
t.Fatalf("Set-Cookie header missing nadir_session_id: %q", cookieHeader)
|
||||
}
|
||||
|
||||
var sessionID string
|
||||
parts := strings.SplitSeq(cookieHeader, ";")
|
||||
for part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if after, ok := strings.CutPrefix(part, "nadir_session_id="); ok {
|
||||
sessionID = after
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if sessionID == "" {
|
||||
t.Fatal("nadir_session_id cookie not found")
|
||||
}
|
||||
|
||||
_, ok := sessions.GetByToken(sessionID)
|
||||
if !ok {
|
||||
t.Fatal("session not found in session store")
|
||||
}
|
||||
|
||||
// 3. Test logout
|
||||
resp = api.Post("/api/logout", "Cookie: nadir_session_id="+sessionID, struct{}{})
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("logout failed: got code %d, want %d", resp.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
_, ok = sessions.GetByToken(sessionID)
|
||||
if ok {
|
||||
t.Fatal("session still valid after logout")
|
||||
}
|
||||
|
||||
// 4. Test throttling (the handler was registered with a 3-failure limiter).
|
||||
for range 3 {
|
||||
api.Post("/api/login", struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}{
|
||||
Username: "throttled-user",
|
||||
Password: "wrong",
|
||||
})
|
||||
}
|
||||
|
||||
resp = api.Post("/api/login", struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}{
|
||||
Username: "throttled-user",
|
||||
Password: "correct",
|
||||
})
|
||||
if resp.Code != http.StatusTooManyRequests {
|
||||
t.Errorf("throttled login: got code %d, want %d", resp.Code, http.StatusTooManyRequests)
|
||||
}
|
||||
|
||||
time.Sleep(600 * time.Millisecond)
|
||||
|
||||
resp = api.Post("/api/login", struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}{
|
||||
Username: "throttled-user",
|
||||
Password: "correct",
|
||||
})
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("login after cooldown: got code %d, want %d", resp.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/danielgtaylor/huma/v2"
|
||||
)
|
||||
|
||||
type LogoutInput struct {
|
||||
SessionID string `cookie:"nadir_session_id"`
|
||||
}
|
||||
|
||||
type LogoutOutput struct {
|
||||
SetCookie http.Cookie `header:"Set-Cookie"`
|
||||
Body struct {
|
||||
Status string `json:"status" example:"logged out"`
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterLogout wires the logout operation. Like login it carries no permission
|
||||
// metadata, so the RBAC middleware lets it through: it deletes whatever session
|
||||
// the cookie names (a no-op for an unknown/expired token) and clears the cookie.
|
||||
func RegisterLogout(api huma.API, sessions *SessionStore, secure bool) {
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "logout",
|
||||
Method: "POST",
|
||||
Path: "/api/logout",
|
||||
Summary: "End the current session",
|
||||
Description: "Invalidates the session named by the cookie and clears it. " +
|
||||
"Always succeeds, even without a valid session.",
|
||||
Tags: []string{"Authentication"},
|
||||
}, func(ctx context.Context, in *LogoutInput) (*LogoutOutput, error) {
|
||||
if in.SessionID != "" {
|
||||
if err := sessions.Delete(in.SessionID); err != nil {
|
||||
return nil, huma.Error500InternalServerError("could not end session", err)
|
||||
}
|
||||
}
|
||||
out := &LogoutOutput{
|
||||
SetCookie: http.Cookie{
|
||||
Name: "nadir_session_id",
|
||||
Value: "",
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: secure,
|
||||
SameSite: http.SameSiteStrictMode,
|
||||
MaxAge: -1, // delete the cookie now
|
||||
},
|
||||
}
|
||||
out.Body.Status = "logged out"
|
||||
return out, nil
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
package auth
|
||||
|
||||
import "github.com/msteinert/pam"
|
||||
|
||||
func Authenticate(username, password string) error {
|
||||
t, err := pam.StartFunc(PAMService, username, func(s pam.Style, msg string) (string, error) {
|
||||
if s == pam.PromptEchoOff {
|
||||
return password, nil
|
||||
}
|
||||
return "", nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return t.Authenticate(0)
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
// PAMService is the name of the PAM service nadir authenticates against.
|
||||
// It maps to /etc/pam.d/<PAMService>. We use a dedicated service rather than
|
||||
// a stock one like "login" so we control exactly which modules run during
|
||||
// authentication. See README.md ("PAM service") for the full rationale.
|
||||
const PAMService = "nadir"
|
||||
|
||||
var pamServicePath = "/etc/pam.d/" + PAMService
|
||||
|
||||
// pamServiceContent is the minimal stack nadir needs: verify the password
|
||||
// against /etc/shadow and confirm the account is valid. It deliberately omits
|
||||
// pam_fprintd (blocks ~30s waiting for a fingerprint swipe that never comes),
|
||||
// pam_systemd, pam_env, and the rest of the distro's login stack.
|
||||
const pamServiceContent = `#%PAM-1.0
|
||||
# Managed by nadir. Do not rely on hand edits surviving - nadir recreates this
|
||||
# file on startup only if it is missing. Minimal auth stack: verify the
|
||||
# password against /etc/shadow and confirm the account is valid. Deliberately
|
||||
# omits pam_fprintd (blocks ~30s on a fingerprint swipe), pam_systemd, pam_env.
|
||||
auth required pam_unix.so
|
||||
account required pam_unix.so
|
||||
`
|
||||
|
||||
// EnsurePAMService writes the PAM service file if it is missing. nadir already
|
||||
// runs as root (pam_unix needs to read /etc/shadow), so it can install its own
|
||||
// PAM config rather than relying on a separate install step.
|
||||
//
|
||||
// It will not overwrite an existing file: an admin who has customized
|
||||
// /etc/pam.d/nadir keeps their version. Returns an error only on a real I/O
|
||||
// problem so main can fail loudly instead of later looking like bad
|
||||
// credentials (a missing file falls through to pam_deny via /etc/pam.d/other).
|
||||
func EnsurePAMService() error {
|
||||
switch _, err := os.Stat(pamServicePath); {
|
||||
case err == nil:
|
||||
return nil // already present - leave admin customizations intact
|
||||
case os.IsNotExist(err):
|
||||
// fall through and create it
|
||||
default:
|
||||
return fmt.Errorf("stat %s: %w", pamServicePath, err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(pamServicePath, []byte(pamServiceContent), 0644); err != nil {
|
||||
return fmt.Errorf("write %s: %w (need root, and /etc/pam.d must be writable)", pamServicePath, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,99 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
const sessionTTL = 24 * time.Hour
|
||||
|
||||
type Session struct {
|
||||
Username string
|
||||
}
|
||||
|
||||
// SessionStore persists sessions in an embedded SQLite database so they survive
|
||||
// process restarts. Expired rows are dropped lazily on read.
|
||||
type SessionStore struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewSessionStore opens (creating if needed) the SQLite database at path and
|
||||
// ensures the sessions table exists. The parent directory is created too.
|
||||
func NewSessionStore(path string) (*SessionStore, error) {
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0700); err != nil {
|
||||
return nil, fmt.Errorf("session db dir: %w", err)
|
||||
}
|
||||
db, err := sql.Open("sqlite", path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open session db: %w", err)
|
||||
}
|
||||
// ponytail: single connection serializes access, avoiding SQLite's
|
||||
// "database is locked" under concurrent writes. A sysadmin panel's session
|
||||
// rate is trivial; switch to WAL + a real pool only if that ever bites.
|
||||
db.SetMaxOpenConns(1)
|
||||
if _, err := db.Exec(`CREATE TABLE IF NOT EXISTS sessions (
|
||||
token TEXT PRIMARY KEY,
|
||||
username TEXT NOT NULL,
|
||||
expires_at INTEGER NOT NULL
|
||||
)`); err != nil {
|
||||
return nil, fmt.Errorf("create sessions table: %w", err)
|
||||
}
|
||||
return &SessionStore{db: db}, nil
|
||||
}
|
||||
|
||||
// Ping reports whether the session database is reachable. Used by the health
|
||||
// check.
|
||||
func (s *SessionStore) Ping() error { return s.db.Ping() }
|
||||
|
||||
func (s *SessionStore) Create(username string) (string, error) {
|
||||
token := randomToken()
|
||||
expires := time.Now().Add(sessionTTL)
|
||||
if _, err := s.db.Exec(
|
||||
`INSERT INTO sessions (token, username, expires_at) VALUES (?, ?, ?)`,
|
||||
token, username, expires.Unix(),
|
||||
); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// Delete removes a session, invalidating it immediately (logout). Deleting an
|
||||
// unknown token is a no-op.
|
||||
func (s *SessionStore) Delete(token string) error {
|
||||
_, err := s.db.Exec(`DELETE FROM sessions WHERE token = ?`, token)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *SessionStore) GetByToken(token string) (Session, bool) {
|
||||
var username string
|
||||
var expires int64
|
||||
err := s.db.QueryRow(
|
||||
`SELECT username, expires_at FROM sessions WHERE token = ?`, token,
|
||||
).Scan(&username, &expires)
|
||||
if err != nil {
|
||||
return Session{}, false
|
||||
}
|
||||
if time.Now().Unix() > expires {
|
||||
s.db.Exec(`DELETE FROM sessions WHERE token = ?`, token)
|
||||
return Session{}, false
|
||||
}
|
||||
return Session{Username: username}, true
|
||||
}
|
||||
|
||||
func randomToken() string {
|
||||
b := make([]byte, 32)
|
||||
// crypto/rand.Read never returns an error on supported platforms; if it
|
||||
// somehow does, an all-zero (guessable) token would be a security hole, so
|
||||
// fail hard rather than hand one out.
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
panic("crypto/rand failed: " + err.Error())
|
||||
}
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
@@ -0,0 +1,88 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSessionPersistsAcrossReopen(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "sessions.db")
|
||||
|
||||
store, err := NewSessionStore(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
token, err := store.Create("urania")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Reopen the same file: a fresh process must still see the session.
|
||||
reopened, err := NewSessionStore(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
sess, ok := reopened.GetByToken(token)
|
||||
if !ok || sess.Username != "urania" {
|
||||
t.Fatalf("session lost after reopen: got %+v ok=%v", sess, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpiredSessionRejected(t *testing.T) {
|
||||
store, err := NewSessionStore(filepath.Join(t.TempDir(), "sessions.db"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Write a row that expired an hour ago, bypassing Create's TTL.
|
||||
_, err = store.db.Exec(
|
||||
`INSERT INTO sessions (token, username, expires_at) VALUES (?, ?, ?)`,
|
||||
"stale", "urania", time.Now().Add(-time.Hour).Unix(),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, ok := store.GetByToken("stale"); ok {
|
||||
t.Fatal("expired session was accepted")
|
||||
}
|
||||
// Lazy cleanup should have deleted the row.
|
||||
var n int
|
||||
store.db.QueryRow(`SELECT count(*) FROM sessions WHERE token = ?`, "stale").Scan(&n)
|
||||
if n != 0 {
|
||||
t.Fatalf("expired row not cleaned up: %d rows remain", n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteInvalidatesSession(t *testing.T) {
|
||||
store, err := NewSessionStore(filepath.Join(t.TempDir(), "sessions.db"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
token, err := store.Create("urania")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, ok := store.GetByToken(token); !ok {
|
||||
t.Fatal("session should exist before logout")
|
||||
}
|
||||
if err := store.Delete(token); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, ok := store.GetByToken(token); ok {
|
||||
t.Fatal("session still valid after logout")
|
||||
}
|
||||
// Deleting an unknown/already-deleted token is a no-op, not an error.
|
||||
if err := store.Delete(token); err != nil {
|
||||
t.Errorf("deleting unknown token should be a no-op, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnknownTokenRejected(t *testing.T) {
|
||||
store, err := NewSessionStore(filepath.Join(t.TempDir(), "sessions.db"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, ok := store.GetByToken("nope"); ok {
|
||||
t.Fatal("unknown token was accepted")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,110 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// failLimiter throttles repeated login failures keyed by "username|ip". After
|
||||
// max failures it imposes a cooldown of window before any further attempt for
|
||||
// that key is accepted, blunting brute force against PAM/shadow.
|
||||
//
|
||||
// ponytail: in-memory, single-process, fixed cooldown - correct for a single
|
||||
// app instance. State is lost on restart, but only an operator restarts the
|
||||
// process (an attacker can't), so persisting it would buy nothing. Source
|
||||
// spoofing is handled by the network (VPN-only + trusted proxy sets XFF), not
|
||||
// here. Reach for pam_faillock only if a failed web login should also lock the
|
||||
// OS account against ssh/console - a different layer we deliberately don't span.
|
||||
type failLimiter struct {
|
||||
mu sync.Mutex
|
||||
attempts map[string]*attemptState
|
||||
max int
|
||||
window time.Duration
|
||||
}
|
||||
|
||||
type attemptState struct {
|
||||
count int
|
||||
until time.Time
|
||||
}
|
||||
|
||||
// maxTrackedKeys bounds memory: an attacker rotating username/IP can't grow the
|
||||
// map without limit. When exceeded we drop all throttle state - a crude reset
|
||||
// that briefly forgets cooldowns, acceptable for a single-node panel.
|
||||
const maxTrackedKeys = 10000
|
||||
|
||||
func newFailLimiter(max int, window time.Duration) *failLimiter {
|
||||
return &failLimiter{attempts: map[string]*attemptState{}, max: max, window: window}
|
||||
}
|
||||
|
||||
// blocked reports whether the key is currently in cooldown.
|
||||
func (l *failLimiter) blocked(key string) bool {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
s := l.attempts[key]
|
||||
return s != nil && time.Now().Before(s.until)
|
||||
}
|
||||
|
||||
// fail records a failed attempt and starts a cooldown once max is reached.
|
||||
func (l *failLimiter) fail(key string) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
if len(l.attempts) > maxTrackedKeys {
|
||||
l.attempts = map[string]*attemptState{}
|
||||
}
|
||||
s := l.attempts[key]
|
||||
if s == nil {
|
||||
s = &attemptState{}
|
||||
l.attempts[key] = s
|
||||
}
|
||||
s.count++
|
||||
if s.count >= l.max {
|
||||
s.until = time.Now().Add(l.window)
|
||||
s.count = 0 // restart the window after the cooldown is set
|
||||
}
|
||||
}
|
||||
|
||||
// reset clears state for a key after a successful login.
|
||||
func (l *failLimiter) reset(key string) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
delete(l.attempts, key)
|
||||
}
|
||||
|
||||
type ctxKey int
|
||||
|
||||
const clientIPKey ctxKey = 0
|
||||
|
||||
// WithClientIP wraps a handler so the client's IP is available downstream via
|
||||
// ClientIP, so the login throttle can key on source IP without coupling to a
|
||||
// specific HTTP adapter.
|
||||
//
|
||||
// When trustProxy is set, the IP is taken from the first X-Forwarded-For hop
|
||||
// (the original client behind the reverse proxy). X-Forwarded-For is fully
|
||||
// caller-controlled, so this is only honored when an admin has opted in by
|
||||
// putting nadir behind a proxy - and nadir must then be reachable only by that
|
||||
// proxy. When trustProxy is false the header is ignored and RemoteAddr wins.
|
||||
func WithClientIP(trustProxy bool, next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ip, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
ip = r.RemoteAddr
|
||||
}
|
||||
if trustProxy {
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
first, _, _ := strings.Cut(xff, ",")
|
||||
ip = strings.TrimSpace(first)
|
||||
}
|
||||
}
|
||||
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), clientIPKey, ip)))
|
||||
})
|
||||
}
|
||||
|
||||
// ClientIP returns the IP recorded by WithClientIP, or "" if absent.
|
||||
func ClientIP(ctx context.Context) string {
|
||||
ip, _ := ctx.Value(clientIPKey).(string)
|
||||
return ip
|
||||
}
|
||||
@@ -0,0 +1,131 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
// tokenPrefix marks a Nadir machine credential, like GitHub's "ghp_". It's
|
||||
// cosmetic (helps secret scanners and humans recognize a leaked token), not a
|
||||
// security boundary - the prefix is hashed and stored along with the rest.
|
||||
const tokenPrefix = "nad_"
|
||||
|
||||
// TokenStore persists machine credentials for non-interactive callers (a
|
||||
// central dashboard managing N nodes), so they authenticate with a static
|
||||
// Bearer token instead of a per-host PAM session. Only the SHA-256 of each
|
||||
// token is stored - a leaked DB or backup can't hand out live credentials.
|
||||
//
|
||||
// It lives in its own SQLite file because both the server (read, on every
|
||||
// Bearer request) and the `nadir token` CLI (write, when minting/revoking)
|
||||
// touch it. WAL + a busy timeout let those two processes share the file without
|
||||
// "database is locked".
|
||||
type TokenStore struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// TokenInfo is one stored credential's public metadata (never the secret).
|
||||
type TokenInfo struct {
|
||||
Name string
|
||||
Created time.Time
|
||||
}
|
||||
|
||||
// NewTokenStore opens (creating if needed) the token database at path.
|
||||
func NewTokenStore(path string) (*TokenStore, error) {
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0700); err != nil {
|
||||
return nil, fmt.Errorf("token db dir: %w", err)
|
||||
}
|
||||
// WAL + busy_timeout: the server process and the CLI process both open this
|
||||
// file, so a plain rollback journal would surface transient lock errors.
|
||||
dsn := path + "?_pragma=busy_timeout(5000)&_pragma=journal_mode(WAL)"
|
||||
db, err := sql.Open("sqlite", dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open token db: %w", err)
|
||||
}
|
||||
if _, err := db.Exec(`CREATE TABLE IF NOT EXISTS tokens (
|
||||
name TEXT PRIMARY KEY,
|
||||
token_hash TEXT NOT NULL UNIQUE,
|
||||
created_at INTEGER NOT NULL
|
||||
)`); err != nil {
|
||||
return nil, fmt.Errorf("create tokens table: %w", err)
|
||||
}
|
||||
return &TokenStore{db: db}, nil
|
||||
}
|
||||
|
||||
func (s *TokenStore) Close() error { return s.db.Close() }
|
||||
|
||||
// Create mints a new token named name and returns the raw secret. The secret is
|
||||
// shown only here, at generation time - only its hash is stored. A duplicate
|
||||
// name is rejected (the caller should revoke and re-mint to rotate).
|
||||
func (s *TokenStore) Create(name string) (string, error) {
|
||||
if name == "" {
|
||||
return "", fmt.Errorf("token name required")
|
||||
}
|
||||
raw := tokenPrefix + randomToken()
|
||||
if _, err := s.db.Exec(
|
||||
`INSERT INTO tokens (name, token_hash, created_at) VALUES (?, ?, ?)`,
|
||||
name, hashToken(raw), time.Now().Unix(),
|
||||
); err != nil {
|
||||
return "", fmt.Errorf("store token %q (already exists?): %w", name, err)
|
||||
}
|
||||
return raw, nil
|
||||
}
|
||||
|
||||
// Lookup returns the token name for a presented raw secret.
|
||||
//
|
||||
// No constant-time compare is needed: we look the secret up by its SHA-256, so
|
||||
// the value compared in the index is already a hash of attacker-controlled
|
||||
// input. A timing side-channel could at most leak a stored hash, which is
|
||||
// useless without a SHA-256 preimage (the actual token).
|
||||
func (s *TokenStore) Lookup(raw string) (string, bool) {
|
||||
if !strings.HasPrefix(raw, tokenPrefix) {
|
||||
return "", false
|
||||
}
|
||||
var name string
|
||||
err := s.db.QueryRow(
|
||||
`SELECT name FROM tokens WHERE token_hash = ?`, hashToken(raw),
|
||||
).Scan(&name)
|
||||
if err != nil {
|
||||
return "", false
|
||||
}
|
||||
return name, true
|
||||
}
|
||||
|
||||
// Delete revokes a token by name. Revoking an unknown name is a no-op. The
|
||||
// change is effective immediately - the server reads this DB live, no restart.
|
||||
func (s *TokenStore) Delete(name string) error {
|
||||
_, err := s.db.Exec(`DELETE FROM tokens WHERE name = ?`, name)
|
||||
return err
|
||||
}
|
||||
|
||||
// List returns all token names with their creation time, newest first.
|
||||
func (s *TokenStore) List() ([]TokenInfo, error) {
|
||||
rows, err := s.db.Query(`SELECT name, created_at FROM tokens ORDER BY created_at DESC`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
infos := []TokenInfo{}
|
||||
for rows.Next() {
|
||||
var t TokenInfo
|
||||
var created int64
|
||||
if err := rows.Scan(&t.Name, &created); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.Created = time.Unix(created, 0).UTC()
|
||||
infos = append(infos, t)
|
||||
}
|
||||
return infos, rows.Err()
|
||||
}
|
||||
|
||||
func hashToken(raw string) string {
|
||||
sum := sha256.Sum256([]byte(raw))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
@@ -0,0 +1,45 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestTokenStore(t *testing.T) {
|
||||
store, err := NewTokenStore(filepath.Join(t.TempDir(), "tokens.db"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
raw, err := store.Create("dash")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(raw) < len(tokenPrefix)+32 || raw[:len(tokenPrefix)] != tokenPrefix {
|
||||
t.Fatalf("token %q lacks %q prefix or is too short", raw, tokenPrefix)
|
||||
}
|
||||
|
||||
// Round-trip: the minted secret resolves to its name.
|
||||
if name, ok := store.Lookup(raw); !ok || name != "dash" {
|
||||
t.Errorf("Lookup(valid) = %q,%v; want dash,true", name, ok)
|
||||
}
|
||||
// A wrong secret (and a non-prefixed one) must not resolve.
|
||||
if _, ok := store.Lookup(tokenPrefix + "wrong"); ok {
|
||||
t.Error("Lookup(wrong) succeeded")
|
||||
}
|
||||
if _, ok := store.Lookup("no-prefix"); ok {
|
||||
t.Error("Lookup(no prefix) succeeded")
|
||||
}
|
||||
// Duplicate name is rejected.
|
||||
if _, err := store.Create("dash"); err == nil {
|
||||
t.Error("Create duplicate name succeeded; want error")
|
||||
}
|
||||
// Revocation is immediate.
|
||||
if err := store.Delete("dash"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, ok := store.Lookup(raw); ok {
|
||||
t.Error("Lookup after Delete succeeded")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,173 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"nadir/internal/module"
|
||||
"nadir/internal/rbac"
|
||||
)
|
||||
|
||||
// File is the on-disk YAML shape. Module keys and permission values may be
|
||||
// "*" (must be quoted in YAML to avoid alias syntax).
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// server:
|
||||
// secure_tls: false
|
||||
// roles:
|
||||
// admin:
|
||||
// "*": ["*"]
|
||||
// auditor:
|
||||
// "*": [read]
|
||||
// assignments:
|
||||
// urania: [admin]
|
||||
//
|
||||
// LogFiles is the allowlist of log files readable per unit via the file log
|
||||
// source. The path query parameter is matched against this set, so an admin
|
||||
// (not the caller) decides which files are exposable. Example:
|
||||
//
|
||||
// log_files:
|
||||
// nginx.service:
|
||||
// - /var/log/nginx/access.log
|
||||
// - /var/log/nginx/error.log
|
||||
type File struct {
|
||||
Server Server `yaml:"server"`
|
||||
Roles map[string]map[string][]string `yaml:"roles"`
|
||||
Assignments map[string][]string `yaml:"assignments"`
|
||||
LogFiles map[string][]string `yaml:"log_files"`
|
||||
}
|
||||
|
||||
// Server holds process-level settings.
|
||||
type Server struct {
|
||||
// SecureTLS controls the Secure attribute on the session cookie. A pointer
|
||||
// so an omitted key means "unset" (defaults to true / production-safe)
|
||||
// rather than the zero value false. Set false for local HTTP development.
|
||||
SecureTLS *bool `yaml:"secure_tls"`
|
||||
Hostname string `yaml:"hostname"`
|
||||
Port string `yaml:"port"`
|
||||
// TLS is provided one of three ways, in priority order:
|
||||
// 1. TrustProxy true - a reverse proxy (e.g. https://example.com) does
|
||||
// TLS; nadir serves plaintext HTTP and trusts X-Forwarded-For. nadir
|
||||
// must then be reachable ONLY by the proxy (bind localhost).
|
||||
// 2. TLSCert+TLSKey - nadir terminates TLS with this PEM pair.
|
||||
// 3. neither - nadir self-signs in memory (dev only).
|
||||
// Keep secure_tls true in modes 1 and 2 so the session cookie stays Secure.
|
||||
TrustProxy bool `yaml:"trust_proxy"`
|
||||
TLSCert string `yaml:"tls_cert"`
|
||||
TLSKey string `yaml:"tls_key"`
|
||||
}
|
||||
|
||||
// SecureCookie reports whether the session cookie should carry the Secure
|
||||
// attribute, defaulting to true when server.secure_tls is omitted.
|
||||
func (f *File) SecureCookie() bool {
|
||||
if f.Server.SecureTLS == nil {
|
||||
return true
|
||||
}
|
||||
return *f.Server.SecureTLS
|
||||
}
|
||||
|
||||
// DefaultPath returns the default configuration path at
|
||||
// ~/.config/nadir/config.yaml (honoring XDG_CONFIG_HOME via os.UserConfigDir).
|
||||
func DefaultPath() (string, error) {
|
||||
dir, err := os.UserConfigDir()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("detect config dir: %w", err)
|
||||
}
|
||||
return filepath.Join(dir, "nadir", "config.yaml"), nil
|
||||
}
|
||||
|
||||
// ExpandPath expands leading ~ to the current user's home directory.
|
||||
func ExpandPath(path string) (string, error) {
|
||||
if strings.HasPrefix(path, "~/") {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("detect home dir: %w", err)
|
||||
}
|
||||
return filepath.Join(home, path[2:]), nil
|
||||
}
|
||||
return path, nil
|
||||
}
|
||||
|
||||
// Load reads and parses the YAML file at path.
|
||||
func Load(path string) (*File, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read config %s: %w", path, err)
|
||||
}
|
||||
var f File
|
||||
if err := yaml.Unmarshal(data, &f); err != nil {
|
||||
return nil, fmt.Errorf("parse config %s: %w", path, err)
|
||||
}
|
||||
return &f, nil
|
||||
}
|
||||
|
||||
// Apply validates the config against the loaded modules and installs roles
|
||||
// + assignments into the RBAC store. Any reference to an unknown module or
|
||||
// unknown permission causes a startup error.
|
||||
func Apply(f *File, roles *rbac.RBAC, mods []module.Module) error {
|
||||
// Build a lookup: module ID -> the permissions it exposes.
|
||||
knownPerms := map[string]map[rbac.Permission]bool{}
|
||||
for _, m := range mods {
|
||||
set := map[rbac.Permission]bool{}
|
||||
for _, p := range m.Permissions() {
|
||||
set[p] = true
|
||||
}
|
||||
knownPerms[m.ID()] = set
|
||||
}
|
||||
|
||||
// 1. Define roles, validating every named module + permission.
|
||||
for roleName, grants := range f.Roles {
|
||||
converted := map[string][]rbac.Permission{}
|
||||
for modKey, perms := range grants {
|
||||
if modKey != rbac.Wildcard {
|
||||
if _, ok := knownPerms[modKey]; !ok {
|
||||
return fmt.Errorf("role %q references unknown module %q", roleName, modKey)
|
||||
}
|
||||
}
|
||||
permList := make([]rbac.Permission, 0, len(perms))
|
||||
for _, raw := range perms {
|
||||
p := rbac.Permission(raw)
|
||||
if p == rbac.All {
|
||||
permList = append(permList, p)
|
||||
continue
|
||||
}
|
||||
if !permissionExists(p, modKey, knownPerms) {
|
||||
return fmt.Errorf("role %q grants %q on module %q, but that permission is not exported by any matching module", roleName, raw, modKey)
|
||||
}
|
||||
permList = append(permList, p)
|
||||
}
|
||||
converted[modKey] = permList
|
||||
}
|
||||
roles.DefineRole(rbac.Role{Name: roleName, ModuleGrants: converted})
|
||||
}
|
||||
|
||||
// 2. Apply assignments, validating role names.
|
||||
for user, assigned := range f.Assignments {
|
||||
for _, roleName := range assigned {
|
||||
if !roles.RoleExists(roleName) {
|
||||
return fmt.Errorf("user %q assigned to unknown role %q", user, roleName)
|
||||
}
|
||||
roles.AssignRole(user, roleName)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// permissionExists returns true if perm is exported by the named module,
|
||||
// or by any module when modKey is the wildcard.
|
||||
func permissionExists(perm rbac.Permission, modKey string, known map[string]map[rbac.Permission]bool) bool {
|
||||
if modKey == rbac.Wildcard {
|
||||
for _, perms := range known {
|
||||
if perms[perm] {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
return known[modKey][perm]
|
||||
}
|
||||
@@ -0,0 +1,152 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"nadir/internal/module"
|
||||
"nadir/internal/rbac"
|
||||
|
||||
"github.com/danielgtaylor/huma/v2"
|
||||
)
|
||||
|
||||
// fakeModule implements module.Module with a fixed permission set, so config
|
||||
// tests don't depend on the concrete modules' exec behavior.
|
||||
type fakeModule struct {
|
||||
id string
|
||||
perms []rbac.Permission
|
||||
}
|
||||
|
||||
func (f fakeModule) ID() string { return f.id }
|
||||
func (f fakeModule) Name() string { return f.id }
|
||||
func (f fakeModule) Permissions() []rbac.Permission { return f.perms }
|
||||
func (f fakeModule) Register(huma.API) {}
|
||||
|
||||
func mods() []module.Module {
|
||||
return []module.Module{
|
||||
fakeModule{id: "system", perms: []rbac.Permission{rbac.Read, rbac.Write, rbac.Root}},
|
||||
fakeModule{id: "services", perms: []rbac.Permission{rbac.Read, rbac.Write}},
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureCookieDefaultsTrue(t *testing.T) {
|
||||
if !(&File{}).SecureCookie() {
|
||||
t.Error("omitted secure_tls should default to true")
|
||||
}
|
||||
no := false
|
||||
if (&File{Server: Server{SecureTLS: &no}}).SecureCookie() {
|
||||
t.Error("secure_tls: false should disable the Secure flag")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "config.yaml")
|
||||
os.WriteFile(path, []byte(`
|
||||
server:
|
||||
secure_tls: false
|
||||
roles:
|
||||
admin:
|
||||
"*": ["*"]
|
||||
assignments:
|
||||
urania: [admin]
|
||||
`), 0600)
|
||||
|
||||
f, err := Load(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if f.SecureCookie() {
|
||||
t.Error("secure_tls: false not parsed")
|
||||
}
|
||||
if len(f.Roles["admin"]) == 0 || len(f.Assignments["urania"]) == 0 {
|
||||
t.Error("roles/assignments not parsed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadMissingFile(t *testing.T) {
|
||||
if _, err := Load(filepath.Join(t.TempDir(), "nope.yaml")); err == nil {
|
||||
t.Fatal("expected error for missing file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyValid(t *testing.T) {
|
||||
f := &File{
|
||||
Roles: map[string]map[string][]string{
|
||||
"admin": {"*": {"*"}},
|
||||
"auditor": {"*": {"read"}},
|
||||
"sysop": {"system": {"read", "root"}},
|
||||
},
|
||||
Assignments: map[string][]string{"urania": {"admin"}, "bob": {"sysop"}},
|
||||
}
|
||||
roles := rbac.New()
|
||||
if err := Apply(f, roles, mods()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !roles.Can("urania", "services", rbac.Write) {
|
||||
t.Error("admin wildcard not applied")
|
||||
}
|
||||
if !roles.Can("bob", "system", rbac.Root) || roles.Can("bob", "system", rbac.Write) {
|
||||
t.Error("sysop grants not applied correctly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyErrors(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
f *File
|
||||
}{
|
||||
{"unknown module", &File{Roles: map[string]map[string][]string{
|
||||
"r": {"firewall": {"read"}}}}},
|
||||
{"unknown perm on known module", &File{Roles: map[string]map[string][]string{
|
||||
"r": {"system": {"banana"}}}}},
|
||||
{"wildcard module with perm no module exports", &File{Roles: map[string]map[string][]string{
|
||||
"r": {"*": {"banana"}}}}},
|
||||
{"assignment to unknown role", &File{
|
||||
Roles: map[string]map[string][]string{"r": {"system": {"read"}}},
|
||||
Assignments: map[string][]string{"u": {"ghost"}}}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := Apply(tt.f, rbac.New(), mods()); err == nil {
|
||||
t.Errorf("expected error for %s", tt.name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyWildcardPermAlwaysOK(t *testing.T) {
|
||||
f := &File{Roles: map[string]map[string][]string{"r": {"system": {"*"}}}}
|
||||
if err := Apply(f, rbac.New(), mods()); err != nil {
|
||||
t.Fatalf("wildcard permission should validate: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultPathAndExpandPath(t *testing.T) {
|
||||
defaultPath, err := DefaultPath()
|
||||
if err != nil {
|
||||
t.Fatalf("DefaultPath failed: %v", err)
|
||||
}
|
||||
expectedSuffix := filepath.Join("nadir", "config.yaml")
|
||||
if !filepath.IsAbs(defaultPath) || !strings.HasSuffix(defaultPath, expectedSuffix) {
|
||||
t.Errorf("expected default path to end with %q and be absolute, got %q", expectedSuffix, defaultPath)
|
||||
}
|
||||
|
||||
expanded, err := ExpandPath("~/foo/config.yaml")
|
||||
if err != nil {
|
||||
t.Fatalf("ExpandPath failed: %v", err)
|
||||
}
|
||||
if !filepath.IsAbs(expanded) || !strings.HasSuffix(expanded, filepath.Join("foo", "config.yaml")) {
|
||||
t.Errorf("ExpandPath did not resolve ~/ correctly, got %q", expanded)
|
||||
}
|
||||
|
||||
plain := "/etc/nadir/config.yaml"
|
||||
expandedPlain, err := ExpandPath(plain)
|
||||
if err != nil {
|
||||
t.Fatalf("ExpandPath failed for plain path: %v", err)
|
||||
}
|
||||
if expandedPlain != plain {
|
||||
t.Errorf("expected no-op for %q, got %q", plain, expandedPlain)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
package meta
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"nadir"
|
||||
"nadir/internal/auth"
|
||||
|
||||
"github.com/danielgtaylor/huma/v2"
|
||||
)
|
||||
|
||||
type HealthOutput struct {
|
||||
Body struct {
|
||||
Status string `json:"status" example:"ok" doc:"Overall health"`
|
||||
Database string `json:"database" example:"ok" doc:"Embedded SQLite session store state"`
|
||||
Version string `json:"version" example:"1.0.0" doc:"Application version"`
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterHealth adds a public liveness/readiness probe. It is intentionally
|
||||
// unauthenticated (no permission metadata) so load balancers and orchestrators
|
||||
// can reach it. Returns 503 when the SQLite session store is unreachable, so
|
||||
// probes can key off the status code without parsing the body.
|
||||
func RegisterHealth(api huma.API, sessions *auth.SessionStore) {
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "health",
|
||||
Method: "GET",
|
||||
Path: "/api/health",
|
||||
Summary: "Health check",
|
||||
Description: "Public liveness/readiness probe. Reports whether the embedded " +
|
||||
"SQLite session store is reachable. Returns 503 when it is not.",
|
||||
Tags: []string{"Meta"},
|
||||
Errors: []int{503},
|
||||
}, func(ctx context.Context, _ *struct{}) (*HealthOutput, error) {
|
||||
if err := sessions.Ping(); err != nil {
|
||||
return nil, huma.Error503ServiceUnavailable("session database unreachable", err)
|
||||
}
|
||||
out := &HealthOutput{}
|
||||
out.Body.Status = "ok"
|
||||
out.Body.Database = "ok"
|
||||
out.Body.Version = nadir.Version
|
||||
return out, nil
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
package meta
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"slices"
|
||||
|
||||
"nadir/internal/module"
|
||||
|
||||
"github.com/danielgtaylor/huma/v2"
|
||||
)
|
||||
|
||||
// ModuleInfo describes a registered module: its stable ID, display name, and
|
||||
// the permissions it exposes. The frontend uses this to drive navigation and
|
||||
// render the role/permission matrix.
|
||||
type ModuleInfo struct {
|
||||
ID string `json:"id" example:"system" doc:"Stable module identifier"`
|
||||
Name string `json:"name" example:"System" doc:"Human-readable module name"`
|
||||
Permissions []string `json:"permissions" doc:"Permissions this module exposes (never includes the \"*\" wildcard)"`
|
||||
}
|
||||
|
||||
type ModulesOutput struct {
|
||||
Body struct {
|
||||
Modules []ModuleInfo `json:"modules" doc:"Registered modules, sorted by ID"`
|
||||
}
|
||||
}
|
||||
|
||||
// Register adds the read-only module-discovery endpoint. It is intentionally
|
||||
// public: it exposes only the API's static shape (module IDs and permission
|
||||
// vocabulary), the same information already served by /openapi.json. The module
|
||||
// list is fixed at startup, so the response is computed once here.
|
||||
func Register(api huma.API, mods []module.Module) {
|
||||
infos := make([]ModuleInfo, 0, len(mods))
|
||||
for _, m := range mods {
|
||||
perms := m.Permissions()
|
||||
ps := make([]string, len(perms))
|
||||
for i, p := range perms {
|
||||
ps[i] = string(p)
|
||||
}
|
||||
infos = append(infos, ModuleInfo{ID: m.ID(), Name: m.Name(), Permissions: ps})
|
||||
}
|
||||
slices.SortFunc(infos, func(a, b ModuleInfo) int { return cmp.Compare(a.ID, b.ID) })
|
||||
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "list-modules",
|
||||
Method: "GET",
|
||||
Path: "/api/_modules",
|
||||
Summary: "List registered modules",
|
||||
Description: "Returns every registered module with its ID, display name, " +
|
||||
"and exported permissions. Public (same static shape as /openapi.json); " +
|
||||
"used by the frontend for navigation and the role/permission matrix.",
|
||||
Tags: []string{"Meta"},
|
||||
}, func(ctx context.Context, _ *struct{}) (*ModulesOutput, error) {
|
||||
out := &ModulesOutput{}
|
||||
out.Body.Modules = infos
|
||||
return out, nil
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
package meta
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"nadir/internal/auth"
|
||||
"nadir/internal/module"
|
||||
"nadir/internal/rbac"
|
||||
|
||||
"github.com/danielgtaylor/huma/v2"
|
||||
)
|
||||
|
||||
// WhoamiInput carries the session cookie. The endpoint is not behind the RBAC
|
||||
// middleware (it requires no specific permission), so it validates the session
|
||||
// itself.
|
||||
type WhoamiInput struct {
|
||||
SessionID string `cookie:"nadir_session_id"`
|
||||
}
|
||||
|
||||
// WhoamiBody reports who the caller is and, per module, which permissions they
|
||||
// actually hold. Combined with /api/_modules (the full module/permission grid),
|
||||
// this gives the frontend everything it needs to render the permission matrix.
|
||||
type WhoamiBody struct {
|
||||
Username string `json:"username" example:"urania" doc:"Authenticated username"`
|
||||
Permissions map[string][]string `json:"permissions" doc:"Module ID -> permissions the caller holds. Modules where they hold none are omitted."`
|
||||
}
|
||||
|
||||
type WhoamiOutput struct{ Body WhoamiBody }
|
||||
|
||||
// RegisterWhoami adds the current-user endpoint. It resolves the caller's
|
||||
// concrete grants by asking the RBAC store about each module's permissions,
|
||||
// so "*" wildcards in roles are expanded for free.
|
||||
func RegisterWhoami(api huma.API, sessions *auth.SessionStore, roles *rbac.RBAC, mods []module.Module) {
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "whoami",
|
||||
Method: "GET",
|
||||
Path: "/api/whoami",
|
||||
Summary: "Get the current user and their permissions",
|
||||
Description: "Returns the authenticated username and, per module, the " +
|
||||
"permissions the caller holds (wildcards resolved). Pair with " +
|
||||
"/api/_modules to render the full permission matrix.",
|
||||
Tags: []string{"Meta"},
|
||||
Errors: []int{401},
|
||||
}, func(ctx context.Context, in *WhoamiInput) (*WhoamiOutput, error) {
|
||||
sess, ok := sessions.GetByToken(in.SessionID)
|
||||
if !ok {
|
||||
return nil, huma.Error401Unauthorized("unauthorized")
|
||||
}
|
||||
|
||||
held := make(map[string][]string)
|
||||
for _, m := range mods {
|
||||
var perms []string
|
||||
for _, p := range m.Permissions() {
|
||||
if roles.Can(sess.Username, m.ID(), p) {
|
||||
perms = append(perms, string(p))
|
||||
}
|
||||
}
|
||||
if len(perms) > 0 {
|
||||
held[m.ID()] = perms
|
||||
}
|
||||
}
|
||||
|
||||
out := &WhoamiOutput{}
|
||||
out.Body = WhoamiBody{Username: sess.Username, Permissions: held}
|
||||
return out, nil
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
package module
|
||||
|
||||
import (
|
||||
"nadir/internal/rbac"
|
||||
|
||||
"github.com/danielgtaylor/huma/v2"
|
||||
)
|
||||
|
||||
type Module interface {
|
||||
ID() string
|
||||
Name() string
|
||||
Permissions() []rbac.Permission // permissions this module exposes (no "*")
|
||||
Register(api huma.API)
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
package audit
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"nadir/internal/auditlog"
|
||||
"nadir/internal/rbac"
|
||||
|
||||
"github.com/danielgtaylor/huma/v2"
|
||||
)
|
||||
|
||||
const ModuleID = "audit"
|
||||
|
||||
type Module struct {
|
||||
store *auditlog.Store
|
||||
}
|
||||
|
||||
func New(store *auditlog.Store) *Module { return &Module{store: store} }
|
||||
|
||||
func (m *Module) ID() string { return ModuleID }
|
||||
func (m *Module) Name() string { return "Audit" }
|
||||
|
||||
// Permissions: read to view the audit trail. There is no write - entries are
|
||||
// produced by the middleware, never by an API call.
|
||||
func (m *Module) Permissions() []rbac.Permission {
|
||||
return []rbac.Permission{rbac.Read}
|
||||
}
|
||||
|
||||
// Types are named AuditList* (not ListInput/ListOutput) because Huma derives
|
||||
// OpenAPI schema names from the Go type name alone, not package-qualified, so a
|
||||
// bare "ListOutput" here would collide with the packages module's.
|
||||
type AuditListInput struct {
|
||||
Limit int `query:"limit" default:"200" minimum:"1" maximum:"10000" doc:"Max entries to return, newest first"`
|
||||
}
|
||||
|
||||
type AuditListOutput struct {
|
||||
Body struct {
|
||||
Entries []auditlog.Entry `json:"entries" doc:"Recorded actions, newest first"`
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Module) Register(api huma.API) {
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "audit-list",
|
||||
Method: "GET",
|
||||
Path: "/api/audit",
|
||||
Summary: "List recorded actions",
|
||||
Description: "Returns the audit trail of privileged write operations " +
|
||||
"(who, what, when, result), newest first.",
|
||||
Tags: []string{"Audit"},
|
||||
Metadata: map[string]any{"module": ModuleID, "permission": "read"},
|
||||
Errors: []int{401, 403, 500},
|
||||
}, func(ctx context.Context, in *AuditListInput) (*AuditListOutput, error) {
|
||||
entries, err := m.store.List(in.Limit)
|
||||
if err != nil {
|
||||
return nil, huma.Error500InternalServerError("read audit log failed", err)
|
||||
}
|
||||
out := &AuditListOutput{}
|
||||
out.Body.Entries = entries
|
||||
return out, nil
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,225 @@
|
||||
package groups
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"nadir/internal/oscmd"
|
||||
|
||||
"github.com/danielgtaylor/huma/v2"
|
||||
)
|
||||
|
||||
const tagGroups = "Groups"
|
||||
|
||||
var groupPath = "/etc/group"
|
||||
|
||||
// systemGIDMax mirrors the user convention: regular groups start at 1000.
|
||||
const systemGIDMax = 1000
|
||||
|
||||
var (
|
||||
readErrors = []int{401, 403, 500}
|
||||
writeErrors = []int{400, 401, 403, 404, 409, 500}
|
||||
)
|
||||
|
||||
// groupNameRe matches valid group names (same rule as usernames). Starting with
|
||||
// a letter/underscore also rejects leading-dash flag injection.
|
||||
var groupNameRe = regexp.MustCompile(`^[a-z_][a-z0-9_-]{0,31}\$?$`)
|
||||
|
||||
// Group mirrors one /etc/group entry.
|
||||
type Group struct {
|
||||
Name string `json:"name" example:"wheel" doc:"Group name"`
|
||||
GID int `json:"gid" example:"10" doc:"Group ID"`
|
||||
Members []string `json:"members" doc:"Supplementary members (primary-group members are not listed here)"`
|
||||
System bool `json:"system" doc:"True for system groups (gid < 1000)"`
|
||||
}
|
||||
|
||||
type ListGroupsOutput struct {
|
||||
Body struct {
|
||||
Groups []Group `json:"groups" doc:"All groups from /etc/group"`
|
||||
}
|
||||
}
|
||||
|
||||
type GetGroupOutput struct{ Body Group }
|
||||
|
||||
type GroupPath struct {
|
||||
Group string `path:"group" example:"wheel" doc:"Group name"`
|
||||
}
|
||||
|
||||
type CreateGroupInput struct {
|
||||
Body struct {
|
||||
Name string `json:"name" example:"developers" doc:"Group name"`
|
||||
GID *int `json:"gid,omitempty" example:"1500" doc:"Explicit GID (omit to auto-assign)"`
|
||||
System bool `json:"system,omitempty" doc:"Create a system group (groupadd --system)"`
|
||||
}
|
||||
}
|
||||
|
||||
func registerGroups(api huma.API) {
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "groups-list",
|
||||
Method: "GET",
|
||||
Path: "/api/groups",
|
||||
Summary: "List groups",
|
||||
Description: "Returns every group in /etc/group, including system groups " +
|
||||
"(flagged via `system`).",
|
||||
Tags: []string{tagGroups},
|
||||
Metadata: op("read"),
|
||||
Errors: readErrors,
|
||||
}, func(ctx context.Context, _ *struct{}) (*ListGroupsOutput, error) {
|
||||
list, err := listGroups()
|
||||
if err != nil {
|
||||
return nil, huma.Error500InternalServerError("read "+groupPath+" failed", err)
|
||||
}
|
||||
out := &ListGroupsOutput{}
|
||||
out.Body.Groups = list
|
||||
return out, nil
|
||||
})
|
||||
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "groups-get",
|
||||
Method: "GET",
|
||||
Path: "/api/groups/{group}",
|
||||
Summary: "Get a single group",
|
||||
Description: "Returns one group by name. 404 if it does not exist.",
|
||||
Tags: []string{tagGroups},
|
||||
Metadata: op("read"),
|
||||
Errors: []int{400, 401, 403, 404, 500},
|
||||
}, func(ctx context.Context, in *GroupPath) (*GetGroupOutput, error) {
|
||||
if err := validateGroupName(in.Group); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
g, ok, err := lookupGroup(in.Group)
|
||||
if err != nil {
|
||||
return nil, huma.Error500InternalServerError("read "+groupPath+" failed", err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, huma.Error404NotFound("group not found: " + in.Group)
|
||||
}
|
||||
return &GetGroupOutput{Body: g}, nil
|
||||
})
|
||||
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "groups-create",
|
||||
Method: "POST",
|
||||
Path: "/api/groups",
|
||||
Summary: "Create a group",
|
||||
Description: "Creates a group via groupadd. 409 if the group already exists.",
|
||||
Tags: []string{tagGroups},
|
||||
Metadata: op("write"),
|
||||
Errors: writeErrors,
|
||||
}, func(ctx context.Context, in *CreateGroupInput) (*GetGroupOutput, error) {
|
||||
if err := validateGroupName(in.Body.Name); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, ok, err := lookupGroup(in.Body.Name); err != nil {
|
||||
return nil, huma.Error500InternalServerError("read "+groupPath+" failed", err)
|
||||
} else if ok {
|
||||
return nil, huma.Error409Conflict("group already exists: " + in.Body.Name)
|
||||
}
|
||||
|
||||
args := []string{}
|
||||
if in.Body.System {
|
||||
args = append(args, "--system")
|
||||
}
|
||||
if in.Body.GID != nil {
|
||||
args = append(args, "-g", strconv.Itoa(*in.Body.GID))
|
||||
}
|
||||
args = append(args, "--", in.Body.Name)
|
||||
if _, err := oscmd.Run("groupadd", args...); err != nil {
|
||||
return nil, huma.Error500InternalServerError("groupadd failed", err)
|
||||
}
|
||||
|
||||
g, ok, err := lookupGroup(in.Body.Name)
|
||||
if err != nil || !ok {
|
||||
return nil, huma.Error500InternalServerError("group created but could not be read back", err)
|
||||
}
|
||||
return &GetGroupOutput{Body: g}, nil
|
||||
})
|
||||
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "groups-delete",
|
||||
Method: "DELETE",
|
||||
Path: "/api/groups/{group}",
|
||||
Summary: "Delete a group",
|
||||
Description: "Removes a group via groupdel. Returns 409 if it is the primary " +
|
||||
"group of an existing user (groupdel refuses), 404 if it does not exist.",
|
||||
Tags: []string{tagGroups},
|
||||
// Deleting a group is irreversible - gated behind root, not write.
|
||||
Metadata: op("root"),
|
||||
Errors: []int{400, 401, 403, 404, 409, 500},
|
||||
}, func(ctx context.Context, in *GroupPath) (*oscmd.StatusOutput, error) {
|
||||
if err := validateGroupName(in.Group); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, ok, err := lookupGroup(in.Group); err != nil {
|
||||
return nil, huma.Error500InternalServerError("read "+groupPath+" failed", err)
|
||||
} else if !ok {
|
||||
return nil, huma.Error404NotFound("group not found: " + in.Group)
|
||||
}
|
||||
if _, err := oscmd.Run("groupdel", "--", in.Group); err != nil {
|
||||
// groupdel refuses to remove a user's primary group; make that actionable.
|
||||
return nil, huma.Error409Conflict("groupdel failed (is it a user's primary group?)", err)
|
||||
}
|
||||
return oscmd.OK(), nil
|
||||
})
|
||||
}
|
||||
|
||||
func validateGroupName(name string) error {
|
||||
if !groupNameRe.MatchString(name) {
|
||||
return huma.Error400BadRequest("invalid group name: " + name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func listGroups() ([]Group, error) {
|
||||
data, err := os.ReadFile(groupPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return parseGroup(data), nil
|
||||
}
|
||||
|
||||
func lookupGroup(name string) (Group, bool, error) {
|
||||
list, err := listGroups()
|
||||
if err != nil {
|
||||
return Group{}, false, err
|
||||
}
|
||||
for _, g := range list {
|
||||
if g.Name == name {
|
||||
return g, true, nil
|
||||
}
|
||||
}
|
||||
return Group{}, false, nil
|
||||
}
|
||||
|
||||
// parseGroup parses /etc/group content. Blank, commented, or malformed lines
|
||||
// (fewer than 4 fields, non-numeric gid) are skipped.
|
||||
func parseGroup(data []byte) []Group {
|
||||
var groups []Group
|
||||
for line := range strings.SplitSeq(string(data), "\n") {
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
f := strings.Split(line, ":")
|
||||
if len(f) < 4 {
|
||||
continue
|
||||
}
|
||||
gid, err := strconv.Atoi(f[2])
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
var members []string
|
||||
if f[3] != "" {
|
||||
members = strings.Split(f[3], ",")
|
||||
}
|
||||
groups = append(groups, Group{
|
||||
Name: f[0],
|
||||
GID: gid,
|
||||
Members: members,
|
||||
System: gid < systemGIDMax,
|
||||
})
|
||||
}
|
||||
return groups
|
||||
}
|
||||
@@ -0,0 +1,109 @@
|
||||
package groups
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"nadir/internal/oscmd"
|
||||
|
||||
"github.com/danielgtaylor/huma/v2"
|
||||
"github.com/danielgtaylor/huma/v2/adapters/humago"
|
||||
"github.com/danielgtaylor/huma/v2/humatest"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
if oscmd.RunHelperProcess() {
|
||||
return
|
||||
}
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
func TestGroupsHandlers(t *testing.T) {
|
||||
tempGroup := filepath.Join(t.TempDir(), "group")
|
||||
initialContent := "root:x:0:\nwheel:x:10:alice,bob\n"
|
||||
if err := os.WriteFile(tempGroup, []byte(initialContent), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
oldGroup := groupPath
|
||||
groupPath = tempGroup
|
||||
defer func() { groupPath = oldGroup }()
|
||||
|
||||
mux := http.NewServeMux()
|
||||
api := humatest.Wrap(t, humago.New(mux, huma.DefaultConfig("Test", "1.0.0")))
|
||||
registerGroups(api)
|
||||
|
||||
// 1. Test GET /api/groups
|
||||
resp := api.Get("/api/groups")
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("list groups: got %d, want %d", resp.Code, http.StatusOK)
|
||||
}
|
||||
var listRes ListGroupsOutput
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &listRes.Body); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(listRes.Body.Groups) != 2 {
|
||||
t.Errorf("got %d groups, want 2", len(listRes.Body.Groups))
|
||||
}
|
||||
|
||||
// 2. Test GET /api/groups/{group}
|
||||
resp = api.Get("/api/groups/wheel")
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("get group: got %d, want %d", resp.Code, http.StatusOK)
|
||||
}
|
||||
var getRes GetGroupOutput
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &getRes.Body); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if getRes.Body.Name != "wheel" || getRes.Body.GID != 10 {
|
||||
t.Errorf("get group: got %+v", getRes.Body)
|
||||
}
|
||||
|
||||
resp = api.Get("/api/groups/nonexistent")
|
||||
if resp.Code != http.StatusNotFound {
|
||||
t.Errorf("get non-existent group: got %d, want %d", resp.Code, http.StatusNotFound)
|
||||
}
|
||||
|
||||
// 3. Test POST /api/groups
|
||||
oscmd.SetMock("groupadd", func(args []string) oscmd.MockCommand {
|
||||
wantArgs := []string{"-g", "1500", "--", "dev"}
|
||||
if !reflect.DeepEqual(args, wantArgs) {
|
||||
t.Errorf("groupadd args: got %v, want %v", args, wantArgs)
|
||||
}
|
||||
devContent := initialContent + "dev:x:1500:\n"
|
||||
os.WriteFile(tempGroup, []byte(devContent), 0644)
|
||||
return oscmd.MockCommand{ExitCode: 0}
|
||||
})
|
||||
defer oscmd.ClearMocks()
|
||||
|
||||
gidVal := 1500
|
||||
resp = api.Post("/api/groups", struct {
|
||||
Name string `json:"name"`
|
||||
GID *int `json:"gid"`
|
||||
}{
|
||||
Name: "dev",
|
||||
GID: &gidVal,
|
||||
})
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("create group: got %d, want %d", resp.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
// 4. Test DELETE /api/groups/{group}
|
||||
oscmd.SetMock("groupdel", func(args []string) oscmd.MockCommand {
|
||||
wantArgs := []string{"--", "dev"}
|
||||
if !reflect.DeepEqual(args, wantArgs) {
|
||||
t.Errorf("groupdel args: got %v, want %v", args, wantArgs)
|
||||
}
|
||||
os.WriteFile(tempGroup, []byte(initialContent), 0644)
|
||||
return oscmd.MockCommand{ExitCode: 0}
|
||||
})
|
||||
|
||||
resp = api.Delete("/api/groups/dev")
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("delete group: got %d, want %d", resp.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
package groups
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseGroup(t *testing.T) {
|
||||
data := []byte(`root:x:0:
|
||||
# comment
|
||||
|
||||
wheel:x:10:alice,bob
|
||||
developers:x:1500:alice
|
||||
empty:x:1600:
|
||||
broken:x:notanumber:x
|
||||
short:x:5
|
||||
`)
|
||||
got := parseGroup(data)
|
||||
if len(got) != 4 {
|
||||
t.Fatalf("expected 4 valid groups, got %d: %+v", len(got), got)
|
||||
}
|
||||
|
||||
wheel := got[1]
|
||||
if wheel.Name != "wheel" || wheel.GID != 10 || !wheel.System ||
|
||||
!reflect.DeepEqual(wheel.Members, []string{"alice", "bob"}) {
|
||||
t.Errorf("wheel parsed wrong: %+v", wheel)
|
||||
}
|
||||
|
||||
dev := got[2]
|
||||
if dev.GID != 1500 || dev.System {
|
||||
t.Errorf("developers should be a non-system group: %+v", dev)
|
||||
}
|
||||
|
||||
empty := got[3]
|
||||
if len(empty.Members) != 0 {
|
||||
t.Errorf("empty group should have no members, got %v", empty.Members)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateGroupName(t *testing.T) {
|
||||
for _, n := range []string{"wheel", "_svc", "dev-team", "g1"} {
|
||||
if err := validateGroupName(n); err != nil {
|
||||
t.Errorf("validateGroupName(%q) = %v, want nil", n, err)
|
||||
}
|
||||
}
|
||||
for _, n := range []string{"", "-x", "Wheel", "a,b", "foo;rm", "1grp"} {
|
||||
if err := validateGroupName(n); err == nil {
|
||||
t.Errorf("validateGroupName(%q) = nil, want error", n)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
package groups
|
||||
|
||||
import (
|
||||
"nadir/internal/rbac"
|
||||
|
||||
"github.com/danielgtaylor/huma/v2"
|
||||
)
|
||||
|
||||
const ModuleID = "groups"
|
||||
|
||||
type Module struct{}
|
||||
|
||||
func New() *Module { return &Module{} }
|
||||
|
||||
func (m *Module) ID() string { return ModuleID }
|
||||
func (m *Module) Name() string { return "Groups" }
|
||||
|
||||
// Permissions: read to list/inspect groups; write to create; root to delete
|
||||
// (irreversible). Group membership lives in the users module.
|
||||
func (m *Module) Permissions() []rbac.Permission {
|
||||
return []rbac.Permission{rbac.Read, rbac.Write, rbac.Root}
|
||||
}
|
||||
|
||||
func (m *Module) Register(api huma.API) {
|
||||
registerGroups(api)
|
||||
}
|
||||
|
||||
func op(permission string) map[string]any {
|
||||
return map[string]any{"module": ModuleID, "permission": permission}
|
||||
}
|
||||
@@ -0,0 +1,90 @@
|
||||
package networking
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"time"
|
||||
|
||||
"nadir/internal/oscmd"
|
||||
)
|
||||
|
||||
// backend is the write-side abstraction. Each host network manager (nmcli,
|
||||
// networkd, ifupdown) implements this. Reads go through `ip -j` and
|
||||
// /etc/resolv.conf regardless - they are backend-agnostic (see read.go).
|
||||
//
|
||||
// When detect() finds no backend, Module.be is nil and all write endpoints
|
||||
// return 501 Not Implemented. Reads still work.
|
||||
// Methods that shell out take a context so a request that is cancelled (client
|
||||
// disconnect, timeout) kills the slow command (e.g. `nmcli con up` waiting on
|
||||
// DHCP). The timer-driven auto-revert, which must finish even with no client,
|
||||
// passes context.Background().
|
||||
type backend interface {
|
||||
// Name returns the backend identifier ("nmcli", "networkd", "ifupdown").
|
||||
Name() string
|
||||
|
||||
// Snapshot captures the current IPv4 configuration of iface so it can be
|
||||
// restored on rollback. The returned IfaceConfig is backend-specific:
|
||||
// nmcli reads from NM's connection, networkd/ifupdown read their config
|
||||
// files, falling back to live `ip` output when no managed file exists
|
||||
// (in which case Method is "dhcp" - safest revert assumption).
|
||||
Snapshot(ctx context.Context, iface string) (IfaceConfig, error)
|
||||
|
||||
// Apply replaces the interface's IPv4 configuration with cfg. It is the
|
||||
// caller's responsibility to have taken a Snapshot first (the rollback
|
||||
// mechanism does this). Apply must be idempotent: calling it with the
|
||||
// same cfg twice should leave the system in the same state.
|
||||
Apply(ctx context.Context, iface string, cfg IfaceConfig) error
|
||||
|
||||
// SetLinkUp brings the interface up.
|
||||
SetLinkUp(ctx context.Context, iface string) error
|
||||
|
||||
// SetLinkDown takes the interface down.
|
||||
SetLinkDown(ctx context.Context, iface string) error
|
||||
}
|
||||
|
||||
// detect probes the host for a supported network manager, in priority order:
|
||||
//
|
||||
// 1. nmcli (NetworkManager) - the majority of desktop and modern server installs
|
||||
// 2. networkctl (systemd-networkd) - common on minimal/container hosts
|
||||
// 3. ifup/ifdown (ifupdown) - classic Debian/Ubuntu servers
|
||||
//
|
||||
// Returns nil when none is found. The order matters: some distros ship both NM
|
||||
// and networkd; NM wins because it's the active manager in that case.
|
||||
func detect() backend {
|
||||
if _, err := exec.LookPath("nmcli"); err == nil {
|
||||
if _, err := oscmd.Run("nmcli", "general", "status"); err == nil {
|
||||
return &nmcliBackend{}
|
||||
}
|
||||
}
|
||||
if _, err := exec.LookPath("networkctl"); err == nil {
|
||||
if _, err := oscmd.Run("systemctl", "is-active", "--quiet", "systemd-networkd"); err == nil {
|
||||
return &networkdBackend{}
|
||||
}
|
||||
}
|
||||
if _, err := exec.LookPath("ifup"); err == nil {
|
||||
if _, err := exec.LookPath("ifdown"); err == nil {
|
||||
return &ifupdownBackend{}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// pendingChange tracks a single in-flight change that has been applied but not
|
||||
// yet confirmed. revert undoes it (re-apply the prior config, or bring a
|
||||
// downed link back up). If the timer fires before confirmation, revert runs -
|
||||
// protecting against lock-yourself-out mistakes.
|
||||
//
|
||||
// ponytail: one slot for the whole module, not per-interface. An admin makes one
|
||||
// change at a time; a concurrent change to another iface is rejected with a 409
|
||||
// that says so. Key it by iface (a map) if multi-interface concurrency is ever
|
||||
// needed.
|
||||
type pendingChange struct {
|
||||
Iface string // interface that was changed
|
||||
revert func() error // undoes the change, for rollback
|
||||
Timer *time.Timer // fires the auto-revert
|
||||
Deadline time.Time // when the timer will fire (for the status endpoint)
|
||||
}
|
||||
|
||||
// errNoBackend is the 501 returned when no write backend was detected.
|
||||
var errNoBackend = fmt.Errorf("no supported network backend detected (tried nmcli, networkctl, ifup/ifdown)")
|
||||
@@ -0,0 +1,133 @@
|
||||
package networking
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"regexp"
|
||||
|
||||
"github.com/danielgtaylor/huma/v2"
|
||||
)
|
||||
|
||||
// tagNetworking groups every networking-module operation under one OpenAPI tag,
|
||||
// keeping tags 1:1 with modules.
|
||||
const tagNetworking = "Networking"
|
||||
|
||||
var (
|
||||
readErrors = []int{401, 403, 500}
|
||||
writeErrors = []int{400, 401, 403, 409, 500, 501}
|
||||
)
|
||||
|
||||
// ifaceNameRe matches a Linux interface name. Linux caps names at 15 bytes and
|
||||
// forbids '/' and whitespace; we additionally reject a leading dash so a name
|
||||
// can never be read as a flag by `ip`/`nmcli`/`ifup`. Every command line also
|
||||
// passes user-supplied names after a "--" separator.
|
||||
var ifaceNameRe = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9._@-]{0,14}$`)
|
||||
|
||||
func validateIface(name string) error {
|
||||
if !ifaceNameRe.MatchString(name) {
|
||||
return huma.Error400BadRequest("invalid interface name: " + name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IfaceConfig is the declarative desired state for one interface. A PUT replaces
|
||||
// the interface's IPv4 configuration with this (and IPv6 too, when the IPv6 block
|
||||
// is included), across whichever backend is in use. Routes are part of the config
|
||||
// (declarative add/remove): the client sends the full set it wants.
|
||||
type IfaceConfig struct {
|
||||
Method string `json:"method" enum:"static,dhcp" example:"static" doc:"\"static\" for a fixed address, \"dhcp\" for automatic"`
|
||||
Address string `json:"address,omitempty" example:"192.168.1.10" doc:"IPv4 address (static only)"`
|
||||
Prefix int `json:"prefix,omitempty" minimum:"0" maximum:"32" example:"24" doc:"Network prefix length (static only)"`
|
||||
Gateway string `json:"gateway,omitempty" example:"192.168.1.1" doc:"Default gateway (static only, optional)"`
|
||||
IPv6 *IPv6Config `json:"ipv6,omitempty" doc:"Optional IPv6 settings. Omit to leave IPv6 untouched; include to manage it."`
|
||||
DNS []string `json:"dns,omitempty" example:"[\"1.1.1.1\",\"8.8.8.8\"]" doc:"DNS servers for this interface (IPv4 or IPv6)"`
|
||||
Routes []Route `json:"routes,omitempty" doc:"Static routes to install for this interface"`
|
||||
RollbackSeconds int `json:"rollback_seconds,omitempty" minimum:"0" maximum:"3600" example:"60" doc:"Auto-revert after this many seconds unless confirmed. 0 uses the default (60s)."`
|
||||
}
|
||||
|
||||
// IPv6Config is the optional IPv6 settings for an interface. Method "auto" uses
|
||||
// SLAAC/router advertisements (the usual default), "static" pins an address, and
|
||||
// "ignore" disables IPv6 on the interface. DHCPv6 is not modeled.
|
||||
type IPv6Config struct {
|
||||
Method string `json:"method" enum:"auto,static,ignore" example:"static" doc:"\"auto\" (SLAAC), \"static\", or \"ignore\" (disable IPv6)"`
|
||||
Address string `json:"address,omitempty" example:"2001:db8::10" doc:"IPv6 address (static only)"`
|
||||
Prefix int `json:"prefix,omitempty" minimum:"0" maximum:"128" example:"64" doc:"Prefix length (static only)"`
|
||||
Gateway string `json:"gateway,omitempty" example:"2001:db8::1" doc:"IPv6 default gateway (static only, optional)"`
|
||||
}
|
||||
|
||||
// Route is a single static route. Destination is a CIDR (or \"default\").
|
||||
type Route struct {
|
||||
Destination string `json:"destination" example:"10.0.0.0/24" doc:"Destination network in CIDR notation, or \"default\""`
|
||||
Gateway string `json:"gateway" example:"192.168.1.1" doc:"Next-hop gateway"`
|
||||
}
|
||||
|
||||
// validate checks the desired config independently of the backend, so a bad
|
||||
// request is a 400 before we touch the system. IP/CIDR parsing uses net/netip.
|
||||
func (c IfaceConfig) validate() error {
|
||||
switch c.Method {
|
||||
case "static":
|
||||
if c.Address == "" {
|
||||
return huma.Error400BadRequest("static method requires an address")
|
||||
}
|
||||
if _, err := netip.ParseAddr(c.Address); err != nil {
|
||||
return huma.Error400BadRequest("invalid address: " + c.Address)
|
||||
}
|
||||
if c.Prefix < 1 || c.Prefix > 32 {
|
||||
return huma.Error400BadRequest("prefix must be 1-32")
|
||||
}
|
||||
if c.Gateway != "" {
|
||||
if _, err := netip.ParseAddr(c.Gateway); err != nil {
|
||||
return huma.Error400BadRequest("invalid gateway: " + c.Gateway)
|
||||
}
|
||||
}
|
||||
case "dhcp":
|
||||
// address/gateway/prefix are ignored; nothing to validate.
|
||||
default:
|
||||
return huma.Error400BadRequest("method must be \"static\" or \"dhcp\"")
|
||||
}
|
||||
for _, s := range c.DNS {
|
||||
if _, err := netip.ParseAddr(s); err != nil {
|
||||
return huma.Error400BadRequest("invalid DNS server: " + s)
|
||||
}
|
||||
}
|
||||
for _, r := range c.Routes {
|
||||
if r.Destination != "default" {
|
||||
if _, err := netip.ParsePrefix(r.Destination); err != nil {
|
||||
return huma.Error400BadRequest("invalid route destination: " + r.Destination)
|
||||
}
|
||||
}
|
||||
if _, err := netip.ParseAddr(r.Gateway); err != nil {
|
||||
return huma.Error400BadRequest("invalid route gateway: " + r.Gateway)
|
||||
}
|
||||
}
|
||||
if c.IPv6 != nil {
|
||||
if err := c.IPv6.validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// validate checks an IPv6 block. Static addresses/gateways must parse as IPv6
|
||||
// (an IPv4 literal here is a client mistake).
|
||||
func (c IPv6Config) validate() error {
|
||||
switch c.Method {
|
||||
case "auto", "ignore":
|
||||
// no address fields to validate
|
||||
case "static":
|
||||
addr, err := netip.ParseAddr(c.Address)
|
||||
if err != nil || addr.Is4() {
|
||||
return huma.Error400BadRequest("ipv6 static requires a valid IPv6 address, got: " + c.Address)
|
||||
}
|
||||
if c.Prefix < 1 || c.Prefix > 128 {
|
||||
return huma.Error400BadRequest("ipv6 prefix must be 1-128")
|
||||
}
|
||||
if c.Gateway != "" {
|
||||
if gw, err := netip.ParseAddr(c.Gateway); err != nil || gw.Is4() {
|
||||
return huma.Error400BadRequest("invalid ipv6 gateway: " + c.Gateway)
|
||||
}
|
||||
}
|
||||
default:
|
||||
return huma.Error400BadRequest("ipv6 method must be \"auto\", \"static\", or \"ignore\"")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,65 @@
|
||||
package networking
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestValidateIfaceConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg IfaceConfig
|
||||
wantErr bool
|
||||
}{
|
||||
{"valid static", IfaceConfig{
|
||||
Method: "static", Address: "192.168.1.10", Prefix: 24, Gateway: "192.168.1.1",
|
||||
DNS: []string{"1.1.1.1", "8.8.8.8"},
|
||||
}, false},
|
||||
{"valid static no gateway", IfaceConfig{
|
||||
Method: "static", Address: "10.0.0.5", Prefix: 8,
|
||||
}, false},
|
||||
{"valid dhcp", IfaceConfig{Method: "dhcp"}, false},
|
||||
{"valid with routes", IfaceConfig{
|
||||
Method: "static", Address: "192.168.1.10", Prefix: 24,
|
||||
Routes: []Route{
|
||||
{Destination: "100.64.0.0/24", Gateway: "192.168.1.1"},
|
||||
{Destination: "default", Gateway: "192.168.1.1"},
|
||||
},
|
||||
}, false},
|
||||
{"static missing address", IfaceConfig{Method: "static", Prefix: 24}, true},
|
||||
{"static bad address", IfaceConfig{Method: "static", Address: "not-an-ip", Prefix: 24}, true},
|
||||
{"static bad gateway", IfaceConfig{Method: "static", Address: "10.0.0.1", Prefix: 24, Gateway: "bad"}, true},
|
||||
{"bad method", IfaceConfig{Method: "pppoe"}, true},
|
||||
{"bad dns", IfaceConfig{Method: "dhcp", DNS: []string{"not-an-ip"}}, true},
|
||||
{"bad route destination", IfaceConfig{
|
||||
Method: "static", Address: "10.0.0.1", Prefix: 24,
|
||||
Routes: []Route{{Destination: "nope", Gateway: "10.0.0.1"}},
|
||||
}, true},
|
||||
{"bad route gateway", IfaceConfig{
|
||||
Method: "static", Address: "10.0.0.1", Prefix: 24,
|
||||
Routes: []Route{{Destination: "10.0.0.0/24", Gateway: "nope"}},
|
||||
}, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.cfg.validate()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateIface(t *testing.T) {
|
||||
valid := []string{"eth0", "enp3s0", "wlan0", "br-lan", "veth1234567", "docker0"}
|
||||
for _, name := range valid {
|
||||
if err := validateIface(name); err != nil {
|
||||
t.Errorf("validateIface(%q) unexpected error: %v", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
invalid := []string{"", "-eth0", "/dev/net", "a b", "name_that_is_way_too_long_for_linux"}
|
||||
for _, name := range invalid {
|
||||
if err := validateIface(name); err == nil {
|
||||
t.Errorf("validateIface(%q) expected error", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,208 @@
|
||||
package networking
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"nadir/internal/oscmd"
|
||||
|
||||
"github.com/danielgtaylor/huma/v2"
|
||||
)
|
||||
|
||||
// /etc/hosts management ("Host Addresses"). Edits are surgical — upsert/delete a
|
||||
// single IP's line — so existing comments, ordering and the localhost entries
|
||||
// are preserved rather than rewritten away.
|
||||
|
||||
var hostsFile = "/etc/hosts"
|
||||
|
||||
// hostnameRe matches a single hostname/alias. It forbids whitespace and '#' (so
|
||||
// an entry can't inject extra fields or a comment) and a leading dash.
|
||||
var hostnameRe = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9._-]*$`)
|
||||
|
||||
// HostEntry is one /etc/hosts mapping: an IP and the names that resolve to it.
|
||||
type HostEntry struct {
|
||||
IP string `json:"ip" example:"192.168.1.10" doc:"IPv4 or IPv6 address"`
|
||||
Hostnames []string `json:"hostnames" example:"[\"server\",\"server.local\"]" doc:"Names mapped to the address"`
|
||||
}
|
||||
|
||||
type ListHostsOutput struct {
|
||||
Body struct {
|
||||
Entries []HostEntry `json:"entries"`
|
||||
}
|
||||
}
|
||||
|
||||
type HostUpsertInput struct {
|
||||
IP string `path:"ip" example:"192.168.1.10" doc:"IP address to add or update"`
|
||||
Body struct {
|
||||
Hostnames []string `json:"hostnames" example:"[\"server\",\"server.local\"]" doc:"Names to map to the IP"`
|
||||
}
|
||||
}
|
||||
|
||||
type HostDeleteInput struct {
|
||||
IP string `path:"ip" example:"192.168.1.10" doc:"IP address whose entry to remove"`
|
||||
}
|
||||
|
||||
func registerHosts(api huma.API) {
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "networking-list-hosts",
|
||||
Method: "GET",
|
||||
Path: "/api/networking/hosts",
|
||||
Summary: "List /etc/hosts entries",
|
||||
Description: "Returns the static host-to-address mappings from /etc/hosts.",
|
||||
Tags: []string{tagNetworking},
|
||||
Metadata: op("read"),
|
||||
Errors: readErrors,
|
||||
}, func(ctx context.Context, _ *struct{}) (*ListHostsOutput, error) {
|
||||
data, err := os.ReadFile(hostsFile)
|
||||
if err != nil {
|
||||
return nil, huma.Error500InternalServerError("read hosts failed", err)
|
||||
}
|
||||
res := &ListHostsOutput{}
|
||||
res.Body.Entries = parseHosts(string(data))
|
||||
return res, nil
|
||||
})
|
||||
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "networking-upsert-host",
|
||||
Method: "PUT",
|
||||
Path: "/api/networking/hosts/{ip}",
|
||||
Summary: "Add or update a /etc/hosts entry",
|
||||
Description: "Sets the hostnames for an IP. Replaces the existing line for that IP, " +
|
||||
"or appends a new one; all other lines (comments included) are left untouched.",
|
||||
Tags: []string{tagNetworking},
|
||||
Metadata: op("write"),
|
||||
Errors: writeErrors,
|
||||
}, func(ctx context.Context, in *HostUpsertInput) (*oscmd.StatusOutput, error) {
|
||||
if err := validateHost(in.IP, in.Body.Hostnames); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := upsertHost(in.IP, in.Body.Hostnames); err != nil {
|
||||
return nil, huma.Error500InternalServerError("write hosts failed", err)
|
||||
}
|
||||
return oscmd.OK(), nil
|
||||
})
|
||||
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "networking-delete-host",
|
||||
Method: "DELETE",
|
||||
Path: "/api/networking/hosts/{ip}",
|
||||
Summary: "Remove a /etc/hosts entry",
|
||||
Description: "Removes the line(s) mapping the given IP. 404 if no entry exists for it.",
|
||||
Tags: []string{tagNetworking},
|
||||
Metadata: op("write"),
|
||||
Errors: []int{400, 401, 403, 404, 500},
|
||||
}, func(ctx context.Context, in *HostDeleteInput) (*oscmd.StatusOutput, error) {
|
||||
if _, err := netip.ParseAddr(in.IP); err != nil {
|
||||
return nil, huma.Error400BadRequest("invalid IP: " + in.IP)
|
||||
}
|
||||
removed, err := deleteHost(in.IP)
|
||||
if err != nil {
|
||||
return nil, huma.Error500InternalServerError("write hosts failed", err)
|
||||
}
|
||||
if !removed {
|
||||
return nil, huma.Error404NotFound("no hosts entry for " + in.IP)
|
||||
}
|
||||
return oscmd.OK(), nil
|
||||
})
|
||||
}
|
||||
|
||||
func validateHost(ip string, hostnames []string) error {
|
||||
if _, err := netip.ParseAddr(ip); err != nil {
|
||||
return huma.Error400BadRequest("invalid IP: " + ip)
|
||||
}
|
||||
if len(hostnames) == 0 {
|
||||
return huma.Error400BadRequest("at least one hostname is required")
|
||||
}
|
||||
for _, h := range hostnames {
|
||||
if !hostnameRe.MatchString(h) {
|
||||
return huma.Error400BadRequest("invalid hostname: " + h)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- parse / render (pure, tested) -------------------------------------------
|
||||
|
||||
// hostFields returns the IP and hostnames on a hosts line, or ok=false for a
|
||||
// blank or comment-only line. An inline "# comment" tail is stripped.
|
||||
func hostFields(line string) (ip string, names []string, ok bool) {
|
||||
if i := strings.IndexByte(line, '#'); i >= 0 {
|
||||
line = line[:i]
|
||||
}
|
||||
f := strings.Fields(line)
|
||||
if len(f) < 2 {
|
||||
return "", nil, false
|
||||
}
|
||||
return f[0], f[1:], true
|
||||
}
|
||||
|
||||
func parseHosts(text string) []HostEntry {
|
||||
entries := []HostEntry{}
|
||||
for line := range strings.SplitSeq(text, "\n") {
|
||||
if ip, names, ok := hostFields(line); ok {
|
||||
entries = append(entries, HostEntry{IP: ip, Hostnames: names})
|
||||
}
|
||||
}
|
||||
return entries
|
||||
}
|
||||
|
||||
// renderHostLine formats one entry as a hosts line.
|
||||
func renderHostLine(ip string, hostnames []string) string {
|
||||
return ip + "\t" + strings.Join(hostnames, " ")
|
||||
}
|
||||
|
||||
// --- writes ------------------------------------------------------------------
|
||||
|
||||
// upsertHost replaces the line for ip (matched on the address field) or appends
|
||||
// a new one, preserving every other line.
|
||||
func upsertHost(ip string, hostnames []string) error {
|
||||
data, err := os.ReadFile(hostsFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
lines := strings.Split(string(data), "\n")
|
||||
newLine := renderHostLine(ip, hostnames)
|
||||
replaced := false
|
||||
for i, line := range lines {
|
||||
if got, _, ok := hostFields(line); ok && got == ip {
|
||||
lines[i] = newLine
|
||||
replaced = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !replaced {
|
||||
// Append, avoiding a blank line if the file already ended with one.
|
||||
if n := len(lines); n > 0 && strings.TrimSpace(lines[n-1]) == "" {
|
||||
lines[n-1] = newLine
|
||||
} else {
|
||||
lines = append(lines, newLine)
|
||||
}
|
||||
lines = append(lines, "")
|
||||
}
|
||||
return os.WriteFile(hostsFile, []byte(strings.Join(lines, "\n")), 0644)
|
||||
}
|
||||
|
||||
// deleteHost removes every line mapping ip and reports whether any were removed.
|
||||
func deleteHost(ip string) (bool, error) {
|
||||
data, err := os.ReadFile(hostsFile)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
lines := strings.Split(string(data), "\n")
|
||||
kept := make([]string, 0, len(lines))
|
||||
removed := false
|
||||
for _, line := range lines {
|
||||
if got, _, ok := hostFields(line); ok && got == ip {
|
||||
removed = true
|
||||
continue
|
||||
}
|
||||
kept = append(kept, line)
|
||||
}
|
||||
if !removed {
|
||||
return false, nil
|
||||
}
|
||||
return true, os.WriteFile(hostsFile, []byte(strings.Join(kept, "\n")), 0644)
|
||||
}
|
||||
@@ -0,0 +1,92 @@
|
||||
package networking
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
const sampleHosts = `# static table
|
||||
127.0.0.1 localhost
|
||||
::1 localhost ip6-localhost
|
||||
192.168.1.10 server server.local # the box
|
||||
|
||||
`
|
||||
|
||||
func TestParseHosts(t *testing.T) {
|
||||
got := parseHosts(sampleHosts)
|
||||
want := []HostEntry{
|
||||
{IP: "127.0.0.1", Hostnames: []string{"localhost"}},
|
||||
{IP: "::1", Hostnames: []string{"localhost", "ip6-localhost"}},
|
||||
{IP: "192.168.1.10", Hostnames: []string{"server", "server.local"}},
|
||||
}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("parseHosts:\n got %+v\nwant %+v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpsertAndDeleteHost(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "hosts")
|
||||
if err := os.WriteFile(path, []byte(sampleHosts), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
old := hostsFile
|
||||
hostsFile = path
|
||||
defer func() { hostsFile = old }()
|
||||
|
||||
// Update an existing IP — comments and other lines must survive.
|
||||
if err := upsertHost("192.168.1.10", []string{"web", "web.local"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, _ := os.ReadFile(path)
|
||||
if !strings.Contains(string(data), "# static table") || !strings.Contains(string(data), "ip6-localhost") {
|
||||
t.Errorf("upsert clobbered other lines:\n%s", data)
|
||||
}
|
||||
entries := parseHosts(string(data))
|
||||
if e := findHost(entries, "192.168.1.10"); e == nil || !reflect.DeepEqual(e.Hostnames, []string{"web", "web.local"}) {
|
||||
t.Errorf("upsert did not update entry: %+v", entries)
|
||||
}
|
||||
|
||||
// Add a new IP.
|
||||
if err := upsertHost("10.0.0.5", []string{"db"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
entries = parseHosts(mustRead(t, path))
|
||||
if findHost(entries, "10.0.0.5") == nil {
|
||||
t.Errorf("upsert did not append new entry: %+v", entries)
|
||||
}
|
||||
|
||||
// Delete it again.
|
||||
removed, err := deleteHost("10.0.0.5")
|
||||
if err != nil || !removed {
|
||||
t.Fatalf("delete failed: removed=%v err=%v", removed, err)
|
||||
}
|
||||
if findHost(parseHosts(mustRead(t, path)), "10.0.0.5") != nil {
|
||||
t.Error("entry still present after delete")
|
||||
}
|
||||
|
||||
// Deleting a missing IP reports not-removed.
|
||||
if removed, _ := deleteHost("8.8.8.8"); removed {
|
||||
t.Error("expected removed=false for missing IP")
|
||||
}
|
||||
}
|
||||
|
||||
func findHost(entries []HostEntry, ip string) *HostEntry {
|
||||
for i := range entries {
|
||||
if entries[i].IP == ip {
|
||||
return &entries[i]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func mustRead(t *testing.T, path string) string {
|
||||
t.Helper()
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return string(data)
|
||||
}
|
||||
@@ -0,0 +1,239 @@
|
||||
package networking
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"nadir/internal/oscmd"
|
||||
)
|
||||
|
||||
// ifupdownBackend implements backend via classic Debian ifupdown. Configuration
|
||||
// is done by writing stanza files under interfacesDDir. Each managed interface
|
||||
// gets "nadir-<iface>" so we never touch /etc/network/interfaces itself.
|
||||
//
|
||||
// Prerequisite: /etc/network/interfaces must contain
|
||||
//
|
||||
// source /etc/network/interfaces.d/*
|
||||
//
|
||||
// for these stanzas to take effect. If missing, Apply checks for this and adds
|
||||
// the source directive (same auto-provision pattern as the PAM service).
|
||||
type ifupdownBackend struct{}
|
||||
|
||||
var (
|
||||
interfacesFile = "/etc/network/interfaces"
|
||||
interfacesDDir = "/etc/network/interfaces.d"
|
||||
)
|
||||
|
||||
func (b *ifupdownBackend) Name() string { return "ifupdown" }
|
||||
|
||||
// ifupdownFile returns the path for the nadir-managed stanza file for iface.
|
||||
// iface is already validated by validateIface to prevent path traversal.
|
||||
func ifupdownFile(iface string) string {
|
||||
return filepath.Join(interfacesDDir, "nadir-"+iface)
|
||||
}
|
||||
|
||||
func (b *ifupdownBackend) Snapshot(ctx context.Context, iface string) (IfaceConfig, error) {
|
||||
path := ifupdownFile(iface)
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
// No nadir-managed stanza → fall back to live ip output, assume DHCP.
|
||||
return snapshotFromIP(ctx, iface)
|
||||
}
|
||||
return parseIfupdownStanza(string(data)), nil
|
||||
}
|
||||
|
||||
// parseIfupdownStanza extracts IfaceConfig from an ifupdown stanza file. A file
|
||||
// may hold both an "inet" (IPv4) and an "inet6" (IPv6) stanza for the interface;
|
||||
// family tracks which one the indented keys below belong to.
|
||||
func parseIfupdownStanza(content string) IfaceConfig {
|
||||
cfg := IfaceConfig{Method: "dhcp"}
|
||||
family := "inet"
|
||||
|
||||
for line := range strings.SplitSeq(content, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) < 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
switch fields[0] {
|
||||
case "iface":
|
||||
// "iface eth0 inet static" / "iface eth0 inet6 auto"
|
||||
if len(fields) >= 4 {
|
||||
family = fields[2]
|
||||
switch family {
|
||||
case "inet":
|
||||
if fields[3] == "static" {
|
||||
cfg.Method = "static"
|
||||
}
|
||||
case "inet6":
|
||||
switch fields[3] {
|
||||
case "static":
|
||||
v6(&cfg).Method = "static"
|
||||
default: // auto / dhcp / manual → treat as autoconf
|
||||
v6(&cfg).Method = "auto"
|
||||
}
|
||||
}
|
||||
}
|
||||
case "address":
|
||||
addr, prefix := splitCIDR(fields[1])
|
||||
if family == "inet6" {
|
||||
g := v6(&cfg)
|
||||
g.Address = addr
|
||||
if prefix > 0 {
|
||||
g.Prefix = prefix
|
||||
}
|
||||
} else if addr != "" {
|
||||
cfg.Address = addr
|
||||
if prefix > 0 {
|
||||
cfg.Prefix = prefix
|
||||
}
|
||||
}
|
||||
case "gateway":
|
||||
if family == "inet6" {
|
||||
v6(&cfg).Gateway = fields[1]
|
||||
} else {
|
||||
cfg.Gateway = fields[1]
|
||||
}
|
||||
case "dns-nameservers":
|
||||
cfg.DNS = append(cfg.DNS, fields[1:]...)
|
||||
case "up":
|
||||
// "up ip route add 10.0.0.0/24 via 192.168.1.1"
|
||||
r := parseUpRoute(fields[1:])
|
||||
if r.Destination != "" {
|
||||
cfg.Routes = append(cfg.Routes, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
// parseUpRoute extracts a Route from "ip route add <dst> via <gw>" post-up commands.
|
||||
func parseUpRoute(args []string) Route {
|
||||
// Expected: ["ip", "route", "add", "10.0.0.0/24", "via", "192.168.1.1"]
|
||||
if len(args) < 6 || args[0] != "ip" || args[1] != "route" || args[2] != "add" {
|
||||
return Route{}
|
||||
}
|
||||
r := Route{Destination: args[3]}
|
||||
for i, a := range args {
|
||||
if a == "via" && i+1 < len(args) {
|
||||
r.Gateway = args[i+1]
|
||||
break
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func (b *ifupdownBackend) Apply(ctx context.Context, iface string, cfg IfaceConfig) error {
|
||||
// Ensure interfaces.d exists and is sourced.
|
||||
if err := ensureSourceDirective(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
content := renderIfupdownStanza(iface, cfg)
|
||||
path := ifupdownFile(iface)
|
||||
|
||||
if err := os.MkdirAll(interfacesDDir, 0755); err != nil {
|
||||
return fmt.Errorf("mkdir %s: %w", interfacesDDir, err)
|
||||
}
|
||||
if err := os.WriteFile(path, []byte(content), 0644); err != nil {
|
||||
return fmt.Errorf("write %s: %w", path, err)
|
||||
}
|
||||
|
||||
// Bring the interface down and back up with the new config.
|
||||
// ifdown may fail if the interface wasn't previously managed — ignore that.
|
||||
_, _ = oscmd.RunContext(ctx, "ifdown", "--force", "--", iface)
|
||||
if _, err := oscmd.RunContext(ctx, "ifup", "--", iface); err != nil {
|
||||
return fmt.Errorf("ifup %s: %w", iface, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensureSourceDirective checks that /etc/network/interfaces contains a source
|
||||
// line for interfaces.d. If not, appends one (same auto-provision pattern as
|
||||
// the PAM service file).
|
||||
func ensureSourceDirective() error {
|
||||
data, err := os.ReadFile(interfacesFile)
|
||||
if err != nil {
|
||||
// If the file doesn't exist at all, that's a broken system; don't create it.
|
||||
return fmt.Errorf("read %s: %w", interfacesFile, err)
|
||||
}
|
||||
|
||||
content := string(data)
|
||||
// Check for any source line pointing at interfaces.d.
|
||||
for line := range strings.SplitSeq(content, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if strings.HasPrefix(line, "source") && strings.Contains(line, "interfaces.d") {
|
||||
return nil // already present
|
||||
}
|
||||
}
|
||||
|
||||
// Append the source directive.
|
||||
addition := "\n# Added by nadir to pick up per-interface stanzas.\nsource " + interfacesDDir + "/*\n"
|
||||
if err := os.WriteFile(interfacesFile, []byte(content+addition), 0644); err != nil {
|
||||
return fmt.Errorf("append source directive to %s: %w", interfacesFile, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// renderIfupdownStanza builds the content for an ifupdown stanza file.
|
||||
func renderIfupdownStanza(iface string, cfg IfaceConfig) string {
|
||||
var b strings.Builder
|
||||
b.WriteString("# Managed by nadir — do not edit manually.\n")
|
||||
|
||||
switch cfg.Method {
|
||||
case "dhcp":
|
||||
fmt.Fprintf(&b, "auto %s\n", iface)
|
||||
fmt.Fprintf(&b, "iface %s inet dhcp\n", iface)
|
||||
case "static":
|
||||
fmt.Fprintf(&b, "auto %s\n", iface)
|
||||
fmt.Fprintf(&b, "iface %s inet static\n", iface)
|
||||
fmt.Fprintf(&b, " address %s/%d\n", cfg.Address, cfg.Prefix)
|
||||
if cfg.Gateway != "" {
|
||||
fmt.Fprintf(&b, " gateway %s\n", cfg.Gateway)
|
||||
}
|
||||
}
|
||||
|
||||
if len(cfg.DNS) > 0 {
|
||||
fmt.Fprintf(&b, " dns-nameservers %s\n", strings.Join(cfg.DNS, " "))
|
||||
}
|
||||
|
||||
for _, r := range cfg.Routes {
|
||||
fmt.Fprintf(&b, " up ip route add %s via %s\n", r.Destination, r.Gateway)
|
||||
fmt.Fprintf(&b, " down ip route del %s via %s\n", r.Destination, r.Gateway)
|
||||
}
|
||||
|
||||
// IPv6 goes in its own inet6 stanza (the "auto eth0" above covers both).
|
||||
if cfg.IPv6 != nil {
|
||||
switch cfg.IPv6.Method {
|
||||
case "auto":
|
||||
fmt.Fprintf(&b, "iface %s inet6 auto\n", iface)
|
||||
case "static":
|
||||
fmt.Fprintf(&b, "iface %s inet6 static\n", iface)
|
||||
fmt.Fprintf(&b, " address %s/%d\n", cfg.IPv6.Address, cfg.IPv6.Prefix)
|
||||
if cfg.IPv6.Gateway != "" {
|
||||
fmt.Fprintf(&b, " gateway %s\n", cfg.IPv6.Gateway)
|
||||
}
|
||||
case "ignore":
|
||||
// no inet6 stanza
|
||||
}
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (b *ifupdownBackend) SetLinkUp(ctx context.Context, iface string) error {
|
||||
_, err := oscmd.RunContext(ctx, "ifup", "--", iface)
|
||||
return err
|
||||
}
|
||||
|
||||
func (b *ifupdownBackend) SetLinkDown(ctx context.Context, iface string) error {
|
||||
_, err := oscmd.RunContext(ctx, "ifdown", "--", iface)
|
||||
return err
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
package networking
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"nadir/internal/rbac"
|
||||
|
||||
"github.com/danielgtaylor/huma/v2"
|
||||
)
|
||||
|
||||
const ModuleID = "networking"
|
||||
|
||||
type Module struct {
|
||||
// be is the detected network backend (nmcli / networkd / ifupdown). nil when
|
||||
// none was found: reads still work (they go through `ip`), writes return 501.
|
||||
be backend
|
||||
// pending holds the single in-flight change awaiting confirmation, for the
|
||||
// timed auto-rollback. See rollback.go.
|
||||
pending *pendingChange
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// New detects the host's network backend once at startup.
|
||||
func New() *Module { return &Module{be: detect()} }
|
||||
|
||||
func (m *Module) ID() string { return ModuleID }
|
||||
func (m *Module) Name() string { return "Networking" }
|
||||
|
||||
// Permissions: read to inspect interfaces/routes/DNS; write to reconfigure them
|
||||
// (apply config, bring links up/down, confirm a pending change).
|
||||
func (m *Module) Permissions() []rbac.Permission {
|
||||
return []rbac.Permission{rbac.Read, rbac.Write}
|
||||
}
|
||||
|
||||
func (m *Module) Register(api huma.API) {
|
||||
registerReads(api)
|
||||
registerWrites(api, m)
|
||||
registerHosts(api)
|
||||
}
|
||||
|
||||
func op(permission string) map[string]any {
|
||||
return map[string]any{"module": ModuleID, "permission": permission}
|
||||
}
|
||||
@@ -0,0 +1,285 @@
|
||||
package networking
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"nadir/internal/oscmd"
|
||||
)
|
||||
|
||||
// networkdBackend implements backend via systemd-networkd. Configuration is done
|
||||
// by writing .network files under networkdDir. Each managed interface gets its
|
||||
// own file named "90-nadir-<iface>.network" — the 90 prefix puts nadir's config
|
||||
// after most distro defaults, and the "nadir-" infix ensures we never clobber
|
||||
// distro-provided files.
|
||||
type networkdBackend struct{}
|
||||
|
||||
var networkdDir = "/etc/systemd/network"
|
||||
|
||||
func (b *networkdBackend) Name() string { return "networkd" }
|
||||
|
||||
// networkdFile returns the path for the nadir-managed .network file for iface.
|
||||
// iface is already validated by validateIface to prevent path traversal.
|
||||
func networkdFile(iface string) string {
|
||||
return filepath.Join(networkdDir, "90-nadir-"+iface+".network")
|
||||
}
|
||||
|
||||
func (b *networkdBackend) Snapshot(ctx context.Context, iface string) (IfaceConfig, error) {
|
||||
path := networkdFile(iface)
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
// No nadir-managed file exists. Fall back to live ip output and assume
|
||||
// DHCP — the safest rollback assumption (reverts to "whatever the
|
||||
// system was doing before nadir touched it").
|
||||
return snapshotFromIP(ctx, iface)
|
||||
}
|
||||
return parseNetworkdFile(string(data)), nil
|
||||
}
|
||||
|
||||
// snapshotFromIP captures the current state from ip -j, assuming DHCP as the
|
||||
// method since there's no way to tell from live output.
|
||||
func snapshotFromIP(ctx context.Context, iface string) (IfaceConfig, error) {
|
||||
cfg := IfaceConfig{Method: "dhcp"}
|
||||
|
||||
// Grab addresses.
|
||||
out, err := oscmd.RunContext(ctx, "ip", "-j", "addr", "show", "--", iface)
|
||||
if err != nil {
|
||||
return cfg, nil // interface may not exist yet; DHCP fallback is fine
|
||||
}
|
||||
ifaces, err := parseInterfaces(out)
|
||||
if err != nil || len(ifaces) == 0 {
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// If there are IPv4 addresses, capture the first one.
|
||||
if len(ifaces[0].IPv4) > 0 {
|
||||
addr, prefix := splitCIDR(ifaces[0].IPv4[0])
|
||||
cfg.Method = "static"
|
||||
cfg.Address = addr
|
||||
cfg.Prefix = prefix
|
||||
}
|
||||
|
||||
// Capture a global IPv6 address if present, skipping link-local (fe80::/10),
|
||||
// so a rollback restores the interface's real IPv6 state.
|
||||
cfg.IPv6 = &IPv6Config{Method: "auto"}
|
||||
for _, c := range ifaces[0].IPv6 {
|
||||
addr, prefix := splitCIDR(c)
|
||||
if ip, err := netip.ParseAddr(addr); err == nil && !ip.IsLinkLocalUnicast() {
|
||||
cfg.IPv6 = &IPv6Config{Method: "static", Address: addr, Prefix: prefix}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Grab the default gateway for this interface.
|
||||
routeOut, err := oscmd.RunContext(ctx, "ip", "-j", "route", "show", "dev", "--", iface)
|
||||
if err == nil {
|
||||
routes, _ := parseRoutes(routeOut)
|
||||
for _, r := range routes {
|
||||
if r.Destination == "default" && r.Gateway != "" {
|
||||
cfg.Gateway = r.Gateway
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Grab DNS from /etc/resolv.conf
|
||||
if data, err := os.ReadFile(resolvConf); err == nil {
|
||||
cfg.DNS = parseResolv(string(data))
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// parseNetworkdFile extracts IfaceConfig from a systemd .network file.
|
||||
func parseNetworkdFile(content string) IfaceConfig {
|
||||
cfg := IfaceConfig{Method: "dhcp"}
|
||||
section := ""
|
||||
|
||||
for line := range strings.SplitSeq(content, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(line, "[") && strings.HasSuffix(line, "]") {
|
||||
section = line
|
||||
continue
|
||||
}
|
||||
|
||||
key, val, ok := strings.Cut(line, "=")
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
key = strings.TrimSpace(key)
|
||||
val = strings.TrimSpace(val)
|
||||
|
||||
if section == "[Network]" {
|
||||
switch key {
|
||||
case "DHCP":
|
||||
if val == "yes" || val == "ipv4" {
|
||||
cfg.Method = "dhcp"
|
||||
}
|
||||
case "Address":
|
||||
addr, prefix := splitCIDR(val)
|
||||
if ip, err := netip.ParseAddr(addr); err == nil && !ip.Is4() {
|
||||
g := v6(&cfg)
|
||||
g.Method = "static"
|
||||
g.Address = addr
|
||||
g.Prefix = prefix
|
||||
} else if addr != "" {
|
||||
cfg.Method = "static"
|
||||
cfg.Address = addr
|
||||
cfg.Prefix = prefix
|
||||
}
|
||||
case "Gateway":
|
||||
if ip, err := netip.ParseAddr(val); err == nil && !ip.Is4() {
|
||||
v6(&cfg).Gateway = val
|
||||
} else {
|
||||
cfg.Gateway = val
|
||||
}
|
||||
case "DNS":
|
||||
for s := range strings.FieldsSeq(val) {
|
||||
cfg.DNS = append(cfg.DNS, s)
|
||||
}
|
||||
case "IPv6AcceptRA":
|
||||
if val == "yes" {
|
||||
if g := v6(&cfg); g.Method != "static" {
|
||||
g.Method = "auto"
|
||||
}
|
||||
}
|
||||
case "LinkLocalAddressing":
|
||||
if val == "no" {
|
||||
v6(&cfg).Method = "ignore"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Parse [Route] sections separately.
|
||||
cfg.Routes = parseNetworkdRoutes(content)
|
||||
return cfg
|
||||
}
|
||||
|
||||
// parseNetworkdRoutes extracts Route entries from [Route] sections.
|
||||
func parseNetworkdRoutes(content string) []Route {
|
||||
var routes []Route
|
||||
var inRoute bool
|
||||
var current Route
|
||||
|
||||
for line := range strings.SplitSeq(content, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
|
||||
if strings.HasPrefix(line, "[") {
|
||||
// Flush any pending route when entering a new section.
|
||||
if inRoute && (current.Destination != "" || current.Gateway != "") {
|
||||
routes = append(routes, current)
|
||||
}
|
||||
inRoute = line == "[Route]"
|
||||
current = Route{}
|
||||
continue
|
||||
}
|
||||
|
||||
if !inRoute {
|
||||
continue
|
||||
}
|
||||
|
||||
key, val, ok := strings.Cut(line, "=")
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
switch strings.TrimSpace(key) {
|
||||
case "Destination":
|
||||
current.Destination = strings.TrimSpace(val)
|
||||
case "Gateway":
|
||||
current.Gateway = strings.TrimSpace(val)
|
||||
}
|
||||
}
|
||||
// Flush last route if we ended inside a [Route] section.
|
||||
if inRoute && (current.Destination != "" || current.Gateway != "") {
|
||||
routes = append(routes, current)
|
||||
}
|
||||
return routes
|
||||
}
|
||||
|
||||
func (b *networkdBackend) Apply(ctx context.Context, iface string, cfg IfaceConfig) error {
|
||||
content := renderNetworkdFile(iface, cfg)
|
||||
path := networkdFile(iface)
|
||||
|
||||
if err := os.MkdirAll(networkdDir, 0755); err != nil {
|
||||
return fmt.Errorf("mkdir %s: %w", networkdDir, err)
|
||||
}
|
||||
if err := os.WriteFile(path, []byte(content), 0644); err != nil {
|
||||
return fmt.Errorf("write %s: %w", path, err)
|
||||
}
|
||||
|
||||
// Reload networkd and reconfigure the specific interface.
|
||||
if _, err := oscmd.RunContext(ctx, "networkctl", "reload"); err != nil {
|
||||
return fmt.Errorf("networkctl reload: %w", err)
|
||||
}
|
||||
if _, err := oscmd.RunContext(ctx, "networkctl", "reconfigure", "--", iface); err != nil {
|
||||
return fmt.Errorf("networkctl reconfigure %s: %w", iface, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// renderNetworkdFile builds the INI content for a systemd .network file.
|
||||
func renderNetworkdFile(iface string, cfg IfaceConfig) string {
|
||||
var b strings.Builder
|
||||
|
||||
b.WriteString("# Managed by nadir — do not edit manually.\n")
|
||||
b.WriteString("[Match]\n")
|
||||
fmt.Fprintf(&b, "Name=%s\n\n", iface)
|
||||
|
||||
b.WriteString("[Network]\n")
|
||||
switch cfg.Method {
|
||||
case "dhcp":
|
||||
b.WriteString("DHCP=yes\n")
|
||||
case "static":
|
||||
b.WriteString("DHCP=no\n")
|
||||
fmt.Fprintf(&b, "Address=%s/%d\n", cfg.Address, cfg.Prefix)
|
||||
if cfg.Gateway != "" {
|
||||
fmt.Fprintf(&b, "Gateway=%s\n", cfg.Gateway)
|
||||
}
|
||||
}
|
||||
if cfg.IPv6 != nil {
|
||||
switch cfg.IPv6.Method {
|
||||
case "static":
|
||||
fmt.Fprintf(&b, "Address=%s/%d\n", cfg.IPv6.Address, cfg.IPv6.Prefix)
|
||||
if cfg.IPv6.Gateway != "" {
|
||||
fmt.Fprintf(&b, "Gateway=%s\n", cfg.IPv6.Gateway)
|
||||
}
|
||||
b.WriteString("IPv6AcceptRA=no\n")
|
||||
case "auto":
|
||||
b.WriteString("IPv6AcceptRA=yes\n")
|
||||
case "ignore":
|
||||
b.WriteString("LinkLocalAddressing=no\nIPv6AcceptRA=no\n")
|
||||
}
|
||||
}
|
||||
|
||||
for _, dns := range cfg.DNS {
|
||||
fmt.Fprintf(&b, "DNS=%s\n", dns)
|
||||
}
|
||||
|
||||
for _, r := range cfg.Routes {
|
||||
b.WriteString("\n[Route]\n")
|
||||
fmt.Fprintf(&b, "Destination=%s\n", r.Destination)
|
||||
fmt.Fprintf(&b, "Gateway=%s\n", r.Gateway)
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (b *networkdBackend) SetLinkUp(ctx context.Context, iface string) error {
|
||||
// Note: ip link set parses DEVICE positionally, so -- is technically ignored
|
||||
// by ip but included here for consistency with other oscmd calls.
|
||||
_, err := oscmd.RunContext(ctx, "ip", "link", "set", "--", iface, "up")
|
||||
return err
|
||||
}
|
||||
|
||||
func (b *networkdBackend) SetLinkDown(ctx context.Context, iface string) error {
|
||||
_, err := oscmd.RunContext(ctx, "ip", "link", "set", "--", iface, "down")
|
||||
return err
|
||||
}
|
||||
@@ -0,0 +1,397 @@
|
||||
package networking
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"nadir/internal/oscmd"
|
||||
|
||||
"github.com/danielgtaylor/huma/v2"
|
||||
"github.com/danielgtaylor/huma/v2/adapters/humago"
|
||||
"github.com/danielgtaylor/huma/v2/humatest"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
if oscmd.RunHelperProcess() {
|
||||
return
|
||||
}
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
type mockBackend struct {
|
||||
name string
|
||||
snapshotResult IfaceConfig
|
||||
applyCalledWith IfaceConfig
|
||||
applyErr error
|
||||
snapshotErr error
|
||||
setUpCalled bool
|
||||
setDownCalled bool
|
||||
}
|
||||
|
||||
func (m *mockBackend) Name() string { return m.name }
|
||||
func (m *mockBackend) Snapshot(_ context.Context, iface string) (IfaceConfig, error) {
|
||||
return m.snapshotResult, m.snapshotErr
|
||||
}
|
||||
func (m *mockBackend) Apply(_ context.Context, iface string, cfg IfaceConfig) error {
|
||||
m.applyCalledWith = cfg
|
||||
return m.applyErr
|
||||
}
|
||||
func (m *mockBackend) SetLinkUp(_ context.Context, iface string) error {
|
||||
m.setUpCalled = true
|
||||
return nil
|
||||
}
|
||||
func (m *mockBackend) SetLinkDown(_ context.Context, iface string) error {
|
||||
m.setDownCalled = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestNetworkingHandlers(t *testing.T) {
|
||||
mux := http.NewServeMux()
|
||||
api := humatest.Wrap(t, humago.New(mux, huma.DefaultConfig("Test", "1.0.0")))
|
||||
|
||||
be := &mockBackend{
|
||||
name: "mockbe",
|
||||
snapshotResult: IfaceConfig{
|
||||
Method: "dhcp",
|
||||
Address: "192.168.1.10/24",
|
||||
},
|
||||
}
|
||||
m := &Module{be: be}
|
||||
m.Register(api)
|
||||
|
||||
tempResolv := filepath.Join(t.TempDir(), "resolv.conf")
|
||||
if err := os.WriteFile(tempResolv, []byte("nameserver 1.1.1.1\nnameserver 8.8.8.8\n"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
oldResolv := resolvConf
|
||||
resolvConf = tempResolv
|
||||
defer func() { resolvConf = oldResolv }()
|
||||
|
||||
oscmd.SetMock("ip", func(args []string) oscmd.MockCommand {
|
||||
if reflect.DeepEqual(args, []string{"-j", "addr"}) {
|
||||
out := `[{"ifname": "eth0", "operstate": "UP", "address": "aa:bb:cc:dd:ee:ff", "mtu": 1500, "addr_info": [{"family": "inet", "local": "192.168.1.10", "prefixlen": 24}]}]`
|
||||
return oscmd.MockCommand{Stdout: out, ExitCode: 0}
|
||||
}
|
||||
if reflect.DeepEqual(args, []string{"-j", "route"}) {
|
||||
out := `[{"dst": "default", "gateway": "192.168.1.1", "dev": "eth0"}]`
|
||||
return oscmd.MockCommand{Stdout: out, ExitCode: 0}
|
||||
}
|
||||
return oscmd.MockCommand{ExitCode: 1}
|
||||
})
|
||||
defer oscmd.ClearMocks()
|
||||
|
||||
// 1. Test GET /api/networking/interfaces
|
||||
resp := api.Get("/api/networking/interfaces")
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("list interfaces: got %d, want %d", resp.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
// 2. Test GET /api/networking/routes
|
||||
resp = api.Get("/api/networking/routes")
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("list routes: got %d, want %d", resp.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
// 3. Test GET /api/networking/dns
|
||||
resp = api.Get("/api/networking/dns")
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("get dns: got %d, want %d", resp.Code, http.StatusOK)
|
||||
}
|
||||
var dnsRes DNSOutput
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &dnsRes.Body); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(dnsRes.Body.Servers) != 2 || dnsRes.Body.Servers[0] != "1.1.1.1" {
|
||||
t.Errorf("get dns output: %+v", dnsRes.Body)
|
||||
}
|
||||
|
||||
// 4. Test PUT /api/networking/interfaces/{name}
|
||||
applyPayload := struct {
|
||||
Method string `json:"method"`
|
||||
Address string `json:"address,omitempty"`
|
||||
Prefix int `json:"prefix,omitempty"`
|
||||
Gateway string `json:"gateway,omitempty"`
|
||||
DNS []string `json:"dns,omitempty"`
|
||||
RollbackSeconds int `json:"rollback_seconds,omitempty"`
|
||||
}{
|
||||
Method: "static",
|
||||
Address: "192.168.1.20",
|
||||
Prefix: 24,
|
||||
Gateway: "192.168.1.1",
|
||||
DNS: []string{"1.1.1.1"},
|
||||
RollbackSeconds: 2,
|
||||
}
|
||||
|
||||
resp = api.Put("/api/networking/interfaces/eth0", applyPayload)
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("apply interface config: got %d, want %d, body=%s", resp.Code, http.StatusOK, resp.Body.String())
|
||||
}
|
||||
|
||||
resp = api.Get("/api/networking/pending")
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("get pending: got %d, want %d", resp.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
// 5. Test POST /api/networking/interfaces/{name}/confirm
|
||||
resp = api.Post("/api/networking/interfaces/eth0/confirm", struct{}{})
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("confirm change: got %d, want %d", resp.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
resp = api.Get("/api/networking/pending")
|
||||
if resp.Code != http.StatusNotFound {
|
||||
t.Errorf("pending change should be cleared: got %d, want %d", resp.Code, http.StatusNotFound)
|
||||
}
|
||||
|
||||
// 6. Test automatic rollback
|
||||
applyPayload.RollbackSeconds = 1
|
||||
resp = api.Put("/api/networking/interfaces/eth0", applyPayload)
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("apply config again: got %d, want %d", resp.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
time.Sleep(1200 * time.Millisecond)
|
||||
|
||||
resp = api.Get("/api/networking/pending")
|
||||
if resp.Code != http.StatusNotFound {
|
||||
t.Errorf("pending change should be rolled back: got %d, want %d", resp.Code, http.StatusNotFound)
|
||||
}
|
||||
|
||||
if be.applyCalledWith.Method != "dhcp" || be.applyCalledWith.Address != "192.168.1.10/24" {
|
||||
t.Errorf("rollback failed to restore prior config: %+v", be.applyCalledWith)
|
||||
}
|
||||
|
||||
// 7. Test POST link up & down
|
||||
resp = api.Post("/api/networking/interfaces/eth0/up", struct{}{})
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("set link up: got %d, want %d", resp.Code, http.StatusOK)
|
||||
}
|
||||
if !be.setUpCalled {
|
||||
t.Errorf("expected SetLinkUp call on backend")
|
||||
}
|
||||
|
||||
resp = api.Post("/api/networking/interfaces/eth0/down", struct{}{})
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("set link down: got %d, want %d", resp.Code, http.StatusOK)
|
||||
}
|
||||
if !be.setDownCalled {
|
||||
t.Errorf("expected SetLinkDown call on backend")
|
||||
}
|
||||
}
|
||||
|
||||
// #3: a failed Apply must restore the prior snapshot (no half-applied config
|
||||
// left with no auto-revert) and arm no pending change.
|
||||
func TestApplyFailureRestoresPrior(t *testing.T) {
|
||||
be := &mockBackend{
|
||||
name: "mock",
|
||||
snapshotResult: IfaceConfig{Method: "dhcp"},
|
||||
applyErr: errors.New("con up failed"),
|
||||
}
|
||||
m := &Module{be: be}
|
||||
|
||||
_, err := m.startRollback(t.Context(), "eth0", IfaceConfig{Method: "static", Address: "10.0.0.5", Prefix: 24})
|
||||
if err == nil {
|
||||
t.Fatal("expected apply to fail")
|
||||
}
|
||||
// The last Apply call should be the restore-to-prior (dhcp), not the new config.
|
||||
if be.applyCalledWith.Method != "dhcp" {
|
||||
t.Errorf("expected restore to prior config, last Apply got %+v", be.applyCalledWith)
|
||||
}
|
||||
if m.pending != nil {
|
||||
t.Error("no pending change should be armed after a failed apply")
|
||||
}
|
||||
}
|
||||
|
||||
// #1: link-down must go through the rollback safety net (arms a pending change),
|
||||
// and rolling it back brings the interface back up. #2: a change to another
|
||||
// interface while one is pending is rejected with errAlreadyPending.
|
||||
func TestLinkDownArmsRollback(t *testing.T) {
|
||||
be := &mockBackend{name: "mock"}
|
||||
m := &Module{be: be}
|
||||
|
||||
secs, err := m.startLinkDown(t.Context(), "eth0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if secs != defaultRollbackSeconds {
|
||||
t.Errorf("seconds = %d, want default %d", secs, defaultRollbackSeconds)
|
||||
}
|
||||
if !be.setDownCalled {
|
||||
t.Error("expected SetLinkDown to be called")
|
||||
}
|
||||
if m.pending == nil {
|
||||
t.Fatal("expected a pending change to be armed")
|
||||
}
|
||||
|
||||
// A concurrent change to a different interface is rejected (global lock).
|
||||
if _, err := m.startRollback(t.Context(), "eth1", IfaceConfig{Method: "dhcp"}); !errors.Is(err, errAlreadyPending) {
|
||||
t.Errorf("expected errAlreadyPending for eth1, got %v", err)
|
||||
}
|
||||
|
||||
// Rolling back brings the link back up and clears the pending change.
|
||||
if err := m.rollbackNow("eth0"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !be.setUpCalled {
|
||||
t.Error("expected SetLinkUp on rollback")
|
||||
}
|
||||
if m.pending != nil {
|
||||
t.Error("pending change should be cleared after rollback")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackendImplementations(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Mock ip command for fallback snapshots and link control
|
||||
oscmd.SetMock("ip", func(args []string) oscmd.MockCommand {
|
||||
argStr := strings.Join(args, " ")
|
||||
if strings.Contains(argStr, "addr show") {
|
||||
out := `[{"ifname": "eth0", "operstate": "UP", "address": "aa:bb:cc:dd:ee:ff", "mtu": 1500, "addr_info": [{"family": "inet", "local": "192.168.1.30", "prefixlen": 24}]}]`
|
||||
return oscmd.MockCommand{Stdout: out, ExitCode: 0}
|
||||
}
|
||||
if strings.Contains(argStr, "route show") {
|
||||
out := `[{"dst": "default", "gateway": "192.168.1.1", "dev": "eth0"}]`
|
||||
return oscmd.MockCommand{Stdout: out, ExitCode: 0}
|
||||
}
|
||||
if strings.Contains(argStr, "link set") {
|
||||
return oscmd.MockCommand{ExitCode: 0}
|
||||
}
|
||||
return oscmd.MockCommand{ExitCode: 1}
|
||||
})
|
||||
defer oscmd.ClearMocks()
|
||||
|
||||
// 1. Test nmcliBackend
|
||||
nm := &nmcliBackend{}
|
||||
oscmd.SetMock("nmcli", func(args []string) oscmd.MockCommand {
|
||||
argStr := strings.Join(args, " ")
|
||||
t.Logf("nmcli mock called with: %s", argStr)
|
||||
if strings.Contains(argStr, "con show --active") {
|
||||
return oscmd.MockCommand{Stdout: "myconn:eth0\n", ExitCode: 0}
|
||||
}
|
||||
if strings.Contains(argStr, "con show") && strings.Contains(argStr, "myconn") {
|
||||
showOut := "ipv4.method:manual\nipv4.addresses:192.168.1.10/24\nipv4.gateway:192.168.1.1\nipv4.dns:1.1.1.1\nipv4.routes:10.0.0.0/24 192.168.1.1\n"
|
||||
return oscmd.MockCommand{Stdout: showOut, ExitCode: 0}
|
||||
}
|
||||
if (strings.Contains(argStr, "con modify") || strings.Contains(argStr, "con up") || strings.Contains(argStr, "con down")) && strings.Contains(argStr, "myconn") {
|
||||
return oscmd.MockCommand{ExitCode: 0}
|
||||
}
|
||||
return oscmd.MockCommand{ExitCode: 1}
|
||||
})
|
||||
|
||||
cfg, err := nm.Snapshot(t.Context(), "eth0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if cfg.Address != "192.168.1.10" || cfg.Gateway != "192.168.1.1" {
|
||||
t.Errorf("nmcli snapshot failed: %+v", cfg)
|
||||
}
|
||||
|
||||
err = nm.Apply(t.Context(), "eth0", IfaceConfig{
|
||||
Method: "static",
|
||||
Address: "192.168.1.20",
|
||||
Prefix: 24,
|
||||
Gateway: "192.168.1.1",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := nm.SetLinkUp(t.Context(), "eth0"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := nm.SetLinkDown(t.Context(), "eth0"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// 2. Test networkdBackend
|
||||
nd := &networkdBackend{}
|
||||
oldNdDir := networkdDir
|
||||
networkdDir = tempDir
|
||||
defer func() { networkdDir = oldNdDir }()
|
||||
|
||||
oscmd.SetMock("networkctl", func(args []string) oscmd.MockCommand {
|
||||
return oscmd.MockCommand{ExitCode: 0}
|
||||
})
|
||||
|
||||
err = nd.Apply(t.Context(), "eth0", IfaceConfig{
|
||||
Method: "static",
|
||||
Address: "192.168.1.30",
|
||||
Prefix: 24,
|
||||
Gateway: "192.168.1.1",
|
||||
DNS: []string{"8.8.8.8"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cfg, err = nd.Snapshot(t.Context(), "eth0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if cfg.Address != "192.168.1.30" || cfg.Method != "static" {
|
||||
t.Errorf("networkd snapshot failed: %+v", cfg)
|
||||
}
|
||||
|
||||
if err := nd.SetLinkUp(t.Context(), "eth0"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := nd.SetLinkDown(t.Context(), "eth0"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// 3. Test ifupdownBackend
|
||||
iu := &ifupdownBackend{}
|
||||
oldIfaceFile := interfacesFile
|
||||
oldInterfacesDDir := interfacesDDir
|
||||
interfacesFile = filepath.Join(tempDir, "interfaces")
|
||||
interfacesDDir = filepath.Join(tempDir, "interfaces.d")
|
||||
defer func() {
|
||||
interfacesFile = oldIfaceFile
|
||||
interfacesDDir = oldInterfacesDDir
|
||||
}()
|
||||
|
||||
if err := os.MkdirAll(interfacesDDir, 0755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(interfacesFile, []byte("auto lo\niface lo inet loopback\n"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
oscmd.SetMock("ifup", func(args []string) oscmd.MockCommand { return oscmd.MockCommand{ExitCode: 0} })
|
||||
oscmd.SetMock("ifdown", func(args []string) oscmd.MockCommand { return oscmd.MockCommand{ExitCode: 0} })
|
||||
|
||||
err = iu.Apply(t.Context(), "eth0", IfaceConfig{
|
||||
Method: "static",
|
||||
Address: "192.168.1.40",
|
||||
Prefix: 24,
|
||||
Gateway: "192.168.1.1",
|
||||
DNS: []string{"1.1.1.1"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cfg, err = iu.Snapshot(t.Context(), "eth0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if cfg.Address != "192.168.1.40" || cfg.Method != "static" {
|
||||
t.Errorf("ifupdown snapshot failed: %+v", cfg)
|
||||
}
|
||||
|
||||
if err := iu.SetLinkUp(t.Context(), "eth0"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := iu.SetLinkDown(t.Context(), "eth0"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,396 @@
|
||||
package networking
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseInterfaces(t *testing.T) {
|
||||
// Trimmed real output from `ip -j addr` on a typical host.
|
||||
input := `[
|
||||
{
|
||||
"ifname": "lo",
|
||||
"address": "00:00:00:00:00:00",
|
||||
"mtu": 65536,
|
||||
"operstate": "UNKNOWN",
|
||||
"addr_info": [
|
||||
{"family": "inet", "local": "127.0.0.1", "prefixlen": 8},
|
||||
{"family": "inet6", "local": "::1", "prefixlen": 128}
|
||||
]
|
||||
},
|
||||
{
|
||||
"ifname": "eth0",
|
||||
"address": "52:54:00:12:34:56",
|
||||
"mtu": 1500,
|
||||
"operstate": "UP",
|
||||
"addr_info": [
|
||||
{"family": "inet", "local": "192.168.1.10", "prefixlen": 24},
|
||||
{"family": "inet6", "local": "fe80::1", "prefixlen": 64}
|
||||
]
|
||||
}
|
||||
]`
|
||||
|
||||
ifaces, err := parseInterfaces(input)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(ifaces) != 2 {
|
||||
t.Fatalf("expected 2 interfaces, got %d", len(ifaces))
|
||||
}
|
||||
|
||||
lo := ifaces[0]
|
||||
if lo.Name != "lo" || lo.State != "unknown" || lo.MTU != 65536 {
|
||||
t.Errorf("lo: got %+v", lo)
|
||||
}
|
||||
if len(lo.IPv4) != 1 || lo.IPv4[0] != "127.0.0.1/8" {
|
||||
t.Errorf("lo ipv4: got %v", lo.IPv4)
|
||||
}
|
||||
|
||||
eth0 := ifaces[1]
|
||||
if eth0.Name != "eth0" || eth0.State != "up" || eth0.MAC != "52:54:00:12:34:56" {
|
||||
t.Errorf("eth0: got %+v", eth0)
|
||||
}
|
||||
if len(eth0.IPv4) != 1 || eth0.IPv4[0] != "192.168.1.10/24" {
|
||||
t.Errorf("eth0 ipv4: got %v", eth0.IPv4)
|
||||
}
|
||||
if len(eth0.IPv6) != 1 || eth0.IPv6[0] != "fe80::1/64" {
|
||||
t.Errorf("eth0 ipv6: got %v", eth0.IPv6)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseRoutes(t *testing.T) {
|
||||
input := `[
|
||||
{"dst": "default", "gateway": "192.168.1.1", "dev": "eth0", "prefsrc": "192.168.1.10", "metric": 100},
|
||||
{"dst": "192.168.1.0/24", "dev": "eth0", "prefsrc": "192.168.1.10"}
|
||||
]`
|
||||
|
||||
routes, err := parseRoutes(input)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(routes) != 2 {
|
||||
t.Fatalf("expected 2 routes, got %d", len(routes))
|
||||
}
|
||||
if routes[0].Destination != "default" || routes[0].Gateway != "192.168.1.1" || routes[0].Metric != 100 {
|
||||
t.Errorf("route 0: got %+v", routes[0])
|
||||
}
|
||||
if routes[1].Destination != "192.168.1.0/24" || routes[1].Interface != "eth0" {
|
||||
t.Errorf("route 1: got %+v", routes[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseResolv(t *testing.T) {
|
||||
input := `# Generated by NetworkManager
|
||||
nameserver 1.1.1.1
|
||||
nameserver 8.8.8.8
|
||||
; legacy comment
|
||||
search example.com
|
||||
nameserver 9.9.9.9
|
||||
`
|
||||
servers := parseResolv(input)
|
||||
want := []string{"1.1.1.1", "8.8.8.8", "9.9.9.9"}
|
||||
if !reflect.DeepEqual(servers, want) {
|
||||
t.Errorf("parseResolv() = %v, want %v", servers, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseResolvEmpty(t *testing.T) {
|
||||
servers := parseResolv("")
|
||||
if len(servers) != 0 {
|
||||
t.Errorf("expected empty, got %v", servers)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseNmcliSnapshot(t *testing.T) {
|
||||
input := `ipv4.method:manual
|
||||
ipv4.addresses:192.168.1.10/24
|
||||
ipv4.gateway:192.168.1.1
|
||||
ipv4.dns:1.1.1.1,8.8.8.8
|
||||
ipv4.routes:--`
|
||||
|
||||
cfg := parseNmcliSnapshot(input)
|
||||
if cfg.Method != "static" {
|
||||
t.Errorf("method: got %q, want static", cfg.Method)
|
||||
}
|
||||
if cfg.Address != "192.168.1.10" || cfg.Prefix != 24 {
|
||||
t.Errorf("address: got %s/%d", cfg.Address, cfg.Prefix)
|
||||
}
|
||||
if cfg.Gateway != "192.168.1.1" {
|
||||
t.Errorf("gateway: got %q", cfg.Gateway)
|
||||
}
|
||||
if len(cfg.DNS) != 2 || cfg.DNS[0] != "1.1.1.1" || cfg.DNS[1] != "8.8.8.8" {
|
||||
t.Errorf("dns: got %v", cfg.DNS)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseNmcliRoutes(t *testing.T) {
|
||||
input := `dst=10.0.0.0/24, nh=192.168.1.1; dst=172.16.0.0/12, nh=10.0.0.1`
|
||||
routes := parseNmcliRoutes(input)
|
||||
if len(routes) != 2 {
|
||||
t.Fatalf("expected 2 routes, got %d", len(routes))
|
||||
}
|
||||
if routes[0].Destination != "10.0.0.0/24" || routes[0].Gateway != "192.168.1.1" {
|
||||
t.Errorf("route 0: got %v", routes[0])
|
||||
}
|
||||
if routes[1].Destination != "172.16.0.0/12" || routes[1].Gateway != "10.0.0.1" {
|
||||
t.Errorf("route 1: got %v", routes[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseNmcliSnapshotDHCP(t *testing.T) {
|
||||
input := `ipv4.method:auto
|
||||
ipv4.addresses:--
|
||||
ipv4.gateway:--
|
||||
ipv4.dns:--
|
||||
ipv4.routes:--`
|
||||
|
||||
cfg := parseNmcliSnapshot(input)
|
||||
if cfg.Method != "dhcp" {
|
||||
t.Errorf("method: got %q, want dhcp", cfg.Method)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseNetworkdFile(t *testing.T) {
|
||||
input := `# Managed by nadir
|
||||
[Match]
|
||||
Name=eth0
|
||||
|
||||
[Network]
|
||||
DHCP=no
|
||||
Address=10.0.0.5/24
|
||||
Gateway=10.0.0.1
|
||||
DNS=1.1.1.1
|
||||
DNS=8.8.8.8
|
||||
|
||||
[Route]
|
||||
Destination=192.168.0.0/16
|
||||
Gateway=10.0.0.254
|
||||
`
|
||||
|
||||
cfg := parseNetworkdFile(input)
|
||||
if cfg.Method != "static" {
|
||||
t.Errorf("method: got %q, want static", cfg.Method)
|
||||
}
|
||||
if cfg.Address != "10.0.0.5" || cfg.Prefix != 24 {
|
||||
t.Errorf("address: got %s/%d", cfg.Address, cfg.Prefix)
|
||||
}
|
||||
if cfg.Gateway != "10.0.0.1" {
|
||||
t.Errorf("gateway: got %q", cfg.Gateway)
|
||||
}
|
||||
if len(cfg.DNS) != 2 {
|
||||
t.Errorf("dns: got %v", cfg.DNS)
|
||||
}
|
||||
if len(cfg.Routes) != 1 || cfg.Routes[0].Destination != "192.168.0.0/16" || cfg.Routes[0].Gateway != "10.0.0.254" {
|
||||
t.Errorf("routes: got %v", cfg.Routes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseNetworkdFileDHCP(t *testing.T) {
|
||||
input := `[Match]
|
||||
Name=eth0
|
||||
|
||||
[Network]
|
||||
DHCP=yes
|
||||
`
|
||||
cfg := parseNetworkdFile(input)
|
||||
if cfg.Method != "dhcp" {
|
||||
t.Errorf("method: got %q, want dhcp", cfg.Method)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseIfupdownStanza(t *testing.T) {
|
||||
input := `# Managed by nadir
|
||||
auto eth0
|
||||
iface eth0 inet static
|
||||
address 192.168.1.10/24
|
||||
gateway 192.168.1.1
|
||||
dns-nameservers 1.1.1.1 8.8.8.8
|
||||
up ip route add 10.0.0.0/24 via 192.168.1.254
|
||||
down ip route del 10.0.0.0/24 via 192.168.1.254
|
||||
`
|
||||
|
||||
cfg := parseIfupdownStanza(input)
|
||||
if cfg.Method != "static" {
|
||||
t.Errorf("method: got %q, want static", cfg.Method)
|
||||
}
|
||||
if cfg.Address != "192.168.1.10" || cfg.Prefix != 24 {
|
||||
t.Errorf("address: got %s/%d", cfg.Address, cfg.Prefix)
|
||||
}
|
||||
if cfg.Gateway != "192.168.1.1" {
|
||||
t.Errorf("gateway: got %q", cfg.Gateway)
|
||||
}
|
||||
if len(cfg.DNS) != 2 || cfg.DNS[0] != "1.1.1.1" || cfg.DNS[1] != "8.8.8.8" {
|
||||
t.Errorf("dns: got %v", cfg.DNS)
|
||||
}
|
||||
if len(cfg.Routes) != 1 || cfg.Routes[0].Destination != "10.0.0.0/24" || cfg.Routes[0].Gateway != "192.168.1.254" {
|
||||
t.Errorf("routes: got %v", cfg.Routes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseIfupdownStanzaDHCP(t *testing.T) {
|
||||
input := `auto eth0
|
||||
iface eth0 inet dhcp
|
||||
`
|
||||
cfg := parseIfupdownStanza(input)
|
||||
if cfg.Method != "dhcp" {
|
||||
t.Errorf("method: got %q, want dhcp", cfg.Method)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderNetworkdFile(t *testing.T) {
|
||||
cfg := IfaceConfig{
|
||||
Method: "static",
|
||||
Address: "10.0.0.5",
|
||||
Prefix: 24,
|
||||
Gateway: "10.0.0.1",
|
||||
DNS: []string{"1.1.1.1"},
|
||||
Routes: []Route{{Destination: "192.168.0.0/16", Gateway: "10.0.0.254"}},
|
||||
}
|
||||
out := renderNetworkdFile("eth0", cfg)
|
||||
|
||||
mustContain := []string{
|
||||
"Name=eth0",
|
||||
"DHCP=no",
|
||||
"Address=10.0.0.5/24",
|
||||
"Gateway=10.0.0.1",
|
||||
"DNS=1.1.1.1",
|
||||
"Destination=192.168.0.0/16",
|
||||
}
|
||||
for _, s := range mustContain {
|
||||
if !strings.Contains(out, s) {
|
||||
t.Errorf("renderNetworkdFile missing %q in:\n%s", s, out)
|
||||
}
|
||||
}
|
||||
|
||||
// Roundtrip test
|
||||
parsed := parseNetworkdFile(out)
|
||||
if !reflect.DeepEqual(parsed, cfg) {
|
||||
t.Errorf("roundtrip failed: got %+v, want %+v", parsed, cfg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderIfupdownStanza(t *testing.T) {
|
||||
cfg := IfaceConfig{
|
||||
Method: "static",
|
||||
Address: "192.168.1.10",
|
||||
Prefix: 24,
|
||||
Gateway: "192.168.1.1",
|
||||
DNS: []string{"1.1.1.1", "8.8.8.8"},
|
||||
Routes: []Route{{Destination: "10.0.0.0/24", Gateway: "192.168.1.254"}},
|
||||
}
|
||||
out := renderIfupdownStanza("eth0", cfg)
|
||||
|
||||
mustContain := []string{
|
||||
"iface eth0 inet static",
|
||||
"address 192.168.1.10/24",
|
||||
"gateway 192.168.1.1",
|
||||
"dns-nameservers 1.1.1.1 8.8.8.8",
|
||||
"up ip route add 10.0.0.0/24 via 192.168.1.254",
|
||||
"down ip route del 10.0.0.0/24 via 192.168.1.254",
|
||||
}
|
||||
for _, s := range mustContain {
|
||||
if !strings.Contains(out, s) {
|
||||
t.Errorf("renderIfupdownStanza missing %q in:\n%s", s, out)
|
||||
}
|
||||
}
|
||||
|
||||
// Roundtrip test
|
||||
parsed := parseIfupdownStanza(out)
|
||||
if !reflect.DeepEqual(parsed, cfg) {
|
||||
t.Errorf("roundtrip failed: got %+v, want %+v", parsed, cfg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateIPv6(t *testing.T) {
|
||||
ok := IfaceConfig{Method: "dhcp", IPv6: &IPv6Config{Method: "static", Address: "2001:db8::1", Prefix: 64, Gateway: "2001:db8::ff"}}
|
||||
if err := ok.validate(); err != nil {
|
||||
t.Errorf("valid ipv6 rejected: %v", err)
|
||||
}
|
||||
bad := []IfaceConfig{
|
||||
{Method: "dhcp", IPv6: &IPv6Config{Method: "static", Address: "1.2.3.4", Prefix: 64}}, // v4 in v6 block
|
||||
{Method: "dhcp", IPv6: &IPv6Config{Method: "static", Address: "2001:db8::1", Prefix: 0}},
|
||||
{Method: "dhcp", IPv6: &IPv6Config{Method: "bogus"}},
|
||||
}
|
||||
for i, c := range bad {
|
||||
if err := c.validate(); err == nil {
|
||||
t.Errorf("bad ipv6 case %d accepted", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseNmcliSnapshotIPv6(t *testing.T) {
|
||||
input := `ipv4.method:manual
|
||||
ipv4.addresses:192.168.1.10/24
|
||||
ipv6.method:manual
|
||||
ipv6.addresses:2001:db8::10/64
|
||||
ipv6.gateway:2001:db8::1`
|
||||
cfg := parseNmcliSnapshot(input)
|
||||
if cfg.IPv6 == nil {
|
||||
t.Fatal("ipv6 not captured")
|
||||
}
|
||||
if cfg.IPv6.Method != "static" || cfg.IPv6.Address != "2001:db8::10" || cfg.IPv6.Prefix != 64 || cfg.IPv6.Gateway != "2001:db8::1" {
|
||||
t.Errorf("ipv6 snapshot: %+v", cfg.IPv6)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNetworkdIPv6RoundTrip(t *testing.T) {
|
||||
cfg := IfaceConfig{
|
||||
Method: "static",
|
||||
Address: "10.0.0.5",
|
||||
Prefix: 24,
|
||||
IPv6: &IPv6Config{Method: "static", Address: "2001:db8::5", Prefix: 64, Gateway: "2001:db8::1"},
|
||||
}
|
||||
got := parseNetworkdFile(renderNetworkdFile("eth0", cfg))
|
||||
if got.Address != "10.0.0.5" || got.Prefix != 24 {
|
||||
t.Errorf("ipv4 lost in roundtrip: %+v", got)
|
||||
}
|
||||
if !reflect.DeepEqual(got.IPv6, cfg.IPv6) {
|
||||
t.Errorf("ipv6 roundtrip: got %+v want %+v", got.IPv6, cfg.IPv6)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIfupdownIPv6RoundTrip(t *testing.T) {
|
||||
cfg := IfaceConfig{
|
||||
Method: "static",
|
||||
Address: "192.168.1.10",
|
||||
Prefix: 24,
|
||||
IPv6: &IPv6Config{Method: "static", Address: "2001:db8::10", Prefix: 64, Gateway: "2001:db8::1"},
|
||||
}
|
||||
got := parseIfupdownStanza(renderIfupdownStanza("eth0", cfg))
|
||||
if got.Address != "192.168.1.10" || got.Prefix != 24 || got.Method != "static" {
|
||||
t.Errorf("ipv4 lost in roundtrip: %+v", got)
|
||||
}
|
||||
if !reflect.DeepEqual(got.IPv6, cfg.IPv6) {
|
||||
t.Errorf("ipv6 roundtrip: got %+v want %+v", got.IPv6, cfg.IPv6)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNetworkdIPv6AutoIgnore(t *testing.T) {
|
||||
auto := parseNetworkdFile(renderNetworkdFile("eth0", IfaceConfig{Method: "dhcp", IPv6: &IPv6Config{Method: "auto"}}))
|
||||
if auto.IPv6 == nil || auto.IPv6.Method != "auto" {
|
||||
t.Errorf("auto roundtrip: %+v", auto.IPv6)
|
||||
}
|
||||
ign := parseNetworkdFile(renderNetworkdFile("eth0", IfaceConfig{Method: "dhcp", IPv6: &IPv6Config{Method: "ignore"}}))
|
||||
if ign.IPv6 == nil || ign.IPv6.Method != "ignore" {
|
||||
t.Errorf("ignore roundtrip: %+v", ign.IPv6)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitCIDR(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
wantAddr string
|
||||
wantPrefix int
|
||||
}{
|
||||
{"192.168.1.10/24", "192.168.1.10", 24},
|
||||
{"10.0.0.1/8", "10.0.0.1", 8},
|
||||
{"10.0.0.1", "10.0.0.1", 0},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
addr, prefix := splitCIDR(tt.input)
|
||||
if addr != tt.wantAddr || prefix != tt.wantPrefix {
|
||||
t.Errorf("splitCIDR(%q) = (%q, %d), want (%q, %d)", tt.input, addr, prefix, tt.wantAddr, tt.wantPrefix)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,274 @@
|
||||
package networking
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"nadir/internal/oscmd"
|
||||
)
|
||||
|
||||
// nmcliBackend implements backend via NetworkManager's nmcli CLI.
|
||||
type nmcliBackend struct{}
|
||||
|
||||
func (b *nmcliBackend) Name() string { return "nmcli" }
|
||||
|
||||
// connForIface resolves a network interface name to the NM connection name
|
||||
// that owns it. Returns an error if the interface has no active connection.
|
||||
//
|
||||
// nmcli -t uses ':' as the field separator. Connection names can contain colons
|
||||
// (e.g. "VLAN:100"), but Linux device names cannot, so we split on the last
|
||||
// colon to get the device and treat everything before it as the connection name.
|
||||
func connForIface(ctx context.Context, iface string) (string, error) {
|
||||
out, err := oscmd.RunContext(ctx, "nmcli", "-t", "-f", "NAME,DEVICE", "con", "show", "--active")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("nmcli con show: %w", err)
|
||||
}
|
||||
for line := range strings.SplitSeq(out, "\n") {
|
||||
// Split on the last colon: "conn:name:eth0" → ("conn:name", "eth0")
|
||||
idx := strings.LastIndex(line, ":")
|
||||
if idx < 0 {
|
||||
continue
|
||||
}
|
||||
name, dev := line[:idx], line[idx+1:]
|
||||
if dev == iface {
|
||||
return name, nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("no active NM connection found for interface %s", iface)
|
||||
}
|
||||
|
||||
func (b *nmcliBackend) Snapshot(ctx context.Context, iface string) (IfaceConfig, error) {
|
||||
conn, err := connForIface(ctx, iface)
|
||||
if err != nil {
|
||||
// No managed connection → assume DHCP (safest rollback assumption).
|
||||
return IfaceConfig{Method: "dhcp"}, nil
|
||||
}
|
||||
|
||||
out, err := oscmd.RunContext(ctx, "nmcli", "-t", "-f",
|
||||
"ipv4.method,ipv4.addresses,ipv4.gateway,ipv4.dns,ipv4.routes,ipv6.method,ipv6.addresses,ipv6.gateway",
|
||||
"con", "show", "--", conn)
|
||||
if err != nil {
|
||||
return IfaceConfig{}, fmt.Errorf("nmcli con show %s: %w", conn, err)
|
||||
}
|
||||
|
||||
return parseNmcliSnapshot(out), nil
|
||||
}
|
||||
|
||||
// parseNmcliSnapshot parses the terse output of `nmcli -t -f ... con show`.
|
||||
// Fields are colon-separated key:value lines. Multi-valued fields (addresses,
|
||||
// dns, routes) use comma separation within the value.
|
||||
func parseNmcliSnapshot(out string) IfaceConfig {
|
||||
cfg := IfaceConfig{Method: "dhcp"}
|
||||
for line := range strings.SplitSeq(out, "\n") {
|
||||
key, val, ok := strings.Cut(line, ":")
|
||||
if !ok || val == "" || val == "--" {
|
||||
continue
|
||||
}
|
||||
switch key {
|
||||
case "ipv4.method":
|
||||
if val == "manual" {
|
||||
cfg.Method = "static"
|
||||
} else {
|
||||
cfg.Method = "dhcp"
|
||||
}
|
||||
case "ipv4.addresses":
|
||||
// "192.168.1.10/24" or "192.168.1.10/24, 10.0.0.1/8"
|
||||
for _, part := range splitTrim(val, ",") {
|
||||
addr, prefix := splitCIDR(part)
|
||||
if addr != "" {
|
||||
cfg.Address = addr
|
||||
cfg.Prefix = prefix
|
||||
}
|
||||
}
|
||||
case "ipv4.gateway":
|
||||
cfg.Gateway = val
|
||||
case "ipv4.dns":
|
||||
cfg.DNS = append(cfg.DNS, splitTrim(val, ",")...)
|
||||
case "ipv4.routes":
|
||||
// "dst=10.0.0.0/24, nh=192.168.1.1; dst=default, nh=192.168.1.1"
|
||||
// or simpler "{ dst = 10.0.0.0/24, nh = 192.168.1.1 }"
|
||||
cfg.Routes = parseNmcliRoutes(val)
|
||||
case "ipv6.method":
|
||||
v6(&cfg).Method = nmcliV6Method(val)
|
||||
case "ipv6.addresses":
|
||||
for _, part := range splitTrim(val, ",") {
|
||||
if addr, prefix := splitCIDR(part); addr != "" {
|
||||
v6(&cfg).Address = addr
|
||||
v6(&cfg).Prefix = prefix
|
||||
}
|
||||
}
|
||||
case "ipv6.gateway":
|
||||
v6(&cfg).Gateway = val
|
||||
}
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
// v6 lazily allocates the IPv6 block so a snapshot captures the live IPv6 state
|
||||
// (defaulting to "auto") even when only some ipv6.* fields are present.
|
||||
func v6(cfg *IfaceConfig) *IPv6Config {
|
||||
if cfg.IPv6 == nil {
|
||||
cfg.IPv6 = &IPv6Config{Method: "auto"}
|
||||
}
|
||||
return cfg.IPv6
|
||||
}
|
||||
|
||||
// nmcliV6Method maps nmcli's ipv6.method to our vocabulary.
|
||||
func nmcliV6Method(val string) string {
|
||||
switch val {
|
||||
case "manual":
|
||||
return "static"
|
||||
case "ignore", "disabled", "link-local":
|
||||
return "ignore"
|
||||
default:
|
||||
return "auto"
|
||||
}
|
||||
}
|
||||
|
||||
// parseNmcliRoutes parses route entries from nmcli terse output.
|
||||
func parseNmcliRoutes(val string) []Route {
|
||||
var routes []Route
|
||||
// Routes are semicolon-separated, each containing "dst=..., nh=..."
|
||||
for _, entry := range splitTrim(val, ";") {
|
||||
entry = strings.Trim(entry, "{}")
|
||||
var r Route
|
||||
for _, part := range splitTrim(entry, ",") {
|
||||
k, v, ok := strings.Cut(part, "=")
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
switch strings.TrimSpace(k) {
|
||||
case "dst":
|
||||
r.Destination = strings.TrimSpace(v)
|
||||
case "nh":
|
||||
r.Gateway = strings.TrimSpace(v)
|
||||
}
|
||||
}
|
||||
if r.Destination != "" && r.Gateway != "" {
|
||||
routes = append(routes, r)
|
||||
}
|
||||
}
|
||||
return routes
|
||||
}
|
||||
|
||||
func (b *nmcliBackend) Apply(ctx context.Context, iface string, cfg IfaceConfig) error {
|
||||
conn, err := connForIface(ctx, iface)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot apply: %w", err)
|
||||
}
|
||||
|
||||
// Build the nmcli con modify arguments. Note: conn is safe to place after
|
||||
// -- since it comes from nmcli output, not directly from the user.
|
||||
args := []string{"con", "modify", "--", conn}
|
||||
|
||||
switch cfg.Method {
|
||||
case "static":
|
||||
cidr := fmt.Sprintf("%s/%d", cfg.Address, cfg.Prefix)
|
||||
args = append(args,
|
||||
"ipv4.method", "manual",
|
||||
"ipv4.addresses", cidr,
|
||||
)
|
||||
if cfg.Gateway != "" {
|
||||
args = append(args, "ipv4.gateway", cfg.Gateway)
|
||||
} else {
|
||||
args = append(args, "ipv4.gateway", "")
|
||||
}
|
||||
case "dhcp":
|
||||
args = append(args,
|
||||
"ipv4.method", "auto",
|
||||
"ipv4.addresses", "",
|
||||
"ipv4.gateway", "",
|
||||
)
|
||||
}
|
||||
|
||||
// DNS: set or clear.
|
||||
if len(cfg.DNS) > 0 {
|
||||
args = append(args, "ipv4.dns", strings.Join(cfg.DNS, ","))
|
||||
} else {
|
||||
args = append(args, "ipv4.dns", "")
|
||||
}
|
||||
|
||||
// Routes: set or clear.
|
||||
if len(cfg.Routes) > 0 {
|
||||
var routeStrs []string
|
||||
for _, r := range cfg.Routes {
|
||||
routeStrs = append(routeStrs, r.Destination+" "+r.Gateway)
|
||||
}
|
||||
args = append(args, "ipv4.routes", strings.Join(routeStrs, ","))
|
||||
} else {
|
||||
args = append(args, "ipv4.routes", "")
|
||||
}
|
||||
|
||||
// IPv6: only touched when the request includes an ipv6 block.
|
||||
if cfg.IPv6 != nil {
|
||||
switch cfg.IPv6.Method {
|
||||
case "static":
|
||||
args = append(args,
|
||||
"ipv6.method", "manual",
|
||||
"ipv6.addresses", fmt.Sprintf("%s/%d", cfg.IPv6.Address, cfg.IPv6.Prefix),
|
||||
"ipv6.gateway", cfg.IPv6.Gateway, // "" clears it
|
||||
)
|
||||
case "auto":
|
||||
args = append(args, "ipv6.method", "auto", "ipv6.addresses", "", "ipv6.gateway", "")
|
||||
case "ignore":
|
||||
args = append(args, "ipv6.method", "ignore", "ipv6.addresses", "", "ipv6.gateway", "")
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := oscmd.RunContext(ctx, "nmcli", args...); err != nil {
|
||||
return fmt.Errorf("nmcli con modify: %w", err)
|
||||
}
|
||||
|
||||
// Bring the connection up to apply changes.
|
||||
if _, err := oscmd.RunContext(ctx, "nmcli", "con", "up", "--", conn); err != nil {
|
||||
return fmt.Errorf("nmcli con up: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *nmcliBackend) SetLinkUp(ctx context.Context, iface string) error {
|
||||
conn, err := connForIface(ctx, iface)
|
||||
if err != nil {
|
||||
// No NM connection - fall back to `ip link`.
|
||||
_, err = oscmd.RunContext(ctx, "ip", "link", "set", "--", iface, "up")
|
||||
return err
|
||||
}
|
||||
_, err = oscmd.RunContext(ctx, "nmcli", "con", "up", "--", conn)
|
||||
return err
|
||||
}
|
||||
|
||||
func (b *nmcliBackend) SetLinkDown(ctx context.Context, iface string) error {
|
||||
conn, err := connForIface(ctx, iface)
|
||||
if err != nil {
|
||||
_, err = oscmd.RunContext(ctx, "ip", "link", "set", "--", iface, "down")
|
||||
return err
|
||||
}
|
||||
_, err = oscmd.RunContext(ctx, "nmcli", "con", "down", "--", conn)
|
||||
return err
|
||||
}
|
||||
|
||||
// --- helpers -----------------------------------------------------------------
|
||||
|
||||
// splitTrim splits s on sep and returns the trimmed, non-empty segments.
|
||||
func splitTrim(s, sep string) []string {
|
||||
var out []string
|
||||
for part := range strings.SplitSeq(s, sep) {
|
||||
if p := strings.TrimSpace(part); p != "" {
|
||||
out = append(out, p)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// splitCIDR splits "192.168.1.10/24" into ("192.168.1.10", 24). A missing or
|
||||
// malformed prefix yields 0.
|
||||
func splitCIDR(cidr string) (string, int) {
|
||||
addr, prefixStr, ok := strings.Cut(cidr, "/")
|
||||
if !ok {
|
||||
return addr, 0
|
||||
}
|
||||
prefix, _ := strconv.Atoi(prefixStr)
|
||||
return addr, prefix
|
||||
}
|
||||
@@ -0,0 +1,211 @@
|
||||
package networking
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"nadir/internal/oscmd"
|
||||
|
||||
"github.com/danielgtaylor/huma/v2"
|
||||
)
|
||||
|
||||
// Reads go through `ip -j` (JSON) and /etc/resolv.conf, which behave the same
|
||||
// regardless of which backend manages the interfaces - so reads need no backend
|
||||
// detection.
|
||||
|
||||
var resolvConf = "/etc/resolv.conf"
|
||||
|
||||
type Interface struct {
|
||||
Name string `json:"name" example:"eth0"`
|
||||
State string `json:"state" example:"up" doc:"operstate: up / down / unknown"`
|
||||
MAC string `json:"mac" example:"52:54:00:12:34:56"`
|
||||
MTU int `json:"mtu" example:"1500"`
|
||||
IPv4 []string `json:"ipv4" example:"[\"192.168.1.10/24\"]" doc:"IPv4 addresses in CIDR form"`
|
||||
IPv6 []string `json:"ipv6" example:"[\"fe80::1/64\"]" doc:"IPv6 addresses in CIDR form"`
|
||||
}
|
||||
|
||||
type ListInterfacesOutput struct {
|
||||
Body struct {
|
||||
Interfaces []Interface `json:"interfaces"`
|
||||
}
|
||||
}
|
||||
|
||||
type RouteEntry struct {
|
||||
Destination string `json:"destination" example:"default" doc:"Destination network, or \"default\""`
|
||||
Gateway string `json:"gateway,omitempty" example:"192.168.1.1"`
|
||||
Interface string `json:"interface" example:"eth0"`
|
||||
Source string `json:"source,omitempty" example:"192.168.1.10" doc:"Preferred source address"`
|
||||
Metric int `json:"metric,omitempty" example:"100"`
|
||||
}
|
||||
|
||||
type ListRoutesOutput struct {
|
||||
Body struct {
|
||||
Routes []RouteEntry `json:"routes"`
|
||||
}
|
||||
}
|
||||
|
||||
type DNSOutput struct {
|
||||
Body struct {
|
||||
Servers []string `json:"servers" example:"[\"1.1.1.1\"]" doc:"Nameservers from /etc/resolv.conf"`
|
||||
}
|
||||
}
|
||||
|
||||
func registerReads(api huma.API) {
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "networking-list-interfaces",
|
||||
Method: "GET",
|
||||
Path: "/api/networking/interfaces",
|
||||
Summary: "List network interfaces",
|
||||
Description: "Returns every interface with its state, MAC, MTU and addresses, via `ip -j addr`.",
|
||||
Tags: []string{tagNetworking},
|
||||
Metadata: op("read"),
|
||||
Errors: readErrors,
|
||||
}, func(ctx context.Context, _ *struct{}) (*ListInterfacesOutput, error) {
|
||||
out, err := oscmd.RunContext(ctx, "ip", "-j", "addr")
|
||||
if err != nil {
|
||||
return nil, huma.Error500InternalServerError("ip addr failed", err)
|
||||
}
|
||||
ifaces, err := parseInterfaces(out)
|
||||
if err != nil {
|
||||
return nil, huma.Error500InternalServerError("parse ip addr failed", err)
|
||||
}
|
||||
res := &ListInterfacesOutput{}
|
||||
res.Body.Interfaces = ifaces
|
||||
return res, nil
|
||||
})
|
||||
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "networking-list-routes",
|
||||
Method: "GET",
|
||||
Path: "/api/networking/routes",
|
||||
Summary: "List the IPv4 route table",
|
||||
Description: "Returns the kernel IPv4 route table via `ip -j route`.",
|
||||
Tags: []string{tagNetworking},
|
||||
Metadata: op("read"),
|
||||
Errors: readErrors,
|
||||
}, func(ctx context.Context, _ *struct{}) (*ListRoutesOutput, error) {
|
||||
out, err := oscmd.RunContext(ctx, "ip", "-j", "route")
|
||||
if err != nil {
|
||||
return nil, huma.Error500InternalServerError("ip route failed", err)
|
||||
}
|
||||
routes, err := parseRoutes(out)
|
||||
if err != nil {
|
||||
return nil, huma.Error500InternalServerError("parse ip route failed", err)
|
||||
}
|
||||
res := &ListRoutesOutput{}
|
||||
res.Body.Routes = routes
|
||||
return res, nil
|
||||
})
|
||||
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "networking-get-dns",
|
||||
Method: "GET",
|
||||
Path: "/api/networking/dns",
|
||||
Summary: "Get configured DNS servers",
|
||||
Description: "Returns the nameservers listed in /etc/resolv.conf. DNS is set " +
|
||||
"per-interface as part of the interface config (PUT /api/networking/interfaces/{name}), " +
|
||||
"so there is no standalone DNS write endpoint.",
|
||||
Tags: []string{tagNetworking},
|
||||
Metadata: op("read"),
|
||||
Errors: readErrors,
|
||||
}, func(ctx context.Context, _ *struct{}) (*DNSOutput, error) {
|
||||
data, err := os.ReadFile(resolvConf)
|
||||
if err != nil {
|
||||
return nil, huma.Error500InternalServerError("read resolv.conf failed", err)
|
||||
}
|
||||
res := &DNSOutput{}
|
||||
res.Body.Servers = parseResolv(string(data))
|
||||
return res, nil
|
||||
})
|
||||
}
|
||||
|
||||
// --- parsers (pure, tested) --------------------------------------------------
|
||||
|
||||
// ipAddr mirrors the fields we use from one `ip -j addr` element.
|
||||
type ipAddr struct {
|
||||
Name string `json:"ifname"`
|
||||
MAC string `json:"address"`
|
||||
MTU int `json:"mtu"`
|
||||
OperState string `json:"operstate"`
|
||||
AddrInfo []struct {
|
||||
Family string `json:"family"` // "inet" / "inet6"
|
||||
Local string `json:"local"`
|
||||
Prefix int `json:"prefixlen"`
|
||||
} `json:"addr_info"`
|
||||
}
|
||||
|
||||
func parseInterfaces(jsonOut string) ([]Interface, error) {
|
||||
var raw []ipAddr
|
||||
if err := json.Unmarshal([]byte(jsonOut), &raw); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ifaces := make([]Interface, 0, len(raw))
|
||||
for _, r := range raw {
|
||||
iface := Interface{
|
||||
Name: r.Name,
|
||||
State: strings.ToLower(r.OperState),
|
||||
MAC: r.MAC,
|
||||
MTU: r.MTU,
|
||||
IPv4: []string{},
|
||||
IPv6: []string{},
|
||||
}
|
||||
for _, a := range r.AddrInfo {
|
||||
cidr := a.Local + "/" + strconv.Itoa(a.Prefix)
|
||||
if a.Family == "inet6" {
|
||||
iface.IPv6 = append(iface.IPv6, cidr)
|
||||
} else {
|
||||
iface.IPv4 = append(iface.IPv4, cidr)
|
||||
}
|
||||
}
|
||||
ifaces = append(ifaces, iface)
|
||||
}
|
||||
return ifaces, nil
|
||||
}
|
||||
|
||||
// ipRoute mirrors the fields we use from one `ip -j route` element.
|
||||
type ipRoute struct {
|
||||
Dst string `json:"dst"`
|
||||
Gateway string `json:"gateway"`
|
||||
Dev string `json:"dev"`
|
||||
PrefSrc string `json:"prefsrc"`
|
||||
Metric int `json:"metric"`
|
||||
}
|
||||
|
||||
func parseRoutes(jsonOut string) ([]RouteEntry, error) {
|
||||
var raw []ipRoute
|
||||
if err := json.Unmarshal([]byte(jsonOut), &raw); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
routes := make([]RouteEntry, 0, len(raw))
|
||||
for _, r := range raw {
|
||||
routes = append(routes, RouteEntry{
|
||||
Destination: r.Dst,
|
||||
Gateway: r.Gateway,
|
||||
Interface: r.Dev,
|
||||
Source: r.PrefSrc,
|
||||
Metric: r.Metric,
|
||||
})
|
||||
}
|
||||
return routes, nil
|
||||
}
|
||||
|
||||
// parseResolv extracts "nameserver X" entries, ignoring comments and other
|
||||
// directives.
|
||||
func parseResolv(text string) []string {
|
||||
servers := []string{}
|
||||
for line := range strings.SplitSeq(text, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if strings.HasPrefix(line, "#") || strings.HasPrefix(line, ";") {
|
||||
continue
|
||||
}
|
||||
if rest, ok := strings.CutPrefix(line, "nameserver"); ok {
|
||||
if s := strings.TrimSpace(rest); s != "" {
|
||||
servers = append(servers, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
return servers
|
||||
}
|
||||
@@ -0,0 +1,193 @@
|
||||
package networking
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
)
|
||||
|
||||
const defaultRollbackSeconds = 60
|
||||
|
||||
// errAlreadyPending is returned when another change is awaiting confirmation.
|
||||
// The write handlers map this to 409 Conflict.
|
||||
var errAlreadyPending = errors.New("already pending")
|
||||
|
||||
// errPending builds the 409 message. The lock is global across all interfaces
|
||||
// (see pendingChange), so the message says so to avoid confusing a user who is
|
||||
// touching a different interface than the one that holds the lock.
|
||||
func errPending(iface string) error {
|
||||
return fmt.Errorf("%w: a change to %s is awaiting confirmation. This is a global lock across all interfaces — confirm or roll that change back first", errAlreadyPending, iface)
|
||||
}
|
||||
|
||||
// startRollback snapshots the current state, applies the new config, and arms a
|
||||
// timer that auto-reverts if not confirmed. Returns errAlreadyPending (409) if
|
||||
// another change is in flight, or a wrapped error (500) if apply fails.
|
||||
//
|
||||
// Snapshot and Apply run WITHOUT the mutex held, so they don't block reads or
|
||||
// the pending-status endpoint while shelling out to nmcli/networkctl/ifup.
|
||||
func (m *Module) startRollback(ctx context.Context, iface string, cfg IfaceConfig) (int, error) {
|
||||
// Fast pre-check so we don't snapshot/apply when something is already
|
||||
// pending. armPending re-checks under the lock to close the race.
|
||||
if err := m.checkNoPending(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
prior, err := m.be.Snapshot(ctx, iface)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("snapshot %s: %w", iface, err)
|
||||
}
|
||||
if err := m.be.Apply(ctx, iface, cfg); err != nil {
|
||||
// Apply is not atomic: nmcli `con modify` may succeed before `con up`
|
||||
// fails, and networkd writes the .network file before `reconfigure`
|
||||
// runs. A failed Apply can therefore leave a half-applied config that
|
||||
// would otherwise have NO auto-revert (we bail before arming the timer).
|
||||
// Best-effort restore the snapshot so we never leave that unprotected.
|
||||
if rerr := m.be.Apply(ctx, iface, prior); rerr != nil {
|
||||
log.Printf("networking: apply %s failed and restore also failed: %v", iface, rerr)
|
||||
}
|
||||
return 0, fmt.Errorf("apply %s: %w", iface, err)
|
||||
}
|
||||
|
||||
// The revert runs from the timer or an explicit rollback, possibly with no
|
||||
// client attached, so it uses context.Background() rather than ctx.
|
||||
return m.armPending(iface, func() error { return m.be.Apply(context.Background(), iface, prior) }, cfg.RollbackSeconds)
|
||||
}
|
||||
|
||||
// startLinkDown takes the interface down behind the same rollback safety net: if
|
||||
// the change is not confirmed, the interface is brought back up. Taking a remote
|
||||
// interface down is just as much a lock-yourself-out risk as a bad static config.
|
||||
//
|
||||
// Bringing a link UP needs no protection (it cannot lock you out), so link-up
|
||||
// stays a direct, un-wrapped call in the handler.
|
||||
func (m *Module) startLinkDown(ctx context.Context, iface string) (int, error) {
|
||||
if err := m.checkNoPending(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if err := m.be.SetLinkDown(ctx, iface); err != nil {
|
||||
return 0, fmt.Errorf("link down %s: %w", iface, err)
|
||||
}
|
||||
// Revert (bring the link back up) may run from the timer with no client.
|
||||
return m.armPending(iface, func() error { return m.be.SetLinkUp(context.Background(), iface) }, 0)
|
||||
}
|
||||
|
||||
// checkNoPending reports a 409 error if a change is already pending.
|
||||
func (m *Module) checkNoPending() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if m.pending != nil {
|
||||
return errPending(m.pending.Iface)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// armPending installs the pending change and starts its auto-revert timer. The
|
||||
// caller has already applied the change; revert is the closure that undoes it.
|
||||
// It is invoked on timer expiry, explicit rollback, or if a concurrent change
|
||||
// raced us between the pre-check and here (in which case we revert immediately
|
||||
// and report the conflict). seconds <= 0 uses the default timeout.
|
||||
func (m *Module) armPending(iface string, revert func() error, seconds int) (int, error) {
|
||||
if seconds <= 0 {
|
||||
seconds = defaultRollbackSeconds
|
||||
}
|
||||
dur := time.Duration(seconds) * time.Second
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.pending != nil {
|
||||
// Lost the race — undo what we just applied and report conflict.
|
||||
if err := revert(); err != nil {
|
||||
log.Printf("networking: failed to undo raced change on %s: %v", iface, err)
|
||||
}
|
||||
return 0, errPending(m.pending.Iface)
|
||||
}
|
||||
|
||||
pc := &pendingChange{
|
||||
Iface: iface,
|
||||
revert: revert,
|
||||
Deadline: time.Now().Add(dur),
|
||||
}
|
||||
// The timer fires the auto-revert. It captures m and pc by closure so it can
|
||||
// revert even if the server is otherwise idle — the whole point is protecting
|
||||
// against being locked out of a remote box.
|
||||
pc.Timer = time.AfterFunc(dur, func() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
// Only revert if this exact change is still pending (it may have been
|
||||
// confirmed or manually rolled back in the meantime).
|
||||
if m.pending != pc {
|
||||
return
|
||||
}
|
||||
log.Printf("networking: rollback timer expired for %s — reverting", iface)
|
||||
if err := pc.revert(); err != nil {
|
||||
log.Printf("networking: auto-rollback of %s failed: %v", iface, err)
|
||||
}
|
||||
m.pending = nil
|
||||
})
|
||||
|
||||
m.pending = pc
|
||||
return seconds, nil
|
||||
}
|
||||
|
||||
// confirm cancels the rollback timer and clears the pending change, making it
|
||||
// permanent. Errors if there is no pending change or it's for another interface.
|
||||
func (m *Module) confirm(iface string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.pending == nil {
|
||||
return fmt.Errorf("no pending change to confirm")
|
||||
}
|
||||
if m.pending.Iface != iface {
|
||||
return fmt.Errorf("pending change is for %s, not %s", m.pending.Iface, iface)
|
||||
}
|
||||
|
||||
m.pending.Timer.Stop()
|
||||
m.pending = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// rollbackNow immediately reverts the pending change and clears it. Errors if
|
||||
// there is no pending change or it's for another interface.
|
||||
func (m *Module) rollbackNow(iface string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.pending == nil {
|
||||
return fmt.Errorf("no pending change to rollback")
|
||||
}
|
||||
if m.pending.Iface != iface {
|
||||
return fmt.Errorf("pending change is for %s, not %s", m.pending.Iface, iface)
|
||||
}
|
||||
|
||||
m.pending.Timer.Stop()
|
||||
err := m.pending.revert()
|
||||
m.pending = nil
|
||||
if err != nil {
|
||||
return fmt.Errorf("rollback %s: %w", iface, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PendingInfo is the JSON body returned by the pending-change status endpoint.
|
||||
type PendingInfo struct {
|
||||
Iface string `json:"interface" example:"eth0" doc:"Interface with a pending change"`
|
||||
SecondsRemaining int `json:"seconds_remaining" example:"45" doc:"Seconds until auto-rollback"`
|
||||
}
|
||||
|
||||
// pendingInfo returns the current pending change status, or nil if none.
|
||||
func (m *Module) pendingInfo() *PendingInfo {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.pending == nil {
|
||||
return nil
|
||||
}
|
||||
remaining := max(int(time.Until(m.pending.Deadline).Seconds()), 0)
|
||||
return &PendingInfo{
|
||||
Iface: m.pending.Iface,
|
||||
SecondsRemaining: remaining,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,194 @@
|
||||
package networking
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"nadir/internal/oscmd"
|
||||
|
||||
"github.com/danielgtaylor/huma/v2"
|
||||
)
|
||||
|
||||
// registerWrites adds all write endpoints for the networking module. Every
|
||||
// handler checks m.be != nil first and returns 501 when no backend was detected.
|
||||
|
||||
type ApplyInput struct {
|
||||
Name string `path:"name" example:"eth0" doc:"Interface name"`
|
||||
Body IfaceConfig `doc:"Desired interface configuration"`
|
||||
}
|
||||
|
||||
type ApplyOutput struct {
|
||||
Body struct {
|
||||
Status string `json:"status" example:"pending" doc:"Always \"pending\" — confirm to make permanent"`
|
||||
Interface string `json:"interface" example:"eth0"`
|
||||
Backend string `json:"backend" example:"nmcli" doc:"Network backend that applied the change"`
|
||||
RollbackSeconds int `json:"rollback_seconds" example:"60" doc:"Seconds until auto-rollback unless confirmed"`
|
||||
}
|
||||
}
|
||||
|
||||
type IfacePathInput struct {
|
||||
Name string `path:"name" example:"eth0" doc:"Interface name"`
|
||||
}
|
||||
|
||||
type PendingOutput struct {
|
||||
Body PendingInfo
|
||||
}
|
||||
|
||||
func registerWrites(api huma.API, m *Module) {
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "networking-apply-config",
|
||||
Method: "PUT",
|
||||
Path: "/api/networking/interfaces/{name}",
|
||||
Summary: "Apply interface configuration",
|
||||
Description: "Replaces the interface's IPv4 configuration. The change is applied " +
|
||||
"immediately but starts a rollback timer — if not confirmed within the timeout " +
|
||||
"(default 60s), the prior configuration is automatically restored. This prevents " +
|
||||
"lock-yourself-out mistakes on remote hosts.",
|
||||
Tags: []string{tagNetworking},
|
||||
Metadata: op("write"),
|
||||
Errors: writeErrors,
|
||||
}, func(ctx context.Context, in *ApplyInput) (*ApplyOutput, error) {
|
||||
if m.be == nil {
|
||||
return nil, huma.Error501NotImplemented("", errNoBackend)
|
||||
}
|
||||
if err := validateIface(in.Name); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := in.Body.validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
seconds, err := m.startRollback(ctx, in.Name, in.Body)
|
||||
if err != nil {
|
||||
if errors.Is(err, errAlreadyPending) {
|
||||
return nil, huma.Error409Conflict(err.Error())
|
||||
}
|
||||
return nil, huma.Error500InternalServerError("apply failed", err)
|
||||
}
|
||||
|
||||
out := &ApplyOutput{}
|
||||
out.Body.Status = "pending"
|
||||
out.Body.Interface = in.Name
|
||||
out.Body.Backend = m.be.Name()
|
||||
out.Body.RollbackSeconds = seconds
|
||||
return out, nil
|
||||
})
|
||||
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "networking-confirm-change",
|
||||
Method: "POST",
|
||||
Path: "/api/networking/interfaces/{name}/confirm",
|
||||
Summary: "Confirm a pending change",
|
||||
Description: "Cancels the rollback timer, making the applied configuration permanent.",
|
||||
Tags: []string{tagNetworking},
|
||||
Metadata: op("write"),
|
||||
Errors: writeErrors,
|
||||
}, func(ctx context.Context, in *IfacePathInput) (*oscmd.StatusOutput, error) {
|
||||
if m.be == nil {
|
||||
return nil, huma.Error501NotImplemented("", errNoBackend)
|
||||
}
|
||||
if err := validateIface(in.Name); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := m.confirm(in.Name); err != nil {
|
||||
return nil, huma.Error409Conflict(err.Error())
|
||||
}
|
||||
return oscmd.OK(), nil
|
||||
})
|
||||
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "networking-rollback-change",
|
||||
Method: "POST",
|
||||
Path: "/api/networking/interfaces/{name}/rollback",
|
||||
Summary: "Immediately revert a pending change",
|
||||
Description: "Reverts the interface to its prior configuration and clears the pending change.",
|
||||
Tags: []string{tagNetworking},
|
||||
Metadata: op("write"),
|
||||
Errors: writeErrors,
|
||||
}, func(ctx context.Context, in *IfacePathInput) (*oscmd.StatusOutput, error) {
|
||||
if m.be == nil {
|
||||
return nil, huma.Error501NotImplemented("", errNoBackend)
|
||||
}
|
||||
if err := validateIface(in.Name); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := m.rollbackNow(in.Name); err != nil {
|
||||
return nil, huma.Error500InternalServerError("rollback failed", err)
|
||||
}
|
||||
return oscmd.OK(), nil
|
||||
})
|
||||
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "networking-link-up",
|
||||
Method: "POST",
|
||||
Path: "/api/networking/interfaces/{name}/up",
|
||||
Summary: "Bring an interface up",
|
||||
Tags: []string{tagNetworking},
|
||||
Metadata: op("write"),
|
||||
Errors: writeErrors,
|
||||
}, func(ctx context.Context, in *IfacePathInput) (*oscmd.StatusOutput, error) {
|
||||
if m.be == nil {
|
||||
return nil, huma.Error501NotImplemented("", errNoBackend)
|
||||
}
|
||||
if err := validateIface(in.Name); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := m.be.SetLinkUp(ctx, in.Name); err != nil {
|
||||
return nil, huma.Error500InternalServerError("link up failed", err)
|
||||
}
|
||||
return oscmd.OK(), nil
|
||||
})
|
||||
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "networking-link-down",
|
||||
Method: "POST",
|
||||
Path: "/api/networking/interfaces/{name}/down",
|
||||
Summary: "Take an interface down",
|
||||
Description: "Brings the interface down behind the rollback safety net: it is brought " +
|
||||
"back up automatically if not confirmed within the timeout (default 60s). This " +
|
||||
"prevents taking down the link you're managing the host over and losing access.",
|
||||
Tags: []string{tagNetworking},
|
||||
Metadata: op("write"),
|
||||
Errors: writeErrors,
|
||||
}, func(ctx context.Context, in *IfacePathInput) (*ApplyOutput, error) {
|
||||
if m.be == nil {
|
||||
return nil, huma.Error501NotImplemented("", errNoBackend)
|
||||
}
|
||||
if err := validateIface(in.Name); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
seconds, err := m.startLinkDown(ctx, in.Name)
|
||||
if err != nil {
|
||||
if errors.Is(err, errAlreadyPending) {
|
||||
return nil, huma.Error409Conflict(err.Error())
|
||||
}
|
||||
return nil, huma.Error500InternalServerError("link down failed", err)
|
||||
}
|
||||
out := &ApplyOutput{}
|
||||
out.Body.Status = "pending"
|
||||
out.Body.Interface = in.Name
|
||||
out.Body.Backend = m.be.Name()
|
||||
out.Body.RollbackSeconds = seconds
|
||||
return out, nil
|
||||
})
|
||||
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "networking-get-pending",
|
||||
Method: "GET",
|
||||
Path: "/api/networking/pending",
|
||||
Summary: "Get pending change status",
|
||||
Description: "Returns the currently pending change (interface name and seconds until " +
|
||||
"auto-rollback), or 404 if there is no pending change.",
|
||||
Tags: []string{tagNetworking},
|
||||
Metadata: op("read"),
|
||||
Errors: []int{401, 403, 404, 500},
|
||||
}, func(ctx context.Context, _ *struct{}) (*PendingOutput, error) {
|
||||
info := m.pendingInfo()
|
||||
if info == nil {
|
||||
return nil, huma.Error404NotFound("no pending change")
|
||||
}
|
||||
out := &PendingOutput{}
|
||||
out.Body = *info
|
||||
return out, nil
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
package packages
|
||||
|
||||
import (
|
||||
"nadir/internal/rbac"
|
||||
|
||||
"github.com/danielgtaylor/huma/v2"
|
||||
)
|
||||
|
||||
const ModuleID = "packages"
|
||||
|
||||
type Module struct {
|
||||
pm manager // detected package manager; zero value means none found
|
||||
}
|
||||
|
||||
// New detects the host's package manager once at startup.
|
||||
func New() *Module { return &Module{pm: detect()} }
|
||||
|
||||
func (m *Module) ID() string { return ModuleID }
|
||||
func (m *Module) Name() string { return "Packages" }
|
||||
|
||||
// Permissions: read to list installed/available; write to install, remove, and
|
||||
// upgrade.
|
||||
func (m *Module) Permissions() []rbac.Permission {
|
||||
return []rbac.Permission{rbac.Read, rbac.Write}
|
||||
}
|
||||
|
||||
func (m *Module) Register(api huma.API) {
|
||||
registerPackages(api, m.pm)
|
||||
}
|
||||
|
||||
func op(permission string) map[string]any {
|
||||
return map[string]any{"module": ModuleID, "permission": permission}
|
||||
}
|
||||
@@ -0,0 +1,370 @@
|
||||
package packages
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"nadir/internal/oscmd"
|
||||
|
||||
"github.com/danielgtaylor/huma/v2"
|
||||
"github.com/danielgtaylor/huma/v2/sse"
|
||||
)
|
||||
|
||||
const tagPackages = "Packages"
|
||||
|
||||
var (
|
||||
readErrors = []int{401, 403, 500}
|
||||
writeErrors = []int{400, 401, 403, 500}
|
||||
)
|
||||
|
||||
// pkgNameRe matches a plain package name (no version specifiers, no shell
|
||||
// metacharacters). Combined with a leading-dash reject and "--" before the name
|
||||
// on every command line, it keeps user input from being read as a flag.
|
||||
var pkgNameRe = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9@._+-]*$`)
|
||||
|
||||
// manager identifies the host package manager. Its name drives a switch in each
|
||||
// operation rather than an interface - three concrete tools, no polymorphism.
|
||||
type manager struct{ name string }
|
||||
|
||||
// detect picks the package manager. dnf/pacman are checked before apt because a
|
||||
// few rpm-based boxes also ship an apt-get shim; the native tool should win.
|
||||
func detect() manager {
|
||||
for _, m := range []struct{ name, bin string }{
|
||||
{"dnf", "dnf"},
|
||||
{"pacman", "pacman"},
|
||||
{"apt", "apt-get"},
|
||||
} {
|
||||
if _, err := exec.LookPath(m.bin); err == nil {
|
||||
return manager{name: m.name}
|
||||
}
|
||||
}
|
||||
return manager{}
|
||||
}
|
||||
|
||||
type Package struct {
|
||||
Name string `json:"name" example:"openssh-server" doc:"Package name"`
|
||||
Version string `json:"version" example:"9.6p1" doc:"Installed version, or the available version for updates"`
|
||||
}
|
||||
|
||||
type ListOutput struct {
|
||||
Body struct {
|
||||
Manager string `json:"manager" example:"dnf" doc:"Detected package manager"`
|
||||
Packages []Package `json:"packages"`
|
||||
}
|
||||
}
|
||||
|
||||
type InstallInput struct {
|
||||
Body struct {
|
||||
Name string `json:"name" example:"htop" doc:"Package to install"`
|
||||
}
|
||||
}
|
||||
|
||||
type RemoveInput struct {
|
||||
Name string `path:"name" example:"htop" doc:"Package to remove"`
|
||||
}
|
||||
|
||||
// SSE event types for streaming package operations.
|
||||
type PkgOutputEvent struct {
|
||||
Line string `json:"line" doc:"One line of the package manager's terminal output"`
|
||||
}
|
||||
|
||||
type PkgErrorEvent struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type PkgDoneEvent struct {
|
||||
Success bool `json:"success" doc:"True if the package manager exited 0"`
|
||||
Error string `json:"error,omitempty" doc:"Exit error when it failed"`
|
||||
}
|
||||
|
||||
// pkgEvents maps SSE event names to their payload types for the streaming
|
||||
// install/remove/upgrade operations.
|
||||
var pkgEvents = map[string]any{
|
||||
"output": PkgOutputEvent{},
|
||||
"done": PkgDoneEvent{},
|
||||
"error": PkgErrorEvent{},
|
||||
}
|
||||
|
||||
func registerPackages(api huma.API, pm manager) {
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "packages-list-installed",
|
||||
Method: "GET",
|
||||
Path: "/api/packages",
|
||||
Summary: "List installed packages",
|
||||
Tags: []string{tagPackages},
|
||||
Metadata: op("read"),
|
||||
Errors: readErrors,
|
||||
}, func(ctx context.Context, _ *struct{}) (*ListOutput, error) {
|
||||
return listInstalled(pm)
|
||||
})
|
||||
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "packages-list-updates",
|
||||
Method: "GET",
|
||||
Path: "/api/packages/updates",
|
||||
Summary: "List available updates",
|
||||
Tags: []string{tagPackages},
|
||||
Metadata: op("read"),
|
||||
Errors: readErrors,
|
||||
}, func(ctx context.Context, _ *struct{}) (*ListOutput, error) {
|
||||
return listUpdates(pm)
|
||||
})
|
||||
|
||||
// Install/remove/upgrade stream the package manager's terminal output live
|
||||
// via SSE (one `output` event per line, a final `done` with exit status),
|
||||
// rather than blocking the request until a long operation finishes.
|
||||
sse.Register(api, huma.Operation{
|
||||
OperationID: "packages-install",
|
||||
Method: "POST",
|
||||
Path: "/api/packages",
|
||||
Summary: "Install a package (streamed)",
|
||||
Description: "Installs a package, streaming the package manager's output as " +
|
||||
"`output` events and ending with a `done` event carrying the exit status.",
|
||||
Tags: []string{tagPackages},
|
||||
Metadata: op("write"),
|
||||
}, pkgEvents, func(ctx context.Context, in *InstallInput, send sse.Sender) {
|
||||
if validateName(in.Body.Name) != nil {
|
||||
send.Data(PkgErrorEvent{Message: "invalid package name: " + in.Body.Name})
|
||||
return
|
||||
}
|
||||
bin, args := pm.installArgs(in.Body.Name)
|
||||
streamOp(ctx, send, bin, args)
|
||||
})
|
||||
|
||||
sse.Register(api, huma.Operation{
|
||||
OperationID: "packages-remove",
|
||||
Method: "DELETE",
|
||||
Path: "/api/packages/{name}",
|
||||
Summary: "Remove a package (streamed)",
|
||||
Tags: []string{tagPackages},
|
||||
Metadata: op("write"),
|
||||
}, pkgEvents, func(ctx context.Context, in *RemoveInput, send sse.Sender) {
|
||||
if validateName(in.Name) != nil {
|
||||
send.Data(PkgErrorEvent{Message: "invalid package name: " + in.Name})
|
||||
return
|
||||
}
|
||||
bin, args := pm.removeArgs(in.Name)
|
||||
streamOp(ctx, send, bin, args)
|
||||
})
|
||||
|
||||
sse.Register(api, huma.Operation{
|
||||
OperationID: "packages-upgrade",
|
||||
Method: "POST",
|
||||
Path: "/api/packages/upgrade",
|
||||
Summary: "Upgrade all packages (streamed)",
|
||||
Description: "Upgrades every package to its latest version, streaming the " +
|
||||
"package manager's output live.",
|
||||
Tags: []string{tagPackages},
|
||||
Metadata: op("write"),
|
||||
}, pkgEvents, func(ctx context.Context, _ *struct{}, send sse.Sender) {
|
||||
bin, args := pm.upgradeArgs()
|
||||
streamOp(ctx, send, bin, args)
|
||||
})
|
||||
}
|
||||
|
||||
// streamOp runs a package write and streams its combined output to the client.
|
||||
func streamOp(ctx context.Context, send sse.Sender, bin string, args []string) {
|
||||
if bin == "" {
|
||||
send.Data(PkgErrorEvent{Message: "no supported package manager found"})
|
||||
return
|
||||
}
|
||||
// DEBIAN_FRONTEND keeps apt from blocking on an interactive prompt.
|
||||
lines, errc, err := oscmd.RunStreamCombined(ctx, []string{"DEBIAN_FRONTEND=noninteractive"}, bin, args...)
|
||||
if err != nil {
|
||||
send.Data(PkgErrorEvent{Message: err.Error()})
|
||||
return
|
||||
}
|
||||
for line := range lines {
|
||||
if send.Data(PkgOutputEvent{Line: line}) != nil {
|
||||
return // client gone; ctx cancel kills the process
|
||||
}
|
||||
}
|
||||
done := PkgDoneEvent{Success: true}
|
||||
if werr := <-errc; werr != nil {
|
||||
done.Success = false
|
||||
done.Error = werr.Error()
|
||||
}
|
||||
send.Data(done)
|
||||
}
|
||||
|
||||
func validateName(name string) error {
|
||||
if !pkgNameRe.MatchString(name) {
|
||||
return huma.Error400BadRequest("invalid package name: " + name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- reads -------------------------------------------------------------------
|
||||
|
||||
func listInstalled(pm manager) (*ListOutput, error) {
|
||||
var out string
|
||||
var err error
|
||||
switch pm.name {
|
||||
case "dnf":
|
||||
out, err = oscmd.Run("rpm", "-qa", "--qf", "%{NAME}\t%{VERSION}-%{RELEASE}\n")
|
||||
case "apt":
|
||||
out, err = oscmd.Run("dpkg-query", "-W", "-f=${Package}\t${Version}\n")
|
||||
case "pacman":
|
||||
out, err = oscmd.Run("pacman", "-Q")
|
||||
default:
|
||||
return nil, huma.Error500InternalServerError("no supported package manager found")
|
||||
}
|
||||
if err != nil {
|
||||
return nil, huma.Error500InternalServerError("listing installed packages failed", err)
|
||||
}
|
||||
pkgs := parseTabbed(out)
|
||||
if pm.name == "pacman" {
|
||||
pkgs = parseSpaced(out)
|
||||
}
|
||||
return result(pm, pkgs), nil
|
||||
}
|
||||
|
||||
func listUpdates(pm manager) (*ListOutput, error) {
|
||||
switch pm.name {
|
||||
case "dnf":
|
||||
// check-update exits 100 when updates exist, 0 when none - both fine.
|
||||
out, code, err := oscmd.RunStatus("dnf", "-q", "check-update")
|
||||
if err != nil || (code != 0 && code != 100) {
|
||||
return nil, huma.Error500InternalServerError("dnf check-update failed", err)
|
||||
}
|
||||
return result(pm, parseDnf(out)), nil
|
||||
case "apt":
|
||||
out, err := oscmd.Run("apt", "list", "--upgradable", "-qq")
|
||||
if err != nil {
|
||||
return nil, huma.Error500InternalServerError("apt list failed", err)
|
||||
}
|
||||
return result(pm, parseApt(out)), nil
|
||||
case "pacman":
|
||||
// -Qu exits 1 when there is nothing to upgrade - not an error.
|
||||
out, code, err := oscmd.RunStatus("pacman", "-Qu")
|
||||
if err != nil || (code != 0 && code != 1) {
|
||||
return nil, huma.Error500InternalServerError("pacman -Qu failed", err)
|
||||
}
|
||||
return result(pm, parsePacmanUpdates(out)), nil
|
||||
default:
|
||||
return nil, huma.Error500InternalServerError("no supported package manager found")
|
||||
}
|
||||
}
|
||||
|
||||
func result(pm manager, pkgs []Package) *ListOutput {
|
||||
out := &ListOutput{}
|
||||
out.Body.Manager = pm.name
|
||||
out.Body.Packages = pkgs
|
||||
return out
|
||||
}
|
||||
|
||||
// --- writes ------------------------------------------------------------------
|
||||
|
||||
func (m manager) installArgs(name string) (string, []string) {
|
||||
switch m.name {
|
||||
case "dnf":
|
||||
return "dnf", []string{"install", "-y", "--", name}
|
||||
case "apt":
|
||||
return "apt-get", []string{"install", "-y", "--", name}
|
||||
case "pacman":
|
||||
return "pacman", []string{"-S", "--noconfirm", "--", name}
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (m manager) removeArgs(name string) (string, []string) {
|
||||
switch m.name {
|
||||
case "dnf":
|
||||
return "dnf", []string{"remove", "-y", "--", name}
|
||||
case "apt":
|
||||
return "apt-get", []string{"remove", "-y", "--", name}
|
||||
case "pacman":
|
||||
return "pacman", []string{"-R", "--noconfirm", "--", name}
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (m manager) upgradeArgs() (string, []string) {
|
||||
switch m.name {
|
||||
case "dnf":
|
||||
return "dnf", []string{"upgrade", "-y"}
|
||||
case "apt":
|
||||
return "apt-get", []string{"upgrade", "-y"}
|
||||
case "pacman":
|
||||
return "pacman", []string{"-Su", "--noconfirm"}
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// --- parsers (pure, tested) --------------------------------------------------
|
||||
|
||||
// parseTabbed reads "name\tversion" lines (dpkg-query / rpm output).
|
||||
func parseTabbed(out string) []Package {
|
||||
pkgs := []Package{}
|
||||
for line := range strings.SplitSeq(out, "\n") {
|
||||
if name, ver, ok := strings.Cut(line, "\t"); ok && name != "" {
|
||||
pkgs = append(pkgs, Package{Name: name, Version: ver})
|
||||
}
|
||||
}
|
||||
return pkgs
|
||||
}
|
||||
|
||||
// parseSpaced reads "name version" lines (pacman -Q).
|
||||
func parseSpaced(out string) []Package {
|
||||
pkgs := []Package{}
|
||||
for line := range strings.SplitSeq(out, "\n") {
|
||||
if f := strings.Fields(line); len(f) >= 2 {
|
||||
pkgs = append(pkgs, Package{Name: f[0], Version: f[1]})
|
||||
}
|
||||
}
|
||||
return pkgs
|
||||
}
|
||||
|
||||
// parseDnf reads `dnf check-update` lines ("name.arch version repo"), skipping
|
||||
// the section headers and blank lines that dnf5 emits.
|
||||
func parseDnf(out string) []Package {
|
||||
pkgs := []Package{}
|
||||
for line := range strings.SplitSeq(out, "\n") {
|
||||
f := strings.Fields(line)
|
||||
// Real rows have 3 columns and an arch suffix on the name; headers like
|
||||
// "Upgrades" or "Obsoleting Packages" don't.
|
||||
if len(f) != 3 || !strings.Contains(f[0], ".") {
|
||||
continue
|
||||
}
|
||||
pkgs = append(pkgs, Package{Name: stripArch(f[0]), Version: f[1]})
|
||||
}
|
||||
return pkgs
|
||||
}
|
||||
|
||||
// parseApt reads `apt list --upgradable` lines ("name/repo version arch [...]").
|
||||
func parseApt(out string) []Package {
|
||||
pkgs := []Package{}
|
||||
for line := range strings.SplitSeq(out, "\n") {
|
||||
f := strings.Fields(line)
|
||||
if len(f) < 2 || !strings.Contains(f[0], "/") {
|
||||
continue
|
||||
}
|
||||
name, _, _ := strings.Cut(f[0], "/")
|
||||
pkgs = append(pkgs, Package{Name: name, Version: f[1]})
|
||||
}
|
||||
return pkgs
|
||||
}
|
||||
|
||||
// parsePacmanUpdates reads `pacman -Qu` lines ("name oldver -> newver").
|
||||
func parsePacmanUpdates(out string) []Package {
|
||||
pkgs := []Package{}
|
||||
for line := range strings.SplitSeq(out, "\n") {
|
||||
f := strings.Fields(line)
|
||||
if len(f) < 4 || f[2] != "->" {
|
||||
continue
|
||||
}
|
||||
pkgs = append(pkgs, Package{Name: f[0], Version: f[3]})
|
||||
}
|
||||
return pkgs
|
||||
}
|
||||
|
||||
// stripArch removes the trailing ".arch" from an rpm "name.arch" token. The
|
||||
// arch is always the final dotted segment, so cut at the last dot.
|
||||
func stripArch(s string) string {
|
||||
if i := strings.LastIndex(s, "."); i > 0 {
|
||||
return s[:i]
|
||||
}
|
||||
return s
|
||||
}
|
||||
@@ -0,0 +1,125 @@
|
||||
package packages
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"nadir/internal/oscmd"
|
||||
|
||||
"github.com/danielgtaylor/huma/v2"
|
||||
"github.com/danielgtaylor/huma/v2/adapters/humago"
|
||||
"github.com/danielgtaylor/huma/v2/humatest"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
if oscmd.RunHelperProcess() {
|
||||
return
|
||||
}
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
func TestPackagesHandlers(t *testing.T) {
|
||||
managers := []string{"dnf", "apt", "pacman"}
|
||||
|
||||
for _, mgrName := range managers {
|
||||
t.Run(mgrName, func(t *testing.T) {
|
||||
mux := http.NewServeMux()
|
||||
api := humatest.Wrap(t, humago.New(mux, huma.DefaultConfig("Test", "1.0.0")))
|
||||
|
||||
m := &Module{pm: manager{name: mgrName}}
|
||||
m.Register(api)
|
||||
|
||||
oscmd.SetMock("rpm", func(args []string) oscmd.MockCommand {
|
||||
return oscmd.MockCommand{Stdout: "htop\t3.3.0-1\n", ExitCode: 0}
|
||||
})
|
||||
oscmd.SetMock("dpkg-query", func(args []string) oscmd.MockCommand {
|
||||
return oscmd.MockCommand{Stdout: "htop\t3.3.0-1\n", ExitCode: 0}
|
||||
})
|
||||
oscmd.SetMock("pacman", func(args []string) oscmd.MockCommand {
|
||||
if reflect.DeepEqual(args, []string{"-Q"}) {
|
||||
return oscmd.MockCommand{Stdout: "htop 3.3.0-1\n", ExitCode: 0}
|
||||
}
|
||||
if reflect.DeepEqual(args, []string{"-Qu"}) {
|
||||
return oscmd.MockCommand{Stdout: "linux 6.9.1-1 -> 6.9.2-1\n", ExitCode: 0}
|
||||
}
|
||||
return oscmd.MockCommand{Lines: []string{"pacman output"}, ExitCode: 0}
|
||||
})
|
||||
oscmd.SetMock("dnf", func(args []string) oscmd.MockCommand {
|
||||
if reflect.DeepEqual(args, []string{"-q", "check-update"}) {
|
||||
return oscmd.MockCommand{Stdout: "Upgrades\ncode.x86_64 1.125.1-1 code\n", ExitCode: 100}
|
||||
}
|
||||
return oscmd.MockCommand{Lines: []string{"dnf output"}, ExitCode: 0}
|
||||
})
|
||||
oscmd.SetMock("apt", func(args []string) oscmd.MockCommand {
|
||||
return oscmd.MockCommand{Stdout: "Listing...\nvim/jammy 2:8.2.3995 amd64 [upgradable]\n", ExitCode: 0}
|
||||
})
|
||||
oscmd.SetMock("apt-get", func(args []string) oscmd.MockCommand {
|
||||
return oscmd.MockCommand{Lines: []string{"apt-get output"}, ExitCode: 0}
|
||||
})
|
||||
defer oscmd.ClearMocks()
|
||||
|
||||
// Test list installed
|
||||
resp := api.Get("/api/packages")
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("list installed: got %d, want %d", resp.Code, http.StatusOK)
|
||||
}
|
||||
var listRes ListOutput
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &listRes.Body); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(listRes.Body.Packages) != 1 || listRes.Body.Packages[0].Name != "htop" {
|
||||
t.Errorf("unexpected installed packages: %+v", listRes.Body)
|
||||
}
|
||||
|
||||
// Test list updates
|
||||
resp = api.Get("/api/packages/updates")
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("list updates: got %d, want %d", resp.Code, http.StatusOK)
|
||||
}
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &listRes.Body); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(listRes.Body.Packages) != 1 {
|
||||
t.Errorf("unexpected updates: %+v", listRes.Body)
|
||||
}
|
||||
|
||||
// Test install (SSE)
|
||||
resp = api.Post("/api/packages", struct {
|
||||
Name string `json:"name"`
|
||||
}{
|
||||
Name: "htop",
|
||||
})
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("install: got %d, want %d", resp.Code, http.StatusOK)
|
||||
}
|
||||
bodyStr := resp.Body.String()
|
||||
if !strings.Contains(bodyStr, "done") {
|
||||
t.Errorf("install stream output missing done: %q", bodyStr)
|
||||
}
|
||||
|
||||
// Test remove (SSE)
|
||||
resp = api.Delete("/api/packages/htop")
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("remove: got %d, want %d", resp.Code, http.StatusOK)
|
||||
}
|
||||
bodyStr = resp.Body.String()
|
||||
if !strings.Contains(bodyStr, "done") {
|
||||
t.Errorf("remove stream output missing done: %q", bodyStr)
|
||||
}
|
||||
|
||||
// Test upgrade (SSE)
|
||||
resp = api.Post("/api/packages/upgrade", struct{}{})
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("upgrade: got %d, want %d", resp.Code, http.StatusOK)
|
||||
}
|
||||
bodyStr = resp.Body.String()
|
||||
if !strings.Contains(bodyStr, "done") {
|
||||
t.Errorf("upgrade stream output missing done: %q", bodyStr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,80 @@
|
||||
package packages
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseTabbed(t *testing.T) {
|
||||
out := "zeromq\t4.3.5-22.fc43\nspice-server\t0.16.0-2.fc43\n\nbad-no-tab\n"
|
||||
got := parseTabbed(out)
|
||||
want := []Package{{"zeromq", "4.3.5-22.fc43"}, {"spice-server", "0.16.0-2.fc43"}}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("got %+v, want %+v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSpaced(t *testing.T) {
|
||||
got := parseSpaced("linux 6.9.1\nhtop 3.3.0\n")
|
||||
want := []Package{{"linux", "6.9.1"}, {"htop", "3.3.0"}}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("got %+v, want %+v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDnf(t *testing.T) {
|
||||
// dnf5 emits a section header ("Upgrades") that must be skipped.
|
||||
out := "Upgrades\n" +
|
||||
"code.x86_64 1.125.1-1781859648.el8 code\n" +
|
||||
"containerd.io.x86_64 2.2.5-1.fc44 docker-ce-stable\n" +
|
||||
"\nObsoleting Packages\n"
|
||||
got := parseDnf(out)
|
||||
want := []Package{{"code", "1.125.1-1781859648.el8"}, {"containerd.io", "2.2.5-1.fc44"}}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("got %+v, want %+v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseApt(t *testing.T) {
|
||||
out := "Listing...\n" +
|
||||
"vim/jammy-updates 2:8.2.3995-1ubuntu2.15 amd64 [upgradable from: 2:8.2.3995-1ubuntu2.1]\n"
|
||||
got := parseApt(out)
|
||||
want := []Package{{"vim", "2:8.2.3995-1ubuntu2.15"}}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("got %+v, want %+v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParsePacmanUpdates(t *testing.T) {
|
||||
got := parsePacmanUpdates("linux 6.9.1-1 -> 6.9.2-1\nfoo 1.0 1.0\n")
|
||||
want := []Package{{"linux", "6.9.2-1"}}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("got %+v, want %+v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripArch(t *testing.T) {
|
||||
cases := map[string]string{
|
||||
"code.x86_64": "code",
|
||||
"python3.11.noarch": "python3.11", // arch is only the final segment
|
||||
"noarchhere": "noarchhere",
|
||||
}
|
||||
for in, want := range cases {
|
||||
if got := stripArch(in); got != want {
|
||||
t.Errorf("stripArch(%q) = %q, want %q", in, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateName(t *testing.T) {
|
||||
for _, n := range []string{"htop", "openssh-server", "lib32-glibc", "g++", "python3.11"} {
|
||||
if err := validateName(n); err != nil {
|
||||
t.Errorf("validateName(%q) = %v, want nil", n, err)
|
||||
}
|
||||
}
|
||||
for _, n := range []string{"", "-rf", "foo;rm", "foo bar", "pkg=1.0", "a/b"} {
|
||||
if err := validateName(n); err == nil {
|
||||
t.Errorf("validateName(%q) = nil, want error", n)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,237 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"nadir/internal/oscmd"
|
||||
|
||||
"github.com/danielgtaylor/huma/v2"
|
||||
"github.com/danielgtaylor/huma/v2/sse"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultLogLines = 100
|
||||
maxLogLines = 10000
|
||||
)
|
||||
|
||||
// LogEntry is one log record. For the journal source it is distilled from
|
||||
// journalctl's JSON; for the file source only Message is set (the raw line,
|
||||
// which usually carries its own embedded timestamp).
|
||||
type LogEntry struct {
|
||||
Time string `json:"time" example:"2026-06-20T08:15:04Z" doc:"Record timestamp (RFC3339, UTC); empty for file lines"`
|
||||
Priority int `json:"priority" example:"6" doc:"syslog priority 0 (emerg) – 7 (debug); 6 for file lines"`
|
||||
Message string `json:"message" example:"Started OpenSSH server daemon."`
|
||||
}
|
||||
|
||||
// ErrorEvent is an SSE event carrying a stream-level error (e.g. bad unit).
|
||||
type ErrorEvent struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type LogsInput struct {
|
||||
Unit string `path:"unit" example:"docker.service" doc:"Unit name as listed by GET /api/services; the trailing .service is optional"`
|
||||
Source string `query:"source" enum:"journal,file" default:"journal" doc:"Where to read logs from"`
|
||||
Path string `query:"path" example:"/var/log/nginx/error.log" doc:"Log file (file source only); must be allowlisted for this unit in config"`
|
||||
Lines int `query:"lines" default:"100" doc:"How many recent records to return (max 10000)"`
|
||||
Since string `query:"since" example:"-1h" doc:"journalctl time filter (journal source only)"`
|
||||
Priority int `query:"priority" default:"7" minimum:"0" maximum:"7" doc:"Max syslog priority to include: 0 emerg .. 7 debug (journal source only). 7 = all."`
|
||||
}
|
||||
|
||||
type LogsOutput struct {
|
||||
Body struct {
|
||||
Entries []LogEntry `json:"entries" doc:"Log records, oldest first"`
|
||||
}
|
||||
}
|
||||
|
||||
type LogStreamInput struct {
|
||||
Unit string `path:"unit" example:"docker.service" doc:"Unit name as listed by GET /api/services; the trailing .service is optional"`
|
||||
Source string `query:"source" enum:"journal,file" default:"journal" doc:"Where to stream logs from"`
|
||||
Path string `query:"path" example:"/var/log/nginx/error.log" doc:"Log file (file source only); must be allowlisted for this unit in config"`
|
||||
Since string `query:"since" example:"-1h" doc:"Backfill window (journal source only)"`
|
||||
Priority int `query:"priority" default:"7" minimum:"0" maximum:"7" doc:"Max syslog priority to include: 0 emerg .. 7 debug (journal source only). 7 = all."`
|
||||
}
|
||||
|
||||
func registerLogs(api huma.API, logFiles map[string][]string) {
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "services-logs",
|
||||
Method: "GET",
|
||||
Path: "/api/services/{unit}/logs",
|
||||
Summary: "Get recent log records for a service",
|
||||
Description: "Returns a snapshot of the unit's logs from the journal " +
|
||||
"(default) or an allowlisted file (source=file&path=). Use /logs/stream " +
|
||||
"to follow new records live.",
|
||||
Tags: []string{tagServices},
|
||||
Metadata: op("read"),
|
||||
// No 404: the journal is historical, so logs are returned even for units
|
||||
// that aren't currently loaded; an unknown unit just yields an empty list.
|
||||
Errors: []int{400, 401, 403, 500},
|
||||
}, func(ctx context.Context, in *LogsInput) (*LogsOutput, error) {
|
||||
if err := validateUnit(in.Unit); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var lines []string
|
||||
var err error
|
||||
if in.Source == "file" {
|
||||
path, perr := resolveLogPath(logFiles, in.Unit, in.Path)
|
||||
if perr != nil {
|
||||
return nil, perr
|
||||
}
|
||||
lines, err = oscmd.RunLines("tail", "-n", strconv.Itoa(clampLines(in.Lines)), "--", path)
|
||||
if err != nil {
|
||||
return nil, huma.Error500InternalServerError("tail failed", err)
|
||||
}
|
||||
return fileOutput(lines), nil
|
||||
}
|
||||
|
||||
args := []string{"-u", journalUnit(in.Unit), "--no-pager", "-o", "json", "-p", strconv.Itoa(in.Priority), "-n", strconv.Itoa(clampLines(in.Lines))}
|
||||
if in.Since != "" {
|
||||
args = append(args, "--since", in.Since)
|
||||
}
|
||||
lines, err = oscmd.RunLines("journalctl", args...)
|
||||
if err != nil {
|
||||
return nil, huma.Error500InternalServerError("journalctl failed", err)
|
||||
}
|
||||
out := &LogsOutput{}
|
||||
out.Body.Entries = []LogEntry{}
|
||||
for _, l := range lines {
|
||||
if e, ok := parseJournalLine([]byte(l)); ok {
|
||||
out.Body.Entries = append(out.Body.Entries, e)
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
})
|
||||
|
||||
// Streaming via huma's sse package keeps the route inside huma, so the RBAC
|
||||
// middleware still enforces op("read") - a raw mux handler would bypass it.
|
||||
sse.Register(api, huma.Operation{
|
||||
OperationID: "services-logs-stream",
|
||||
Method: "GET",
|
||||
Path: "/api/services/{unit}/logs/stream",
|
||||
Summary: "Stream a service's logs (Server-Sent Events)",
|
||||
Description: "Follows the unit's journal (journalctl -f) or an allowlisted " +
|
||||
"file (source=file&path=, via tail -F) and emits a `log` event per " +
|
||||
"record. Stops when the client disconnects.",
|
||||
Tags: []string{tagServices},
|
||||
Metadata: op("read"),
|
||||
}, map[string]any{
|
||||
"log": LogEntry{},
|
||||
"error": ErrorEvent{},
|
||||
}, func(ctx context.Context, in *LogStreamInput, send sse.Sender) {
|
||||
if err := validateUnit(in.Unit); err != nil {
|
||||
send.Data(ErrorEvent{Message: "invalid unit name"})
|
||||
return
|
||||
}
|
||||
|
||||
var cmd string
|
||||
var args []string
|
||||
if in.Source == "file" {
|
||||
path, perr := resolveLogPath(logFiles, in.Unit, in.Path)
|
||||
if perr != nil {
|
||||
send.Data(ErrorEvent{Message: perr.Error()})
|
||||
return
|
||||
}
|
||||
cmd, args = "tail", []string{"-n", strconv.Itoa(defaultLogLines), "-F", "--", path}
|
||||
} else {
|
||||
cmd = "journalctl"
|
||||
args = []string{"-u", journalUnit(in.Unit), "--no-pager", "-o", "json", "-p", strconv.Itoa(in.Priority), "-f"}
|
||||
if in.Since != "" {
|
||||
args = append(args, "--since", in.Since)
|
||||
}
|
||||
}
|
||||
|
||||
lines, err := oscmd.RunStream(ctx, cmd, args...)
|
||||
if err != nil {
|
||||
send.Data(ErrorEvent{Message: cmd + " failed: " + err.Error()})
|
||||
return
|
||||
}
|
||||
for l := range lines {
|
||||
e, ok := LogEntry{Priority: 6, Message: l}, true
|
||||
if in.Source != "file" {
|
||||
e, ok = parseJournalLine([]byte(l))
|
||||
}
|
||||
if ok {
|
||||
if send.Data(e) != nil {
|
||||
return // client gone; ctx cancel will kill the command
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// resolveLogPath validates that path is allowlisted for unit. The caller never
|
||||
// gets to point exec at an arbitrary file - only paths an admin listed under
|
||||
// log_files for this unit are accepted.
|
||||
func resolveLogPath(logFiles map[string][]string, unit, path string) (string, error) {
|
||||
if path == "" {
|
||||
return "", huma.Error400BadRequest("source=file requires a path")
|
||||
}
|
||||
// Match the allowlist key suffix-insensitively (nginx == nginx.service), so
|
||||
// it behaves like the journal source regardless of which form the caller and
|
||||
// the config author each used.
|
||||
want := journalUnit(unit)
|
||||
for key, paths := range logFiles {
|
||||
if journalUnit(key) == want && slices.Contains(paths, path) {
|
||||
return path, nil
|
||||
}
|
||||
}
|
||||
return "", huma.Error403Forbidden("log file not allowlisted for unit " + unit + ": " + path)
|
||||
}
|
||||
|
||||
func fileOutput(lines []string) *LogsOutput {
|
||||
out := &LogsOutput{}
|
||||
out.Body.Entries = make([]LogEntry, 0, len(lines))
|
||||
for _, l := range lines {
|
||||
out.Body.Entries = append(out.Body.Entries, LogEntry{Priority: 6, Message: l})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// journalUnit normalizes a unit name for `journalctl -u`. journalctl treats a
|
||||
// bare name as a .service, and on some setups only the bare form matches the
|
||||
// recorded _SYSTEMD_UNIT, so we always strip the suffix. This is the services
|
||||
// module, so .service is the only suffix we expect.
|
||||
func journalUnit(unit string) string {
|
||||
return strings.TrimSuffix(unit, ".service")
|
||||
}
|
||||
|
||||
func clampLines(n int) int {
|
||||
switch {
|
||||
case n <= 0:
|
||||
return defaultLogLines
|
||||
case n > maxLogLines:
|
||||
return maxLogLines
|
||||
default:
|
||||
return n
|
||||
}
|
||||
}
|
||||
|
||||
// parseJournalLine distills one journalctl `-o json` record. Returns false for
|
||||
// unparseable lines. Binary MESSAGE fields (encoded as a byte array rather than
|
||||
// a string) yield an empty message rather than an error.
|
||||
func parseJournalLine(line []byte) (LogEntry, bool) {
|
||||
var raw struct {
|
||||
Message any `json:"MESSAGE"`
|
||||
Priority string `json:"PRIORITY"`
|
||||
TS string `json:"__REALTIME_TIMESTAMP"`
|
||||
}
|
||||
if err := json.Unmarshal(line, &raw); err != nil {
|
||||
return LogEntry{}, false
|
||||
}
|
||||
// Records without a PRIORITY (it's often absent) default to info (6), not
|
||||
// the zero value 0 which is emerg - that would fake critical alerts.
|
||||
e := LogEntry{Priority: 6}
|
||||
e.Message, _ = raw.Message.(string)
|
||||
if p, err := strconv.Atoi(raw.Priority); err == nil {
|
||||
e.Priority = p
|
||||
}
|
||||
if us, err := strconv.ParseInt(raw.TS, 10, 64); err == nil {
|
||||
e.Time = time.UnixMicro(us).UTC().Format(time.RFC3339)
|
||||
}
|
||||
return e, true
|
||||
}
|
||||
@@ -0,0 +1,96 @@
|
||||
package services
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestParseJournalLine(t *testing.T) {
|
||||
line := []byte(`{"__REALTIME_TIMESTAMP":"1750406104000000","PRIORITY":"6","MESSAGE":"Started OpenSSH server daemon.","_PID":"123"}`)
|
||||
e, ok := parseJournalLine(line)
|
||||
if !ok {
|
||||
t.Fatal("expected parse to succeed")
|
||||
}
|
||||
if e.Message != "Started OpenSSH server daemon." {
|
||||
t.Errorf("message = %q", e.Message)
|
||||
}
|
||||
if e.Priority != 6 {
|
||||
t.Errorf("priority = %d", e.Priority)
|
||||
}
|
||||
if e.Time != "2025-06-20T07:55:04Z" {
|
||||
t.Errorf("time = %q", e.Time)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJournalLineBinaryMessage(t *testing.T) {
|
||||
// Binary MESSAGE is encoded as a byte array; we yield an empty message, not an error.
|
||||
line := []byte(`{"__REALTIME_TIMESTAMP":"1750406104000000","PRIORITY":"3","MESSAGE":[104,105]}`)
|
||||
e, ok := parseJournalLine(line)
|
||||
if !ok || e.Message != "" || e.Priority != 3 {
|
||||
t.Errorf("got ok=%v entry=%+v", ok, e)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJournalLineMissingPriority(t *testing.T) {
|
||||
// PRIORITY is often absent; it must default to info (6), not emerg (0).
|
||||
line := []byte(`{"__REALTIME_TIMESTAMP":"1750406104000000","MESSAGE":"hi"}`)
|
||||
e, ok := parseJournalLine(line)
|
||||
if !ok || e.Priority != 6 {
|
||||
t.Errorf("got ok=%v priority=%d, want priority 6", ok, e.Priority)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJournalLineGarbage(t *testing.T) {
|
||||
if _, ok := parseJournalLine([]byte("not json")); ok {
|
||||
t.Error("garbage line should not parse")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveLogPath(t *testing.T) {
|
||||
allow := map[string][]string{
|
||||
"nginx.service": {"/var/log/nginx/access.log", "/var/log/nginx/error.log"},
|
||||
}
|
||||
|
||||
// Allowlisted path resolves whether the caller uses the bare or .service
|
||||
// form, regardless of which form the config key used.
|
||||
for _, unit := range []string{"nginx.service", "nginx"} {
|
||||
if p, err := resolveLogPath(allow, unit, "/var/log/nginx/error.log"); err != nil || p != "/var/log/nginx/error.log" {
|
||||
t.Errorf("allowlisted path for %q: got %q, %v", unit, p, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Everything else is rejected: empty path, non-listed path (traversal),
|
||||
// listed path but wrong unit, and unit with no allowlist at all.
|
||||
bad := []struct{ unit, path string }{
|
||||
{"nginx.service", ""},
|
||||
{"nginx.service", "/etc/shadow"},
|
||||
{"nginx.service", "/var/log/nginx/access.log/../../../etc/shadow"},
|
||||
{"sshd.service", "/var/log/nginx/error.log"},
|
||||
{"unknown.service", "/var/log/nginx/error.log"},
|
||||
}
|
||||
for _, b := range bad {
|
||||
if _, err := resolveLogPath(allow, b.unit, b.path); err == nil {
|
||||
t.Errorf("resolveLogPath(%q, %q) = nil error, want rejection", b.unit, b.path)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestJournalUnit(t *testing.T) {
|
||||
cases := map[string]string{
|
||||
"docker.service": "docker",
|
||||
"docker": "docker",
|
||||
"sshd.service": "sshd",
|
||||
"foo.socket": "foo.socket", // only .service is stripped
|
||||
}
|
||||
for in, want := range cases {
|
||||
if got := journalUnit(in); got != want {
|
||||
t.Errorf("journalUnit(%q) = %q, want %q", in, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClampLines(t *testing.T) {
|
||||
cases := map[int]int{0: defaultLogLines, -5: defaultLogLines, 50: 50, 999999: maxLogLines}
|
||||
for in, want := range cases {
|
||||
if got := clampLines(in); got != want {
|
||||
t.Errorf("clampLines(%d) = %d, want %d", in, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"nadir/internal/rbac"
|
||||
|
||||
"github.com/danielgtaylor/huma/v2"
|
||||
)
|
||||
|
||||
const ModuleID = "services"
|
||||
|
||||
type Module struct {
|
||||
// logFiles is the per-unit allowlist of readable log files (from config),
|
||||
// consulted by the file log source.
|
||||
logFiles map[string][]string
|
||||
}
|
||||
|
||||
func New(logFiles map[string][]string) *Module { return &Module{logFiles: logFiles} }
|
||||
|
||||
func (m *Module) ID() string { return ModuleID }
|
||||
func (m *Module) Name() string { return "Services" }
|
||||
|
||||
// Permissions: read to list and inspect units; write to control them
|
||||
// (start/stop/restart/enable/disable).
|
||||
func (m *Module) Permissions() []rbac.Permission {
|
||||
return []rbac.Permission{rbac.Read, rbac.Write}
|
||||
}
|
||||
|
||||
func (m *Module) Register(api huma.API) {
|
||||
registerServices(api)
|
||||
registerLogs(api, m.logFiles)
|
||||
}
|
||||
|
||||
func op(permission string) map[string]any {
|
||||
return map[string]any{
|
||||
"module": ModuleID,
|
||||
"permission": permission,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,180 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"nadir/internal/oscmd"
|
||||
|
||||
"github.com/danielgtaylor/huma/v2"
|
||||
)
|
||||
|
||||
const tagServices = "Services"
|
||||
|
||||
var (
|
||||
readErrors = []int{401, 403, 500}
|
||||
writeErrors = []int{400, 401, 403, 404, 500}
|
||||
)
|
||||
|
||||
// unitNameRe matches valid systemd unit names. Combined with a leading-dash
|
||||
// reject and the "--" separator on every systemctl call, it keeps user-supplied
|
||||
// unit names from being interpreted as options.
|
||||
var unitNameRe = regexp.MustCompile(`^[a-zA-Z0-9@._:-]+$`)
|
||||
|
||||
// ServiceUnit mirrors one entry of `systemctl list-units --type=service -o json`.
|
||||
type ServiceUnit struct {
|
||||
Unit string `json:"unit" example:"sshd.service" doc:"Unit name"`
|
||||
Load string `json:"load" example:"loaded" doc:"Load state"`
|
||||
Active string `json:"active" example:"active" doc:"High-level active state"`
|
||||
Sub string `json:"sub" example:"running" doc:"Low-level sub state"`
|
||||
Description string `json:"description" example:"OpenSSH server daemon" doc:"Unit description"`
|
||||
}
|
||||
|
||||
type ListServicesOutput struct {
|
||||
Body struct {
|
||||
Services []ServiceUnit `json:"services" doc:"All service units, active and inactive"`
|
||||
}
|
||||
}
|
||||
|
||||
// ServiceStatusBody is the detailed status of a single unit from `systemctl show`.
|
||||
type ServiceStatusBody struct {
|
||||
Unit string `json:"unit" example:"sshd.service"`
|
||||
Description string `json:"description" example:"OpenSSH server daemon"`
|
||||
LoadState string `json:"load_state" example:"loaded" doc:"loaded / not-found / masked"`
|
||||
ActiveState string `json:"active_state" example:"active" doc:"active / inactive / failed"`
|
||||
SubState string `json:"sub_state" example:"running"`
|
||||
UnitFileState string `json:"unit_file_state" example:"enabled" doc:"enabled / disabled / static"`
|
||||
}
|
||||
|
||||
type GetServiceOutput struct{ Body ServiceStatusBody }
|
||||
|
||||
// UnitPath is the shared path parameter for per-unit operations.
|
||||
type UnitPath struct {
|
||||
Unit string `path:"unit" example:"sshd.service" doc:"systemd unit name"`
|
||||
}
|
||||
|
||||
func registerServices(api huma.API) {
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "services-list",
|
||||
Method: "GET",
|
||||
Path: "/api/services",
|
||||
Summary: "List service units",
|
||||
Description: "Returns all service units (active and inactive) via " +
|
||||
"`systemctl list-units --type=service --all`.",
|
||||
Tags: []string{tagServices},
|
||||
Metadata: op("read"),
|
||||
Errors: readErrors,
|
||||
}, func(ctx context.Context, _ *struct{}) (*ListServicesOutput, error) {
|
||||
out, err := oscmd.Run("systemctl", "list-units", "--type=service", "--all", "-o", "json", "--no-pager")
|
||||
if err != nil {
|
||||
return nil, huma.Error500InternalServerError("systemctl list-units failed", err)
|
||||
}
|
||||
var units []ServiceUnit
|
||||
if err := json.Unmarshal([]byte(out), &units); err != nil {
|
||||
return nil, huma.Error500InternalServerError("parse systemctl json failed", err)
|
||||
}
|
||||
res := &ListServicesOutput{}
|
||||
res.Body.Services = units
|
||||
return res, nil
|
||||
})
|
||||
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "services-get",
|
||||
Method: "GET",
|
||||
Path: "/api/services/{unit}",
|
||||
Summary: "Get a service's status",
|
||||
Description: "Returns load/active/sub/unit-file state for one unit via " +
|
||||
"`systemctl show`. Returns 404 when the unit does not exist.",
|
||||
Tags: []string{tagServices},
|
||||
Metadata: op("read"),
|
||||
Errors: []int{400, 401, 403, 404, 500},
|
||||
}, func(ctx context.Context, in *UnitPath) (*GetServiceOutput, error) {
|
||||
if err := validateUnit(in.Unit); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m, err := showUnit(in.Unit)
|
||||
if err != nil {
|
||||
return nil, huma.Error500InternalServerError("systemctl show failed", err)
|
||||
}
|
||||
if m["LoadState"] == "not-found" {
|
||||
return nil, huma.Error404NotFound("unit not found: " + in.Unit)
|
||||
}
|
||||
out := &GetServiceOutput{Body: ServiceStatusBody{
|
||||
Unit: m["Id"],
|
||||
Description: m["Description"],
|
||||
LoadState: m["LoadState"],
|
||||
ActiveState: m["ActiveState"],
|
||||
SubState: m["SubState"],
|
||||
UnitFileState: m["UnitFileState"],
|
||||
}}
|
||||
return out, nil
|
||||
})
|
||||
|
||||
controls := []struct{ action, summary, desc string }{
|
||||
{"start", "Start a service", "Starts the unit (`systemctl start`)."},
|
||||
{"stop", "Stop a service", "Stops the unit (`systemctl stop`)."},
|
||||
{"restart", "Restart a service", "Restarts the unit (`systemctl restart`)."},
|
||||
{"enable", "Enable a service at boot", "Enables the unit (`systemctl enable`)."},
|
||||
{"disable", "Disable a service at boot", "Disables the unit (`systemctl disable`)."},
|
||||
}
|
||||
for _, c := range controls {
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "services-" + c.action,
|
||||
Method: "POST",
|
||||
Path: "/api/services/{unit}/" + c.action,
|
||||
Summary: c.summary,
|
||||
Description: c.desc + " Returns 404 when the unit does not exist.",
|
||||
Tags: []string{tagServices},
|
||||
Metadata: op("write"),
|
||||
Errors: writeErrors,
|
||||
}, func(ctx context.Context, in *UnitPath) (*oscmd.StatusOutput, error) {
|
||||
if err := validateUnit(in.Unit); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := ensureExists(in.Unit); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err := oscmd.Run("systemctl", c.action, "--", in.Unit); err != nil {
|
||||
return nil, huma.Error500InternalServerError("systemctl "+c.action+" failed", err)
|
||||
}
|
||||
return oscmd.OK(), nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// validateUnit guards against empty, flag-like, or malformed unit names.
|
||||
func validateUnit(unit string) error {
|
||||
if unit == "" || strings.HasPrefix(unit, "-") || !unitNameRe.MatchString(unit) {
|
||||
return huma.Error400BadRequest("invalid unit name: " + unit)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// showUnit returns selected properties of a unit as a key=value map. systemctl
|
||||
// show exits 0 even for unknown units (LoadState=not-found), so callers must
|
||||
// check LoadState to detect non-existence.
|
||||
func showUnit(unit string) (map[string]string, error) {
|
||||
lines, err := oscmd.RunLines("systemctl", "show",
|
||||
"-p", "Id", "-p", "Description", "-p", "LoadState",
|
||||
"-p", "ActiveState", "-p", "SubState", "-p", "UnitFileState",
|
||||
"--", unit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return oscmd.ParseKV(lines), nil
|
||||
}
|
||||
|
||||
// ensureExists returns a 404 if the unit is unknown, mapping the systemctl
|
||||
// show probe to an HTTP error for the control endpoints.
|
||||
func ensureExists(unit string) error {
|
||||
m, err := showUnit(unit)
|
||||
if err != nil {
|
||||
return huma.Error500InternalServerError("systemctl show failed", err)
|
||||
}
|
||||
if m["LoadState"] == "not-found" {
|
||||
return huma.Error404NotFound("unit not found: " + unit)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,161 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"nadir/internal/oscmd"
|
||||
|
||||
"github.com/danielgtaylor/huma/v2"
|
||||
"github.com/danielgtaylor/huma/v2/adapters/humago"
|
||||
"github.com/danielgtaylor/huma/v2/humatest"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
if oscmd.RunHelperProcess() {
|
||||
return
|
||||
}
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
func TestServicesHandlers(t *testing.T) {
|
||||
mux := http.NewServeMux()
|
||||
api := humatest.Wrap(t, humago.New(mux, huma.DefaultConfig("Test", "1.0.0")))
|
||||
|
||||
// Set up allowlisted log files for testing the file source
|
||||
logFiles := map[string][]string{
|
||||
"nginx.service": {filepath.Join(t.TempDir(), "nginx-error.log")},
|
||||
}
|
||||
// Create the dummy log file
|
||||
errLogPath := logFiles["nginx.service"][0]
|
||||
if err := os.WriteFile(errLogPath, []byte("file log line 1\nfile log line 2\n"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
m := New(logFiles)
|
||||
m.Register(api)
|
||||
|
||||
// 1. Test GET /api/services (list services)
|
||||
oscmd.SetMock("systemctl", func(args []string) oscmd.MockCommand {
|
||||
if reflect.DeepEqual(args, []string{"list-units", "--type=service", "--all", "-o", "json", "--no-pager"}) {
|
||||
units := []ServiceUnit{
|
||||
{Unit: "sshd.service", Load: "loaded", Active: "active", Sub: "running", Description: "OpenSSH"},
|
||||
}
|
||||
data, _ := json.Marshal(units)
|
||||
return oscmd.MockCommand{Stdout: string(data) + "\n", ExitCode: 0}
|
||||
}
|
||||
if reflect.DeepEqual(args, []string{"show", "-p", "Id", "-p", "Description", "-p", "LoadState", "-p", "ActiveState", "-p", "SubState", "-p", "UnitFileState", "--", "sshd.service"}) {
|
||||
showOut := "Id=sshd.service\nDescription=OpenSSH\nLoadState=loaded\nActiveState=active\nSubState=running\nUnitFileState=enabled\n"
|
||||
return oscmd.MockCommand{Stdout: showOut, ExitCode: 0}
|
||||
}
|
||||
if reflect.DeepEqual(args, []string{"start", "--", "sshd.service"}) {
|
||||
return oscmd.MockCommand{ExitCode: 0}
|
||||
}
|
||||
return oscmd.MockCommand{ExitCode: 1}
|
||||
})
|
||||
defer oscmd.ClearMocks()
|
||||
|
||||
resp := api.Get("/api/services")
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("list services: got %d, want %d", resp.Code, http.StatusOK)
|
||||
}
|
||||
var listRes ListServicesOutput
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &listRes.Body); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(listRes.Body.Services) != 1 || listRes.Body.Services[0].Unit != "sshd.service" {
|
||||
t.Errorf("list services output: %+v", listRes.Body)
|
||||
}
|
||||
|
||||
// 2. Test GET /api/services/{unit} (get service status)
|
||||
resp = api.Get("/api/services/sshd.service")
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("get service status: got %d, want %d", resp.Code, http.StatusOK)
|
||||
}
|
||||
var getRes GetServiceOutput
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &getRes.Body); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if getRes.Body.Unit != "sshd.service" || getRes.Body.ActiveState != "active" {
|
||||
t.Errorf("get service output: %+v", getRes.Body)
|
||||
}
|
||||
|
||||
// 3. Test POST /api/services/{unit}/start
|
||||
resp = api.Post("/api/services/sshd.service/start", struct{}{})
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("start service: got %d, want %d", resp.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
// 4. Test GET /api/services/{unit}/logs (journal source)
|
||||
oscmd.SetMock("journalctl", func(args []string) oscmd.MockCommand {
|
||||
if strings.Contains(strings.Join(args, " "), "-f") {
|
||||
// Streaming mock
|
||||
lines := []string{
|
||||
`{"MESSAGE":"streaming line 1","PRIORITY":"6","__REALTIME_TIMESTAMP":"1718873704000000"}`,
|
||||
}
|
||||
return oscmd.MockCommand{Lines: lines, DelayMs: 1, ExitCode: 0}
|
||||
}
|
||||
// Regular snapshot mock
|
||||
lines := []string{
|
||||
`{"MESSAGE":"journal line 1","PRIORITY":"6","__REALTIME_TIMESTAMP":"1718873704000000"}`,
|
||||
}
|
||||
return oscmd.MockCommand{Stdout: strings.Join(lines, "\n") + "\n", ExitCode: 0}
|
||||
})
|
||||
|
||||
resp = api.Get("/api/services/sshd.service/logs")
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("get journal logs: got %d, want %d", resp.Code, http.StatusOK)
|
||||
}
|
||||
var logsRes LogsOutput
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &logsRes.Body); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(logsRes.Body.Entries) != 1 || logsRes.Body.Entries[0].Message != "journal line 1" {
|
||||
t.Errorf("journal logs output: %+v", logsRes.Body)
|
||||
}
|
||||
|
||||
// 5. Test GET /api/services/{unit}/logs (file source)
|
||||
oscmd.SetMock("tail", func(args []string) oscmd.MockCommand {
|
||||
if strings.Contains(strings.Join(args, " "), "-F") {
|
||||
// Streaming mock
|
||||
return oscmd.MockCommand{Lines: []string{"stream file line 1"}, DelayMs: 1, ExitCode: 0}
|
||||
}
|
||||
return oscmd.MockCommand{Stdout: "file log line 1\nfile log line 2\n", ExitCode: 0}
|
||||
})
|
||||
|
||||
resp = api.Get("/api/services/nginx.service/logs?source=file&path=" + errLogPath)
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("get file logs: got %d, want %d", resp.Code, http.StatusOK)
|
||||
}
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &logsRes.Body); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(logsRes.Body.Entries) != 2 || logsRes.Body.Entries[0].Message != "file log line 1" {
|
||||
t.Errorf("file logs output: %+v", logsRes.Body)
|
||||
}
|
||||
|
||||
// 6. Test GET /api/services/{unit}/logs/stream (journal stream)
|
||||
resp = api.Get("/api/services/sshd.service/logs/stream")
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("stream journal logs: got %d, want %d", resp.Code, http.StatusOK)
|
||||
}
|
||||
bodyStr := resp.Body.String()
|
||||
if !strings.Contains(bodyStr, "streaming line 1") {
|
||||
t.Errorf("stream journal logs missing message, got: %q", bodyStr)
|
||||
}
|
||||
|
||||
// 7. Test GET /api/services/{unit}/logs/stream (file stream)
|
||||
resp = api.Get("/api/services/nginx.service/logs/stream?source=file&path=" + errLogPath)
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("stream file logs: got %d, want %d", resp.Code, http.StatusOK)
|
||||
}
|
||||
bodyStr = resp.Body.String()
|
||||
if !strings.Contains(bodyStr, "stream file line 1") {
|
||||
t.Errorf("stream file logs missing message, got: %q", bodyStr)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
package services
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestValidateUnit(t *testing.T) {
|
||||
valid := []string{"sshd.service", "getty@tty1.service", "foo.bar:baz-1.service", "a_b.timer"}
|
||||
for _, u := range valid {
|
||||
if err := validateUnit(u); err != nil {
|
||||
t.Errorf("validateUnit(%q) = %v, want nil", u, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Empty, flag-injection, and anything with shell/path metacharacters must
|
||||
// be rejected before reaching systemctl.
|
||||
invalid := []string{"", "-rf", "--now", "a b", "foo;rm -rf /", "a/b", "naughty$()", "x|y"}
|
||||
for _, u := range invalid {
|
||||
if err := validateUnit(u); err == nil {
|
||||
t.Errorf("validateUnit(%q) = nil, want error", u)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/danielgtaylor/huma/v2"
|
||||
)
|
||||
|
||||
const tagStorage = "Storage"
|
||||
|
||||
var (
|
||||
readErrors = []int{401, 403, 500}
|
||||
writeErrors = []int{400, 401, 403, 409, 500}
|
||||
)
|
||||
|
||||
// Validation regexes. Every value is also passed after "--" on the command line.
|
||||
// We forbid whitespace and shell metacharacters outright (mount/umount take no
|
||||
// such values for our purposes, and fstab fields with spaces would need octal
|
||||
// escaping we don't emit).
|
||||
var (
|
||||
// deviceRe allows a /dev path or a UUID=/LABEL= specifier.
|
||||
deviceRe = regexp.MustCompile(`^(/dev/[a-zA-Z0-9/._-]+|UUID=[a-zA-Z0-9-]+|LABEL=[a-zA-Z0-9._-]+|PARTUUID=[a-zA-Z0-9-]+)$`)
|
||||
// mountpointRe is an absolute path without traversal or whitespace.
|
||||
mountpointRe = regexp.MustCompile(`^/[a-zA-Z0-9/._-]*$`)
|
||||
// fstypeRe matches a filesystem type token (ext4, xfs, btrfs, vfat, nfs, …).
|
||||
fstypeRe = regexp.MustCompile(`^[a-z0-9]+$`)
|
||||
// optionsRe matches a comma-separated mount option list (defaults, noatime,
|
||||
// uid=1000, …).
|
||||
optionsRe = regexp.MustCompile(`^[a-zA-Z0-9=,._:+-]+$`)
|
||||
)
|
||||
|
||||
func validateMountpoint(mp string) error {
|
||||
if !mountpointRe.MatchString(mp) || hasDotDot(mp) {
|
||||
return huma.Error400BadRequest("invalid mountpoint: " + mp)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// hasDotDot rejects a ".." path segment so a mountpoint can't escape upward.
|
||||
func hasDotDot(p string) bool {
|
||||
for _, seg := range strings.Split(p, "/") {
|
||||
if seg == ".." {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"nadir/internal/rbac"
|
||||
|
||||
"github.com/danielgtaylor/huma/v2"
|
||||
)
|
||||
|
||||
const ModuleID = "storage"
|
||||
|
||||
type Module struct{}
|
||||
|
||||
func New() *Module { return &Module{} }
|
||||
|
||||
func (m *Module) ID() string { return ModuleID }
|
||||
func (m *Module) Name() string { return "Storage" }
|
||||
|
||||
// Permissions: read to list mounts and fstab; write to mount/unmount and edit
|
||||
// fstab entries.
|
||||
func (m *Module) Permissions() []rbac.Permission {
|
||||
return []rbac.Permission{rbac.Read, rbac.Write}
|
||||
}
|
||||
|
||||
func (m *Module) Register(api huma.API) {
|
||||
registerStorage(api)
|
||||
}
|
||||
|
||||
func op(permission string) map[string]any {
|
||||
return map[string]any{"module": ModuleID, "permission": permission}
|
||||
}
|
||||
@@ -0,0 +1,311 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"nadir/internal/mounts"
|
||||
"nadir/internal/oscmd"
|
||||
|
||||
"github.com/danielgtaylor/huma/v2"
|
||||
)
|
||||
|
||||
// fstabFile is a var so tests can point it at a fixture.
|
||||
var fstabFile = "/etc/fstab"
|
||||
|
||||
// FstabEntry is one /etc/fstab line. Dump and Pass are the last two numeric
|
||||
// fields (fs_freq and fs_passno).
|
||||
type FstabEntry struct {
|
||||
Device string `json:"device" example:"UUID=1234-5678"`
|
||||
Mountpoint string `json:"mountpoint" example:"/mnt/data"`
|
||||
FSType string `json:"fstype" example:"ext4"`
|
||||
Options string `json:"options" example:"defaults"`
|
||||
Dump int `json:"dump" example:"0"`
|
||||
Pass int `json:"pass" example:"2" doc:"fsck order (0 = skip)"`
|
||||
}
|
||||
|
||||
type ListMountsOutput struct {
|
||||
Body struct {
|
||||
Mounts []mounts.Mount `json:"mounts"`
|
||||
}
|
||||
}
|
||||
|
||||
type ListFstabOutput struct {
|
||||
Body struct {
|
||||
Entries []FstabEntry `json:"entries"`
|
||||
}
|
||||
}
|
||||
|
||||
type MountInput struct {
|
||||
Body struct {
|
||||
Device string `json:"device" example:"/dev/sdb1" doc:"Block device or UUID=/LABEL= specifier"`
|
||||
Mountpoint string `json:"mountpoint" example:"/mnt/data" doc:"Absolute mount path"`
|
||||
FSType string `json:"fstype" example:"ext4"`
|
||||
Options string `json:"options,omitempty" example:"defaults" doc:"Mount options (default: defaults)"`
|
||||
Dump int `json:"dump,omitempty"`
|
||||
Pass int `json:"pass,omitempty" doc:"fsck order (default 2 for real filesystems, 0 to skip)"`
|
||||
}
|
||||
}
|
||||
|
||||
type UnmountInput struct {
|
||||
Mountpoint string `query:"mountpoint" example:"/mnt/data" doc:"Mountpoint to unmount and remove from fstab"`
|
||||
}
|
||||
|
||||
func registerStorage(api huma.API) {
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "storage-list-mounts",
|
||||
Method: "GET",
|
||||
Path: "/api/storage/mounts",
|
||||
Summary: "List active mounts",
|
||||
Description: "Returns the kernel mount table (/proc/mounts).",
|
||||
Tags: []string{tagStorage},
|
||||
Metadata: op("read"),
|
||||
Errors: readErrors,
|
||||
}, func(ctx context.Context, _ *struct{}) (*ListMountsOutput, error) {
|
||||
entries, err := mounts.Proc()
|
||||
if err != nil {
|
||||
return nil, huma.Error500InternalServerError("read mounts failed", err)
|
||||
}
|
||||
res := &ListMountsOutput{}
|
||||
res.Body.Mounts = entries
|
||||
return res, nil
|
||||
})
|
||||
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "storage-list-fstab",
|
||||
Method: "GET",
|
||||
Path: "/api/storage/fstab",
|
||||
Summary: "List /etc/fstab entries",
|
||||
Description: "Returns the persistent mount definitions from /etc/fstab.",
|
||||
Tags: []string{tagStorage},
|
||||
Metadata: op("read"),
|
||||
Errors: readErrors,
|
||||
}, func(ctx context.Context, _ *struct{}) (*ListFstabOutput, error) {
|
||||
data, err := os.ReadFile(fstabFile)
|
||||
if err != nil {
|
||||
return nil, huma.Error500InternalServerError("read fstab failed", err)
|
||||
}
|
||||
res := &ListFstabOutput{}
|
||||
res.Body.Entries = parseFstab(string(data))
|
||||
return res, nil
|
||||
})
|
||||
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "storage-add-mount",
|
||||
Method: "POST",
|
||||
Path: "/api/storage/mounts",
|
||||
Summary: "Add and mount a filesystem",
|
||||
Description: "Appends an /etc/fstab entry and mounts it. If the mount fails the " +
|
||||
"fstab entry is rolled back, so a bad request leaves the system unchanged.",
|
||||
Tags: []string{tagStorage},
|
||||
Metadata: op("write"),
|
||||
Errors: writeErrors,
|
||||
}, func(ctx context.Context, in *MountInput) (*oscmd.StatusOutput, error) {
|
||||
e := FstabEntry{
|
||||
Device: in.Body.Device,
|
||||
Mountpoint: in.Body.Mountpoint,
|
||||
FSType: in.Body.FSType,
|
||||
Options: in.Body.Options,
|
||||
Dump: in.Body.Dump,
|
||||
Pass: in.Body.Pass,
|
||||
}
|
||||
if e.Options == "" {
|
||||
e.Options = "defaults"
|
||||
}
|
||||
if err := validateEntry(e); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
existing, err := readFstab()
|
||||
if err != nil {
|
||||
return nil, huma.Error500InternalServerError("read fstab failed", err)
|
||||
}
|
||||
if findEntry(existing, e.Mountpoint) != nil {
|
||||
return nil, huma.Error409Conflict("an fstab entry already exists for " + e.Mountpoint)
|
||||
}
|
||||
|
||||
if err := appendFstabLine(e); err != nil {
|
||||
return nil, huma.Error500InternalServerError("write fstab failed", err)
|
||||
}
|
||||
// mount reads the fstab line we just wrote. On failure, roll the line back
|
||||
// so a bad device/options doesn't linger in fstab and break the next boot.
|
||||
if _, err := oscmd.RunContext(ctx, "mount", "--", e.Mountpoint); err != nil {
|
||||
if _, werr := removeFstabLines(e.Mountpoint); werr != nil {
|
||||
return nil, huma.Error500InternalServerError("mount failed and fstab rollback failed", werr)
|
||||
}
|
||||
return nil, huma.Error400BadRequest("mount failed: " + err.Error())
|
||||
}
|
||||
return oscmd.OK(), nil
|
||||
})
|
||||
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "storage-remove-mount",
|
||||
Method: "DELETE",
|
||||
Path: "/api/storage/mounts",
|
||||
Summary: "Unmount and remove a filesystem",
|
||||
Description: "Unmounts the filesystem (if mounted) and removes its /etc/fstab entry. " +
|
||||
"The mountpoint is passed as a query parameter since it contains slashes.",
|
||||
Tags: []string{tagStorage},
|
||||
Metadata: op("write"),
|
||||
Errors: []int{400, 401, 403, 404, 409, 500},
|
||||
}, func(ctx context.Context, in *UnmountInput) (*oscmd.StatusOutput, error) {
|
||||
if err := validateMountpoint(in.Mountpoint); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
entries, err := readFstab()
|
||||
if err != nil {
|
||||
return nil, huma.Error500InternalServerError("read fstab failed", err)
|
||||
}
|
||||
inFstab := findEntry(entries, in.Mountpoint) != nil
|
||||
|
||||
mounted, err := isMounted(in.Mountpoint)
|
||||
if err != nil {
|
||||
return nil, huma.Error500InternalServerError("read mounts failed", err)
|
||||
}
|
||||
if !inFstab && !mounted {
|
||||
return nil, huma.Error404NotFound("no mount or fstab entry for " + in.Mountpoint)
|
||||
}
|
||||
|
||||
if mounted {
|
||||
if _, err := oscmd.RunContext(ctx, "umount", "--", in.Mountpoint); err != nil {
|
||||
return nil, huma.Error409Conflict("unmount failed (in use?): " + err.Error())
|
||||
}
|
||||
}
|
||||
if inFstab {
|
||||
if _, err := removeFstabLines(in.Mountpoint); err != nil {
|
||||
return nil, huma.Error500InternalServerError("write fstab failed", err)
|
||||
}
|
||||
}
|
||||
return oscmd.OK(), nil
|
||||
})
|
||||
}
|
||||
|
||||
// --- fstab helpers -----------------------------------------------------------
|
||||
|
||||
func readFstab() ([]FstabEntry, error) {
|
||||
data, err := os.ReadFile(fstabFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return parseFstab(string(data)), nil
|
||||
}
|
||||
|
||||
// parseFstab parses /etc/fstab, skipping comments and blank lines.
|
||||
func parseFstab(data string) []FstabEntry {
|
||||
entries := []FstabEntry{}
|
||||
for line := range strings.SplitSeq(data, "\n") {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if trimmed == "" || strings.HasPrefix(trimmed, "#") {
|
||||
continue
|
||||
}
|
||||
f := strings.Fields(trimmed)
|
||||
if len(f) < 4 { // device mountpoint fstype options [dump] [pass]
|
||||
continue
|
||||
}
|
||||
e := FstabEntry{
|
||||
Device: mounts.Unescape(f[0]),
|
||||
Mountpoint: mounts.Unescape(f[1]),
|
||||
FSType: f[2],
|
||||
Options: f[3],
|
||||
}
|
||||
if len(f) >= 5 {
|
||||
e.Dump, _ = strconv.Atoi(f[4])
|
||||
}
|
||||
if len(f) >= 6 {
|
||||
e.Pass, _ = strconv.Atoi(f[5])
|
||||
}
|
||||
entries = append(entries, e)
|
||||
}
|
||||
return entries
|
||||
}
|
||||
|
||||
// fstab edits are surgical (append one line / drop matching lines) rather than a
|
||||
// whole-file rewrite, so an admin's existing entries, comments and formatting in
|
||||
// this boot-critical file are preserved — same approach as /etc/hosts.
|
||||
|
||||
func renderFstabLine(e FstabEntry) string {
|
||||
return fmt.Sprintf("%s\t%s\t%s\t%s\t%d\t%d", e.Device, e.Mountpoint, e.FSType, e.Options, e.Dump, e.Pass)
|
||||
}
|
||||
|
||||
// appendFstabLine adds one entry, leaving every existing line untouched.
|
||||
func appendFstabLine(e FstabEntry) error {
|
||||
data, err := os.ReadFile(fstabFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
content := string(data)
|
||||
if content != "" && !strings.HasSuffix(content, "\n") {
|
||||
content += "\n"
|
||||
}
|
||||
content += renderFstabLine(e) + "\n"
|
||||
return os.WriteFile(fstabFile, []byte(content), 0644)
|
||||
}
|
||||
|
||||
// removeFstabLines drops every line mapping mountpoint, preserving comments and
|
||||
// other entries. Reports whether anything was removed.
|
||||
func removeFstabLines(mountpoint string) (bool, error) {
|
||||
data, err := os.ReadFile(fstabFile)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
lines := strings.Split(string(data), "\n")
|
||||
kept := make([]string, 0, len(lines))
|
||||
removed := false
|
||||
for _, line := range lines {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if trimmed != "" && !strings.HasPrefix(trimmed, "#") {
|
||||
if f := strings.Fields(trimmed); len(f) >= 2 && mounts.Unescape(f[1]) == mountpoint {
|
||||
removed = true
|
||||
continue
|
||||
}
|
||||
}
|
||||
kept = append(kept, line)
|
||||
}
|
||||
if !removed {
|
||||
return false, nil
|
||||
}
|
||||
return true, os.WriteFile(fstabFile, []byte(strings.Join(kept, "\n")), 0644)
|
||||
}
|
||||
|
||||
func findEntry(entries []FstabEntry, mountpoint string) *FstabEntry {
|
||||
for i := range entries {
|
||||
if entries[i].Mountpoint == mountpoint {
|
||||
return &entries[i]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// isMounted reports whether mountpoint currently appears in /proc/mounts.
|
||||
func isMounted(mountpoint string) (bool, error) {
|
||||
active, err := mounts.Proc()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
for _, e := range active {
|
||||
if e.Mountpoint == mountpoint {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func validateEntry(e FstabEntry) error {
|
||||
if !deviceRe.MatchString(e.Device) {
|
||||
return huma.Error400BadRequest("invalid device: " + e.Device)
|
||||
}
|
||||
if err := validateMountpoint(e.Mountpoint); err != nil {
|
||||
return err
|
||||
}
|
||||
if !fstypeRe.MatchString(e.FSType) {
|
||||
return huma.Error400BadRequest("invalid fstype: " + e.FSType)
|
||||
}
|
||||
if !optionsRe.MatchString(e.Options) {
|
||||
return huma.Error400BadRequest("invalid options: " + e.Options)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,83 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"nadir/internal/oscmd"
|
||||
|
||||
"github.com/danielgtaylor/huma/v2"
|
||||
"github.com/danielgtaylor/huma/v2/adapters/humago"
|
||||
"github.com/danielgtaylor/huma/v2/humatest"
|
||||
)
|
||||
|
||||
func TestStorageHandlers(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "fstab")
|
||||
if err := os.WriteFile(path, []byte("# seed\nUUID=root / ext4 defaults 0 1\n"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
old := fstabFile
|
||||
fstabFile = path
|
||||
defer func() { fstabFile = old }()
|
||||
|
||||
mux := http.NewServeMux()
|
||||
api := humatest.Wrap(t, humago.New(mux, huma.DefaultConfig("Test", "1.0.0")))
|
||||
New().Register(api)
|
||||
|
||||
mountOK := true
|
||||
oscmd.SetMock("mount", func([]string) oscmd.MockCommand {
|
||||
if mountOK {
|
||||
return oscmd.MockCommand{ExitCode: 0}
|
||||
}
|
||||
return oscmd.MockCommand{Stderr: "mount: wrong fs type", ExitCode: 32}
|
||||
})
|
||||
oscmd.SetMock("umount", func([]string) oscmd.MockCommand { return oscmd.MockCommand{ExitCode: 0} })
|
||||
defer oscmd.ClearMocks()
|
||||
|
||||
body := map[string]any{"device": "/dev/sdb1", "mountpoint": "/mnt/data", "fstype": "ext4"}
|
||||
|
||||
// Add + mount succeeds.
|
||||
if resp := api.Post("/api/storage/mounts", body); resp.Code != http.StatusOK {
|
||||
t.Fatalf("add mount: got %d, body=%s", resp.Code, resp.Body.String())
|
||||
}
|
||||
if findEntry(mustReadFstab(t), "/mnt/data") == nil {
|
||||
t.Fatal("entry not written to fstab")
|
||||
}
|
||||
|
||||
// Duplicate mountpoint → 409.
|
||||
if resp := api.Post("/api/storage/mounts", body); resp.Code != http.StatusConflict {
|
||||
t.Errorf("duplicate: got %d, want 409", resp.Code)
|
||||
}
|
||||
|
||||
// mount fails → 400 and the fstab line is rolled back.
|
||||
mountOK = false
|
||||
bad := map[string]any{"device": "/dev/sdc1", "mountpoint": "/mnt/bad", "fstype": "ext4"}
|
||||
if resp := api.Post("/api/storage/mounts", bad); resp.Code != http.StatusBadRequest {
|
||||
t.Errorf("failed mount: got %d, want 400", resp.Code)
|
||||
}
|
||||
if findEntry(mustReadFstab(t), "/mnt/bad") != nil {
|
||||
t.Error("fstab entry not rolled back after mount failure")
|
||||
}
|
||||
|
||||
// Invalid device is rejected before touching the system.
|
||||
if resp := api.Post("/api/storage/mounts", map[string]any{"device": "/dev/x; rm", "mountpoint": "/mnt/x", "fstype": "ext4"}); resp.Code != http.StatusBadRequest {
|
||||
t.Errorf("invalid device: got %d, want 400", resp.Code)
|
||||
}
|
||||
|
||||
// Delete the good entry (not actually mounted, so no umount needed).
|
||||
if resp := api.Delete("/api/storage/mounts?mountpoint=/mnt/data"); resp.Code != http.StatusOK {
|
||||
t.Errorf("delete: got %d, body=%s", resp.Code, resp.Body.String())
|
||||
}
|
||||
data, _ := os.ReadFile(path)
|
||||
if strings.Contains(string(data), "/mnt/data") || !strings.Contains(string(data), "# seed") {
|
||||
t.Errorf("delete result wrong:\n%s", data)
|
||||
}
|
||||
|
||||
// Deleting a nonexistent mountpoint → 404.
|
||||
if resp := api.Delete("/api/storage/mounts?mountpoint=/mnt/none"); resp.Code != http.StatusNotFound {
|
||||
t.Errorf("delete missing: got %d, want 404", resp.Code)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,110 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"nadir/internal/oscmd"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
if oscmd.RunHelperProcess() {
|
||||
return
|
||||
}
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
func TestParseFstab(t *testing.T) {
|
||||
data := `# /etc/fstab
|
||||
UUID=1111-2222 / ext4 defaults 0 1
|
||||
|
||||
/dev/sdb1 /mnt/data ext4 rw,noatime 0 2
|
||||
LABEL=swap none swap sw
|
||||
`
|
||||
got := parseFstab(data)
|
||||
want := []FstabEntry{
|
||||
{Device: "UUID=1111-2222", Mountpoint: "/", FSType: "ext4", Options: "defaults", Dump: 0, Pass: 1},
|
||||
{Device: "/dev/sdb1", Mountpoint: "/mnt/data", FSType: "ext4", Options: "rw,noatime", Dump: 0, Pass: 2},
|
||||
{Device: "LABEL=swap", Mountpoint: "none", FSType: "swap", Options: "sw", Dump: 0, Pass: 0},
|
||||
}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("parseFstab:\n got %+v\nwant %+v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFstabSurgicalEdits(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "fstab")
|
||||
old := fstabFile
|
||||
fstabFile = path
|
||||
defer func() { fstabFile = old }()
|
||||
|
||||
// Seed with a comment + an existing entry that must survive edits.
|
||||
seed := "# keep me\nUUID=root / ext4 defaults 0 1\n"
|
||||
if err := os.WriteFile(path, []byte(seed), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Append a new entry.
|
||||
if err := appendFstabLine(FstabEntry{Device: "/dev/sdb1", Mountpoint: "/mnt/data", FSType: "xfs", Options: "noatime", Pass: 2}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
entries, err := readFstab()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(entries) != 2 || findEntry(entries, "/mnt/data") == nil {
|
||||
t.Fatalf("append failed: %+v", entries)
|
||||
}
|
||||
|
||||
// Remove it again; the comment and original entry must remain.
|
||||
removed, err := removeFstabLines("/mnt/data")
|
||||
if err != nil || !removed {
|
||||
t.Fatalf("remove failed: removed=%v err=%v", removed, err)
|
||||
}
|
||||
data, _ := os.ReadFile(path)
|
||||
if !strings.Contains(string(data), "# keep me") || !strings.Contains(string(data), "UUID=root") {
|
||||
t.Errorf("surgical edit clobbered other lines:\n%s", data)
|
||||
}
|
||||
if findEntry(mustReadFstab(t), "/mnt/data") != nil {
|
||||
t.Error("entry still present after remove")
|
||||
}
|
||||
|
||||
// Removing a missing mountpoint reports not-removed.
|
||||
if removed, _ := removeFstabLines("/nope"); removed {
|
||||
t.Error("expected removed=false for missing mountpoint")
|
||||
}
|
||||
}
|
||||
|
||||
func mustReadFstab(t *testing.T) []FstabEntry {
|
||||
t.Helper()
|
||||
e, err := readFstab()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return e
|
||||
}
|
||||
|
||||
func TestValidateEntry(t *testing.T) {
|
||||
ok := FstabEntry{Device: "/dev/sdb1", Mountpoint: "/mnt/data", FSType: "ext4", Options: "defaults"}
|
||||
if err := validateEntry(ok); err != nil {
|
||||
t.Errorf("valid entry rejected: %v", err)
|
||||
}
|
||||
if err := validateEntry(FstabEntry{Device: "UUID=ab-cd", Mountpoint: "/mnt/x", FSType: "xfs", Options: "rw,noatime"}); err != nil {
|
||||
t.Errorf("UUID entry rejected: %v", err)
|
||||
}
|
||||
bad := []FstabEntry{
|
||||
{Device: "/dev/sdb1; rm -rf /", Mountpoint: "/mnt/x", FSType: "ext4", Options: "defaults"}, // shell metachars
|
||||
{Device: "/dev/sdb1", Mountpoint: "../etc", FSType: "ext4", Options: "defaults"}, // not absolute
|
||||
{Device: "/dev/sdb1", Mountpoint: "/mnt/../../etc", FSType: "ext4", Options: "defaults"}, // traversal
|
||||
{Device: "/dev/sdb1", Mountpoint: "/mnt/x", FSType: "ext4!", Options: "defaults"}, // bad fstype
|
||||
{Device: "/dev/sdb1", Mountpoint: "/mnt/x", FSType: "ext4", Options: "defaults; reboot"}, // bad options
|
||||
}
|
||||
for i, e := range bad {
|
||||
if err := validateEntry(e); err == nil {
|
||||
t.Errorf("bad entry %d accepted: %+v", i, e)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
package system
|
||||
|
||||
// tagSystem groups every system-module operation (info, hostname, time/date,
|
||||
// locale, power) under one OpenAPI tag, so tags map 1:1 to modules.
|
||||
const tagSystem = "System"
|
||||
|
||||
var (
|
||||
readErrors = []int{401, 403, 500}
|
||||
writeErrors = []int{400, 401, 403, 500}
|
||||
)
|
||||
@@ -0,0 +1,57 @@
|
||||
package system
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"nadir/internal/oscmd"
|
||||
|
||||
"github.com/danielgtaylor/huma/v2"
|
||||
)
|
||||
|
||||
type HostnameBody struct {
|
||||
Hostname string `json:"hostname" example:"server01" doc:"System hostname"`
|
||||
}
|
||||
|
||||
type GetHostnameOutput struct{ Body HostnameBody }
|
||||
type SetHostnameInput struct{ Body HostnameBody }
|
||||
|
||||
func registerHostname(api huma.API) {
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "system-get-hostname",
|
||||
Method: "GET",
|
||||
Path: "/api/system/hostname",
|
||||
Summary: "Get system hostname",
|
||||
Description: "Returns the current hostname as reported by hostnamectl.",
|
||||
Tags: []string{tagSystem},
|
||||
Metadata: op("read"),
|
||||
Errors: readErrors,
|
||||
}, func(ctx context.Context, _ *struct{}) (*GetHostnameOutput, error) {
|
||||
name, err := oscmd.Run("hostnamectl", "hostname")
|
||||
if err != nil {
|
||||
return nil, huma.Error500InternalServerError("hostnamectl failed", err)
|
||||
}
|
||||
return &GetHostnameOutput{Body: HostnameBody{Hostname: name}}, nil
|
||||
})
|
||||
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "system-set-hostname",
|
||||
Method: "POST",
|
||||
Path: "/api/system/hostname",
|
||||
Summary: "Set system hostname",
|
||||
Description: "Sets the static hostname via hostnamectl, which owns " +
|
||||
"/etc/hostname and manages the static/pretty/transient names.",
|
||||
Tags: []string{tagSystem},
|
||||
Metadata: op("write"),
|
||||
Errors: writeErrors,
|
||||
}, func(ctx context.Context, in *SetHostnameInput) (*oscmd.StatusOutput, error) {
|
||||
name := strings.TrimSpace(in.Body.Hostname)
|
||||
if name == "" {
|
||||
return nil, huma.Error400BadRequest("empty hostname")
|
||||
}
|
||||
if _, err := oscmd.Run("hostnamectl", "set-hostname", name); err != nil {
|
||||
return nil, huma.Error500InternalServerError("hostnamectl failed", err)
|
||||
}
|
||||
return oscmd.OK(), nil
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,544 @@
|
||||
package system
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"nadir/internal/mounts"
|
||||
"nadir/internal/oscmd"
|
||||
|
||||
"github.com/danielgtaylor/huma/v2"
|
||||
)
|
||||
|
||||
// SystemInfoBody is the dashboard overview: OS identity plus live CPU, memory,
|
||||
// disk, load, network, and temperature readings. Every section is best-effort —
|
||||
// a source that's unavailable (e.g. no thermal zones in a VM) yields a zero
|
||||
// value or empty list rather than failing the whole call.
|
||||
type SystemInfoBody struct {
|
||||
OS OSInfo `json:"os" doc:"OS and kernel identity"`
|
||||
CPU CPUInfo `json:"cpu" doc:"Processor model and core count"`
|
||||
Memory MemoryInfo `json:"memory" doc:"RAM and swap usage in bytes"`
|
||||
Load LoadInfo `json:"load" doc:"Load averages (1/5/15 min)"`
|
||||
UptimeSec int64 `json:"uptime_seconds" example:"12490" doc:"Seconds since boot"`
|
||||
BootTime string `json:"boot_time" example:"2026-06-19T12:08:00Z" doc:"Boot time (RFC3339, UTC)"`
|
||||
Disks []DiskInfo `json:"disks" doc:"Mounted block-device filesystems"`
|
||||
NetworkInterfaces []NetInterface `json:"network_interfaces" doc:"Network interfaces and their addresses"`
|
||||
Temperatures []Temperature `json:"temperatures" doc:"Thermal sensor readings in Celsius"`
|
||||
}
|
||||
|
||||
type OSInfo struct {
|
||||
PrettyName string `json:"pretty_name" example:"Fedora Linux 44 (Workstation Edition)" doc:"Distro name from /etc/os-release PRETTY_NAME"`
|
||||
Kernel string `json:"kernel" example:"7.0.12-201.fc44.x86_64" doc:"Running kernel release (uname -r)"`
|
||||
Architecture string `json:"architecture" example:"x86_64" doc:"Machine hardware architecture (uname -m)"`
|
||||
Hostname string `json:"hostname" example:"server01" doc:"System hostname"`
|
||||
}
|
||||
|
||||
type CPUInfo struct {
|
||||
Model string `json:"model" example:"AMD Ryzen 7 7840U" doc:"CPU model name"`
|
||||
LogicalCPUs int `json:"logical_cpus" example:"16" doc:"Number of logical CPUs (cores × threads)"`
|
||||
MinMHz int `json:"min_mhz" example:"400" doc:"Lowest frequency the scaling governor can select"`
|
||||
MaxMHz int `json:"max_mhz" example:"5137" doc:"Highest frequency (boost ceiling)"`
|
||||
CurrentMHz int `json:"current_mhz" example:"3157" doc:"Peak current clock across all cores (instantaneous snapshot; 0 if cpufreq unavailable)"`
|
||||
}
|
||||
|
||||
type MemoryInfo struct {
|
||||
TotalBytes uint64 `json:"total_bytes" example:"16384000000"`
|
||||
AvailableBytes uint64 `json:"available_bytes" example:"8192000000" doc:"Memory available for new allocations without swapping"`
|
||||
UsedBytes uint64 `json:"used_bytes" example:"8192000000" doc:"total - available"`
|
||||
SwapTotalBytes uint64 `json:"swap_total_bytes" example:"8589934592"`
|
||||
SwapFreeBytes uint64 `json:"swap_free_bytes" example:"8589934592"`
|
||||
}
|
||||
|
||||
type LoadInfo struct {
|
||||
Load1 float64 `json:"load1" example:"0.42"`
|
||||
Load5 float64 `json:"load5" example:"0.55"`
|
||||
Load15 float64 `json:"load15" example:"0.61"`
|
||||
CPUUsage []CoreUsage `json:"cpu_usage" doc:"Per-core CPU usage percentage (sampled over ~1 s); empty until the first sample completes"`
|
||||
}
|
||||
|
||||
// CoreUsage holds the usage percentage for a single logical core, computed as
|
||||
// the delta of non-idle ticks over total ticks between two /proc/stat reads.
|
||||
type CoreUsage struct {
|
||||
Core int `json:"core" example:"0" doc:"Logical core index"`
|
||||
UsagePct float64 `json:"usage_pct" example:"23.4" doc:"Usage percentage (0–100)"`
|
||||
}
|
||||
|
||||
type DiskInfo struct {
|
||||
Mountpoint string `json:"mountpoint" example:"/"`
|
||||
Filesystem string `json:"filesystem" example:"/dev/nvme0n1p2" doc:"Backing device"`
|
||||
FSType string `json:"fstype" example:"btrfs"`
|
||||
TotalBytes uint64 `json:"total_bytes" example:"512000000000"`
|
||||
FreeBytes uint64 `json:"free_bytes" example:"256000000000" doc:"Space available to unprivileged users"`
|
||||
UsedBytes uint64 `json:"used_bytes" example:"256000000000"`
|
||||
}
|
||||
|
||||
type NetInterface struct {
|
||||
Name string `json:"name" example:"eth0"`
|
||||
MAC string `json:"mac" example:"aa:bb:cc:dd:ee:ff"`
|
||||
Up bool `json:"up" doc:"Interface is administratively up"`
|
||||
Addresses []string `json:"addresses" doc:"Assigned addresses in CIDR notation"`
|
||||
}
|
||||
|
||||
type Temperature struct {
|
||||
Chip string `json:"chip" example:"k10temp" doc:"hwmon chip name; identifies the source (k10temp/coretemp=CPU, amdgpu/nvidia=GPU, nvme=disk)"`
|
||||
Label string `json:"label" example:"Tctl" doc:"Per-sensor label, or the chip name when the sensor is unlabelled"`
|
||||
Celsius float64 `json:"celsius" example:"47.5"`
|
||||
}
|
||||
|
||||
type GetInfoOutput struct{ Body SystemInfoBody }
|
||||
|
||||
func registerInfo(api huma.API) {
|
||||
startCPUSampler()
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "system-get-info",
|
||||
Method: "GET",
|
||||
Path: "/api/system/info",
|
||||
Summary: "Get system information",
|
||||
Description: "Returns an overview for a dashboard: OS/kernel identity, CPU, " +
|
||||
"memory and swap, mounted disks, load averages, uptime, network " +
|
||||
"interfaces, and temperatures. All values come from cheap local reads " +
|
||||
"(/proc, /sys, syscalls) with no D-Bus dependency; each section is " +
|
||||
"best-effort.",
|
||||
Tags: []string{tagSystem},
|
||||
Metadata: op("read"),
|
||||
Errors: readErrors,
|
||||
}, func(ctx context.Context, _ *struct{}) (*GetInfoOutput, error) {
|
||||
uptime, boot := uptimeAndBoot()
|
||||
return &GetInfoOutput{Body: SystemInfoBody{
|
||||
OS: osInfo(),
|
||||
CPU: cpuInfo(),
|
||||
Memory: memInfo(),
|
||||
Load: loadInfo(),
|
||||
UptimeSec: uptime,
|
||||
BootTime: boot.Format(time.RFC3339),
|
||||
Disks: diskInfo(),
|
||||
NetworkInterfaces: netInfo(),
|
||||
Temperatures: tempInfo(),
|
||||
}}, nil
|
||||
})
|
||||
}
|
||||
|
||||
func osInfo() OSInfo {
|
||||
host, _ := os.Hostname()
|
||||
return OSInfo{
|
||||
PrettyName: osReleasePretty(),
|
||||
Kernel: firstLine(oscmd.Run("uname", "-r")),
|
||||
Architecture: firstLine(oscmd.Run("uname", "-m")),
|
||||
Hostname: host,
|
||||
}
|
||||
}
|
||||
|
||||
// firstLine discards a command error and returns its (already trimmed) output,
|
||||
// used where a missing value is acceptable.
|
||||
func firstLine(out string, _ error) string { return out }
|
||||
|
||||
func osReleasePretty() string {
|
||||
data, err := os.ReadFile("/etc/os-release")
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
for line := range strings.SplitSeq(string(data), "\n") {
|
||||
if v, ok := strings.CutPrefix(line, "PRETTY_NAME="); ok {
|
||||
return strings.Trim(v, `"`)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func cpuInfo() CPUInfo {
|
||||
data, _ := os.ReadFile("/proc/cpuinfo")
|
||||
c := CPUInfo{Model: cpuModel(string(data)), LogicalCPUs: runtime.NumCPU()}
|
||||
c.MinMHz, c.MaxMHz, c.CurrentMHz = cpuFreqMHz("/sys/devices/system/cpu")
|
||||
// ponytail: cpufreq sysfs is absent on many VMs and stock Ubuntu server
|
||||
// kernels; fall back to /proc/cpuinfo "cpu MHz" so CurrentMHz isn't 0.
|
||||
if c.CurrentMHz == 0 {
|
||||
c.CurrentMHz = cpuinfoMaxMHz(string(data))
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// cpuinfoMaxMHz returns the highest "cpu MHz" value across all cores in
|
||||
// /proc/cpuinfo, rounded to an int. Returns 0 when no such line exists.
|
||||
func cpuinfoMaxMHz(cpuinfo string) int {
|
||||
var max float64
|
||||
for line := range strings.SplitSeq(cpuinfo, "\n") {
|
||||
k, v, ok := strings.Cut(line, ":")
|
||||
if !ok || strings.TrimSpace(k) != "cpu MHz" {
|
||||
continue
|
||||
}
|
||||
if f, err := strconv.ParseFloat(strings.TrimSpace(v), 64); err == nil && f > max {
|
||||
max = f
|
||||
}
|
||||
}
|
||||
return int(math.Round(max))
|
||||
}
|
||||
|
||||
// cpuFreqMHz reads cpufreq sysfs: min/max are stable hardware limits (from
|
||||
// cpu0); current is the highest scaling_cur_freq across all cores — the "is it
|
||||
// boosting" figure. Values are kHz in sysfs. Returns zeros when cpufreq is
|
||||
// absent (e.g. some VMs).
|
||||
func cpuFreqMHz(root string) (min, max, cur int) {
|
||||
min = readKHzAsMHz(filepath.Join(root, "cpu0/cpufreq/cpuinfo_min_freq"))
|
||||
max = readKHzAsMHz(filepath.Join(root, "cpu0/cpufreq/cpuinfo_max_freq"))
|
||||
cores, _ := filepath.Glob(filepath.Join(root, "cpu[0-9]*/cpufreq/scaling_cur_freq"))
|
||||
for _, f := range cores {
|
||||
if v := readKHzAsMHz(f); v > cur {
|
||||
cur = v
|
||||
}
|
||||
}
|
||||
return min, max, cur
|
||||
}
|
||||
|
||||
func readKHzAsMHz(path string) int {
|
||||
khz, err := strconv.Atoi(readTrim(path))
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return khz / 1000
|
||||
}
|
||||
|
||||
// cpuModel extracts the processor model from /proc/cpuinfo. x86 uses "model
|
||||
// name"; many ARM boards use "Model" instead, so fall back to it.
|
||||
func cpuModel(cpuinfo string) string {
|
||||
var fallback string
|
||||
for line := range strings.SplitSeq(cpuinfo, "\n") {
|
||||
k, v, ok := strings.Cut(line, ":")
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
switch strings.TrimSpace(k) {
|
||||
case "model name":
|
||||
return strings.TrimSpace(v)
|
||||
case "Model":
|
||||
fallback = strings.TrimSpace(v)
|
||||
}
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func memInfo() MemoryInfo {
|
||||
data, _ := os.ReadFile("/proc/meminfo")
|
||||
return parseMeminfo(data)
|
||||
}
|
||||
|
||||
// parseMeminfo reads the kB values in /proc/meminfo and converts them to bytes.
|
||||
func parseMeminfo(data []byte) MemoryInfo {
|
||||
kv := map[string]uint64{}
|
||||
for line := range strings.SplitSeq(string(data), "\n") {
|
||||
k, v, ok := strings.Cut(line, ":")
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
fields := strings.Fields(v) // e.g. "16384000 kB"
|
||||
if len(fields) == 0 {
|
||||
continue
|
||||
}
|
||||
if n, err := strconv.ParseUint(fields[0], 10, 64); err == nil {
|
||||
kv[k] = n * 1024 // values are in kB
|
||||
}
|
||||
}
|
||||
return MemoryInfo{
|
||||
TotalBytes: kv["MemTotal"],
|
||||
AvailableBytes: kv["MemAvailable"],
|
||||
UsedBytes: kv["MemTotal"] - kv["MemAvailable"],
|
||||
SwapTotalBytes: kv["SwapTotal"],
|
||||
SwapFreeBytes: kv["SwapFree"],
|
||||
}
|
||||
}
|
||||
|
||||
func loadInfo() LoadInfo {
|
||||
data, _ := os.ReadFile("/proc/loadavg")
|
||||
l := parseLoadavg(string(data))
|
||||
l.CPUUsage = cachedCPUUsage()
|
||||
return l
|
||||
}
|
||||
|
||||
// parseLoadavg reads the three load averages from /proc/loadavg.
|
||||
func parseLoadavg(loadavg string) LoadInfo {
|
||||
f := strings.Fields(loadavg)
|
||||
if len(f) < 3 {
|
||||
return LoadInfo{}
|
||||
}
|
||||
at := func(i int) float64 { v, _ := strconv.ParseFloat(f[i], 64); return v }
|
||||
return LoadInfo{Load1: at(0), Load5: at(1), Load15: at(2)}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Per-core CPU usage sampler
|
||||
// ---------------------------------------------------------------------------
|
||||
//
|
||||
// /proc/stat exposes cumulative jiffies per core:
|
||||
//
|
||||
// cpuN user nice system idle iowait irq softirq steal guest guest_nice
|
||||
//
|
||||
// We sample every second, compute the delta, and derive:
|
||||
//
|
||||
// usage% = (totalΔ − idleΔ) / totalΔ × 100
|
||||
//
|
||||
// The result is cached behind a RWMutex so the HTTP handler never blocks.
|
||||
|
||||
var (
|
||||
usageMu sync.RWMutex
|
||||
usageCache []CoreUsage
|
||||
)
|
||||
|
||||
func cachedCPUUsage() []CoreUsage {
|
||||
usageMu.RLock()
|
||||
defer usageMu.RUnlock()
|
||||
// Return a copy so callers can't mutate the cache.
|
||||
if usageCache == nil {
|
||||
return nil
|
||||
}
|
||||
out := make([]CoreUsage, len(usageCache))
|
||||
copy(out, usageCache)
|
||||
return out
|
||||
}
|
||||
|
||||
// startCPUSampler launches a goroutine that samples /proc/stat once per second
|
||||
// for the lifetime of the process. Safe to call multiple times (only the first
|
||||
// call starts the goroutine).
|
||||
var samplerOnce sync.Once
|
||||
|
||||
func startCPUSampler() {
|
||||
samplerOnce.Do(func() {
|
||||
go cpuSamplerLoop("/proc/stat", 1*time.Second)
|
||||
})
|
||||
}
|
||||
|
||||
func cpuSamplerLoop(statPath string, interval time.Duration) {
|
||||
prev := readProcStat(statPath)
|
||||
for {
|
||||
time.Sleep(interval)
|
||||
cur := readProcStat(statPath)
|
||||
usage := computeUsage(prev, cur)
|
||||
usageMu.Lock()
|
||||
usageCache = usage
|
||||
usageMu.Unlock()
|
||||
prev = cur
|
||||
}
|
||||
}
|
||||
|
||||
// cpuCoreTicks holds the cumulative jiffies for one "cpuN" line.
|
||||
type cpuCoreTicks struct {
|
||||
core int
|
||||
total uint64
|
||||
idle uint64
|
||||
}
|
||||
|
||||
// readProcStat reads /proc/stat and returns per-core tick totals. The
|
||||
// aggregate "cpu" line (no digit suffix) is skipped.
|
||||
func readProcStat(path string) []cpuCoreTicks {
|
||||
data, _ := os.ReadFile(path)
|
||||
var cores []cpuCoreTicks
|
||||
for line := range strings.SplitSeq(string(data), "\n") {
|
||||
if !strings.HasPrefix(line, "cpu") {
|
||||
continue
|
||||
}
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) < 5 {
|
||||
continue
|
||||
}
|
||||
// Skip the aggregate "cpu" line; we only want "cpu0", "cpu1", …
|
||||
name := fields[0]
|
||||
if name == "cpu" {
|
||||
continue
|
||||
}
|
||||
coreIdx, err := strconv.Atoi(strings.TrimPrefix(name, "cpu"))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
// Fields: user(1) nice(2) system(3) idle(4) iowait(5) irq(6) softirq(7) steal(8) …
|
||||
var total, idle uint64
|
||||
for _, f := range fields[1:] {
|
||||
v, _ := strconv.ParseUint(f, 10, 64)
|
||||
total += v
|
||||
}
|
||||
// idle = idle + iowait (indices 4 and 5 in the original line).
|
||||
if len(fields) > 5 {
|
||||
v4, _ := strconv.ParseUint(fields[4], 10, 64)
|
||||
v5, _ := strconv.ParseUint(fields[5], 10, 64)
|
||||
idle = v4 + v5
|
||||
} else {
|
||||
v4, _ := strconv.ParseUint(fields[4], 10, 64)
|
||||
idle = v4
|
||||
}
|
||||
cores = append(cores, cpuCoreTicks{core: coreIdx, total: total, idle: idle})
|
||||
}
|
||||
return cores
|
||||
}
|
||||
|
||||
func computeUsage(prev, cur []cpuCoreTicks) []CoreUsage {
|
||||
prevMap := make(map[int]cpuCoreTicks, len(prev))
|
||||
for _, c := range prev {
|
||||
prevMap[c.core] = c
|
||||
}
|
||||
usage := make([]CoreUsage, 0, len(cur))
|
||||
for _, c := range cur {
|
||||
p, ok := prevMap[c.core]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
dTotal := c.total - p.total
|
||||
dIdle := c.idle - p.idle
|
||||
var pct float64
|
||||
if dTotal > 0 {
|
||||
pct = float64(dTotal-dIdle) / float64(dTotal) * 100
|
||||
// Round to one decimal.
|
||||
pct = math.Round(pct*10) / 10
|
||||
}
|
||||
usage = append(usage, CoreUsage{Core: c.core, UsagePct: pct})
|
||||
}
|
||||
return usage
|
||||
}
|
||||
|
||||
// uptimeAndBoot reads /proc/uptime (seconds since boot) and derives boot time.
|
||||
// On any read error it returns zero values rather than failing the request.
|
||||
func uptimeAndBoot() (int64, time.Time) {
|
||||
data, err := os.ReadFile("/proc/uptime")
|
||||
if err != nil {
|
||||
return 0, time.Time{}
|
||||
}
|
||||
fields := strings.Fields(string(data))
|
||||
if len(fields) == 0 {
|
||||
return 0, time.Time{}
|
||||
}
|
||||
secs, err := strconv.ParseFloat(fields[0], 64)
|
||||
if err != nil {
|
||||
return 0, time.Time{}
|
||||
}
|
||||
boot := time.Now().Add(-time.Duration(secs * float64(time.Second))).UTC()
|
||||
return int64(secs), boot
|
||||
}
|
||||
|
||||
func diskInfo() []DiskInfo {
|
||||
entries, err := mounts.Proc()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
disks := []DiskInfo{}
|
||||
seen := map[string]bool{}
|
||||
for _, e := range entries {
|
||||
// Only real block devices; skip pseudo filesystems and snap's squashfs
|
||||
// loop mounts that would otherwise clutter the list.
|
||||
if !strings.HasPrefix(e.Device, "/dev/") || e.FSType == "squashfs" || seen[e.Mountpoint] {
|
||||
continue
|
||||
}
|
||||
var st syscall.Statfs_t
|
||||
if syscall.Statfs(e.Mountpoint, &st) != nil || st.Blocks == 0 {
|
||||
continue
|
||||
}
|
||||
seen[e.Mountpoint] = true
|
||||
bs := uint64(st.Bsize)
|
||||
disks = append(disks, DiskInfo{
|
||||
Mountpoint: e.Mountpoint,
|
||||
Filesystem: e.Device,
|
||||
FSType: e.FSType,
|
||||
TotalBytes: st.Blocks * bs,
|
||||
FreeBytes: st.Bavail * bs,
|
||||
UsedBytes: (st.Blocks - st.Bfree) * bs,
|
||||
})
|
||||
}
|
||||
return disks
|
||||
}
|
||||
|
||||
func netInfo() []NetInterface {
|
||||
ifaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
out := []NetInterface{}
|
||||
for _, ifi := range ifaces {
|
||||
addrs, _ := ifi.Addrs()
|
||||
strs := []string{}
|
||||
for _, a := range addrs {
|
||||
strs = append(strs, a.String())
|
||||
}
|
||||
out = append(out, NetInterface{
|
||||
Name: ifi.Name,
|
||||
MAC: ifi.HardwareAddr.String(),
|
||||
Up: ifi.Flags&net.FlagUp != 0,
|
||||
Addresses: strs,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func tempInfo() []Temperature {
|
||||
if t := readHwmonTemps("/sys/class/hwmon"); len(t) > 0 {
|
||||
return t
|
||||
}
|
||||
// ponytail: stock Ubuntu server has no coretemp/k10temp loaded, so hwmon
|
||||
// is empty; thermal_zone exposes ACPI sensors (coarser, no chip name).
|
||||
return readThermalZones("/sys/class/thermal")
|
||||
}
|
||||
|
||||
// readThermalZones reads /sys/class/thermal/thermal_zone*/temp as a fallback
|
||||
// for hosts without hwmon chip drivers. "type" names the zone (e.g. acpitz,
|
||||
// x86_pkg_temp); used as both chip and label.
|
||||
func readThermalZones(root string) []Temperature {
|
||||
zones, _ := filepath.Glob(filepath.Join(root, "thermal_zone*"))
|
||||
temps := []Temperature{}
|
||||
for _, dir := range zones {
|
||||
milli, err := strconv.Atoi(readTrim(filepath.Join(dir, "temp")))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
c := float64(milli) / 1000
|
||||
if c <= 0 || c >= 150 {
|
||||
continue
|
||||
}
|
||||
name := readTrim(filepath.Join(dir, "type"))
|
||||
if name == "" {
|
||||
name = filepath.Base(dir)
|
||||
}
|
||||
temps = append(temps, Temperature{Chip: name, Label: name, Celsius: c})
|
||||
}
|
||||
return temps
|
||||
}
|
||||
|
||||
// readHwmonTemps walks /sys/class/hwmon, which (unlike /sys/class/thermal, that
|
||||
// only exposes generic ACPI zones like "acpitz") names each chip — so callers
|
||||
// can split CPU (k10temp/coretemp) from GPU (amdgpu/nvidia) from disk (nvme).
|
||||
// Each tempN_input may carry a tempN_label; when absent we fall back to the
|
||||
// chip name. Best-effort: unreadable, empty, or implausible sensors are skipped.
|
||||
func readHwmonTemps(root string) []Temperature {
|
||||
chips, _ := filepath.Glob(filepath.Join(root, "hwmon*"))
|
||||
temps := []Temperature{}
|
||||
for _, dir := range chips {
|
||||
chip := readTrim(filepath.Join(dir, "name"))
|
||||
inputs, _ := filepath.Glob(filepath.Join(dir, "temp*_input"))
|
||||
for _, in := range inputs {
|
||||
milli, err := strconv.Atoi(readTrim(in))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
c := float64(milli) / 1000
|
||||
// Disabled/placeholder sensors report absurd values (e.g. -0.15 or
|
||||
// 179.8 °C). Drop anything outside a plausible band.
|
||||
if c <= 0 || c >= 150 {
|
||||
continue
|
||||
}
|
||||
label := readTrim(strings.TrimSuffix(in, "_input") + "_label")
|
||||
if label == "" {
|
||||
label = chip
|
||||
}
|
||||
temps = append(temps, Temperature{Chip: chip, Label: label, Celsius: c})
|
||||
}
|
||||
}
|
||||
return temps
|
||||
}
|
||||
|
||||
// readTrim reads a sysfs file and trims it; a missing file yields "".
|
||||
func readTrim(path string) string {
|
||||
b, _ := os.ReadFile(path)
|
||||
return strings.TrimSpace(string(b))
|
||||
}
|
||||
@@ -0,0 +1,177 @@
|
||||
package system
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestReadHwmonTemps(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
// k10temp: CPU, labelled Tctl.
|
||||
write(t, root, "hwmon0", "name", "k10temp")
|
||||
write(t, root, "hwmon0", "temp1_input", "62125")
|
||||
write(t, root, "hwmon0", "temp1_label", "Tctl")
|
||||
// amdgpu: GPU, no label -> falls back to chip name. Also a sensor reading
|
||||
// an implausible value that must be dropped.
|
||||
write(t, root, "hwmon1", "name", "amdgpu")
|
||||
write(t, root, "hwmon1", "temp1_input", "56000")
|
||||
write(t, root, "hwmon1", "temp2_input", "-150") // disabled placeholder
|
||||
// empty input -> skipped without error.
|
||||
write(t, root, "hwmon2", "name", "cros_ec")
|
||||
write(t, root, "hwmon2", "temp1_input", "")
|
||||
|
||||
got := readHwmonTemps(root)
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("want 2 temps, got %d: %+v", len(got), got)
|
||||
}
|
||||
if got[0].Chip != "k10temp" || got[0].Label != "Tctl" || got[0].Celsius != 62.125 {
|
||||
t.Errorf("cpu sensor wrong: %+v", got[0])
|
||||
}
|
||||
if got[1].Chip != "amdgpu" || got[1].Label != "amdgpu" || got[1].Celsius != 56 {
|
||||
t.Errorf("gpu sensor wrong (label should fall back to chip): %+v", got[1])
|
||||
}
|
||||
}
|
||||
|
||||
func write(t *testing.T, root, chip, file, val string) {
|
||||
t.Helper()
|
||||
dir := filepath.Join(root, chip)
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(dir, file), []byte(val), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseMeminfo(t *testing.T) {
|
||||
data := []byte(`MemTotal: 16000 kB
|
||||
MemFree: 2000 kB
|
||||
MemAvailable: 8000 kB
|
||||
SwapTotal: 4000 kB
|
||||
SwapFree: 3000 kB
|
||||
`)
|
||||
m := parseMeminfo(data)
|
||||
const kb = 1024
|
||||
if m.TotalBytes != 16000*kb {
|
||||
t.Errorf("TotalBytes = %d", m.TotalBytes)
|
||||
}
|
||||
if m.AvailableBytes != 8000*kb {
|
||||
t.Errorf("AvailableBytes = %d", m.AvailableBytes)
|
||||
}
|
||||
if m.UsedBytes != (16000-8000)*kb {
|
||||
t.Errorf("UsedBytes = %d, want %d", m.UsedBytes, (16000-8000)*kb)
|
||||
}
|
||||
if m.SwapTotalBytes != 4000*kb || m.SwapFreeBytes != 3000*kb {
|
||||
t.Errorf("swap parsed wrong: %+v", m)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseLoadavg(t *testing.T) {
|
||||
l := parseLoadavg("0.42 0.55 0.61 1/234 5678")
|
||||
if l.Load1 != 0.42 || l.Load5 != 0.55 || l.Load15 != 0.61 {
|
||||
t.Errorf("got %+v", l)
|
||||
}
|
||||
bad := parseLoadavg("garbage")
|
||||
if bad.Load1 != 0 || bad.Load5 != 0 || bad.Load15 != 0 {
|
||||
t.Error("short input should yield zero LoadInfo")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCpuFreqMHz(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
// Hardware limits live under cpu0 only.
|
||||
write(t, root, "cpu0/cpufreq", "cpuinfo_min_freq", "400000") // 400 MHz
|
||||
write(t, root, "cpu0/cpufreq", "cpuinfo_max_freq", "5137000") // 5137 MHz
|
||||
// Per-core current frequencies (kHz). Core 2 has the peak.
|
||||
write(t, root, "cpu0/cpufreq", "scaling_cur_freq", "2500000")
|
||||
write(t, root, "cpu1/cpufreq", "scaling_cur_freq", "1100000")
|
||||
write(t, root, "cpu2/cpufreq", "scaling_cur_freq", "3200000")
|
||||
write(t, root, "cpu3/cpufreq", "scaling_cur_freq", "2800000")
|
||||
|
||||
min, max, cur := cpuFreqMHz(root)
|
||||
if min != 400 {
|
||||
t.Errorf("min = %d, want 400", min)
|
||||
}
|
||||
if max != 5137 {
|
||||
t.Errorf("max = %d, want 5137", max)
|
||||
}
|
||||
if cur != 3200 {
|
||||
t.Errorf("cur = %d, want 3200 (peak across cores)", cur)
|
||||
}
|
||||
|
||||
// Absent cpufreq → all zeros.
|
||||
emptyRoot := t.TempDir()
|
||||
min, max, cur = cpuFreqMHz(emptyRoot)
|
||||
if min != 0 || max != 0 || cur != 0 {
|
||||
t.Errorf("absent cpufreq: got min=%d max=%d cur=%d, want all 0", min, max, cur)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCPUModel(t *testing.T) {
|
||||
x86 := "processor\t: 0\nmodel name\t: AMD Ryzen 7 7840U\ncpu MHz\t: 3000\n"
|
||||
if got := cpuModel(x86); got != "AMD Ryzen 7 7840U" {
|
||||
t.Errorf("x86: got %q", got)
|
||||
}
|
||||
arm := "processor\t: 0\nModel\t: Raspberry Pi 5 Model B\n"
|
||||
if got := cpuModel(arm); got != "Raspberry Pi 5 Model B" {
|
||||
t.Errorf("arm fallback: got %q", got)
|
||||
}
|
||||
if got := cpuModel("processor\t: 0\n"); got != "" {
|
||||
t.Errorf("no model: got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadProcStat(t *testing.T) {
|
||||
content := `cpu 100 20 30 400 10 5 3 2 0 0
|
||||
cpu0 50 10 15 200 5 3 1 1 0 0
|
||||
cpu1 50 10 15 200 5 2 2 1 0 0
|
||||
intr 12345
|
||||
`
|
||||
f := filepath.Join(t.TempDir(), "stat")
|
||||
if err := os.WriteFile(f, []byte(content), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cores := readProcStat(f)
|
||||
if len(cores) != 2 {
|
||||
t.Fatalf("want 2 cores, got %d", len(cores))
|
||||
}
|
||||
// cpu0: user=50 nice=10 sys=15 idle=200 iowait=5 irq=3 softirq=1 steal=1 guest=0 guest_nice=0
|
||||
// total = 50+10+15+200+5+3+1+1+0+0 = 285
|
||||
// idle = 200 + 5 = 205
|
||||
if cores[0].core != 0 {
|
||||
t.Errorf("core index = %d, want 0", cores[0].core)
|
||||
}
|
||||
if cores[0].total != 285 {
|
||||
t.Errorf("core0 total = %d, want 285", cores[0].total)
|
||||
}
|
||||
if cores[0].idle != 205 {
|
||||
t.Errorf("core0 idle = %d, want 205", cores[0].idle)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeUsage(t *testing.T) {
|
||||
prev := []cpuCoreTicks{
|
||||
{core: 0, total: 1000, idle: 800},
|
||||
{core: 1, total: 1000, idle: 900},
|
||||
}
|
||||
cur := []cpuCoreTicks{
|
||||
{core: 0, total: 1100, idle: 850}, // 100 total delta, 50 idle delta → 50% usage
|
||||
{core: 1, total: 1200, idle: 950}, // 200 total delta, 50 idle delta → 75% usage
|
||||
}
|
||||
usage := computeUsage(prev, cur)
|
||||
if len(usage) != 2 {
|
||||
t.Fatalf("want 2 entries, got %d", len(usage))
|
||||
}
|
||||
if usage[0].Core != 0 || usage[0].UsagePct != 50.0 {
|
||||
t.Errorf("core0: got %+v, want 50%%", usage[0])
|
||||
}
|
||||
if usage[1].Core != 1 || usage[1].UsagePct != 75.0 {
|
||||
t.Errorf("core1: got %+v, want 75%%", usage[1])
|
||||
}
|
||||
|
||||
// Empty prev → empty result.
|
||||
if got := computeUsage(nil, cur); len(got) != 0 {
|
||||
t.Errorf("nil prev should yield empty, got %d", len(got))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,186 @@
|
||||
package system
|
||||
|
||||
import (
|
||||
"context"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"nadir/internal/oscmd"
|
||||
|
||||
"github.com/danielgtaylor/huma/v2"
|
||||
)
|
||||
|
||||
type LocaleStatusBody struct {
|
||||
Lang string `json:"lang" example:"it_IT.UTF-8" doc:"System locale (LANG)"`
|
||||
VCKeymap string `json:"vc_keymap" example:"it" doc:"Virtual console keymap"`
|
||||
X11Layout string `json:"x11_layout" example:"it" doc:"X11 keyboard layout"`
|
||||
}
|
||||
|
||||
type GetLocaleOutput struct{ Body LocaleStatusBody }
|
||||
|
||||
type LocalesOutput struct {
|
||||
Body struct {
|
||||
Locales []string `json:"locales" doc:"Available locales"`
|
||||
}
|
||||
}
|
||||
|
||||
type KeymapsOutput struct {
|
||||
Body struct {
|
||||
Keymaps []string `json:"keymaps" doc:"Available virtual console keymaps"`
|
||||
}
|
||||
}
|
||||
|
||||
type SetLocaleInput struct {
|
||||
Body struct {
|
||||
Lang string `json:"lang" example:"it_IT.UTF-8" doc:"Locale to set as LANG"`
|
||||
}
|
||||
}
|
||||
|
||||
type SetKeymapInput struct {
|
||||
Body struct {
|
||||
Keymap string `json:"keymap" example:"it" doc:"Virtual console keymap"`
|
||||
}
|
||||
}
|
||||
|
||||
func localeStatus() (LocaleStatusBody, error) {
|
||||
lines, err := oscmd.RunLines("localectl", "status")
|
||||
if err != nil {
|
||||
return LocaleStatusBody{}, err
|
||||
}
|
||||
var b LocaleStatusBody
|
||||
for _, line := range lines {
|
||||
label, val, ok := strings.Cut(line, ":")
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
switch strings.TrimSpace(label) {
|
||||
case "System Locale":
|
||||
for kv := range strings.FieldsSeq(val) {
|
||||
if k, v, ok := strings.Cut(kv, "="); ok && k == "LANG" {
|
||||
b.Lang = v
|
||||
}
|
||||
}
|
||||
case "VC Keymap":
|
||||
b.VCKeymap = strings.TrimSpace(val)
|
||||
case "X11 Layout":
|
||||
b.X11Layout = strings.TrimSpace(val)
|
||||
}
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func registerLocale(api huma.API) {
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "system-get-locale",
|
||||
Method: "GET",
|
||||
Path: "/api/system/locale",
|
||||
Summary: "Get locale and keyboard layout",
|
||||
Description: "Returns the system locale (LANG), virtual console keymap, " +
|
||||
"and X11 layout from localectl.",
|
||||
Tags: []string{tagSystem},
|
||||
Metadata: op("read"),
|
||||
Errors: readErrors,
|
||||
}, func(ctx context.Context, _ *struct{}) (*GetLocaleOutput, error) {
|
||||
b, err := localeStatus()
|
||||
if err != nil {
|
||||
return nil, huma.Error500InternalServerError("localectl failed", err)
|
||||
}
|
||||
return &GetLocaleOutput{Body: b}, nil
|
||||
})
|
||||
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "system-list-locales",
|
||||
Method: "GET",
|
||||
Path: "/api/system/locales",
|
||||
Summary: "List available locales",
|
||||
Description: "Returns every locale the host can be configured to use, " +
|
||||
"suitable for populating a selector.",
|
||||
Tags: []string{tagSystem},
|
||||
Metadata: op("read"),
|
||||
Errors: readErrors,
|
||||
}, func(ctx context.Context, _ *struct{}) (*LocalesOutput, error) {
|
||||
locales, err := oscmd.RunLines("localectl", "list-locales")
|
||||
if err != nil {
|
||||
return nil, huma.Error500InternalServerError("localectl failed", err)
|
||||
}
|
||||
out := &LocalesOutput{}
|
||||
out.Body.Locales = locales
|
||||
return out, nil
|
||||
})
|
||||
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "system-set-locale",
|
||||
Method: "POST",
|
||||
Path: "/api/system/locale",
|
||||
Summary: "Set system locale (LANG)",
|
||||
Description: "Sets LANG via localectl. The value is validated against the " +
|
||||
"host's locale list, so an unknown locale returns 400.",
|
||||
Tags: []string{tagSystem},
|
||||
Metadata: op("write"),
|
||||
Errors: writeErrors,
|
||||
}, func(ctx context.Context, in *SetLocaleInput) (*oscmd.StatusOutput, error) {
|
||||
lang := strings.TrimSpace(in.Body.Lang)
|
||||
if lang == "" {
|
||||
return nil, huma.Error400BadRequest("empty locale")
|
||||
}
|
||||
locales, err := oscmd.RunLines("localectl", "list-locales")
|
||||
if err != nil {
|
||||
return nil, huma.Error500InternalServerError("localectl failed", err)
|
||||
}
|
||||
if !slices.Contains(locales, lang) {
|
||||
return nil, huma.Error400BadRequest("unknown locale: " + lang)
|
||||
}
|
||||
if _, err := oscmd.Run("localectl", "set-locale", "LANG="+lang); err != nil {
|
||||
return nil, huma.Error500InternalServerError("set-locale failed", err)
|
||||
}
|
||||
return oscmd.OK(), nil
|
||||
})
|
||||
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "system-list-keymaps",
|
||||
Method: "GET",
|
||||
Path: "/api/system/keymaps",
|
||||
Summary: "List available console keymaps",
|
||||
Description: "Returns every virtual console keymap known to the host.",
|
||||
Tags: []string{tagSystem},
|
||||
Metadata: op("read"),
|
||||
Errors: readErrors,
|
||||
}, func(ctx context.Context, _ *struct{}) (*KeymapsOutput, error) {
|
||||
keymaps, err := oscmd.RunLines("localectl", "list-keymaps")
|
||||
if err != nil {
|
||||
return nil, huma.Error500InternalServerError("localectl failed", err)
|
||||
}
|
||||
out := &KeymapsOutput{}
|
||||
out.Body.Keymaps = keymaps
|
||||
return out, nil
|
||||
})
|
||||
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "system-set-keymap",
|
||||
Method: "POST",
|
||||
Path: "/api/system/keymap",
|
||||
Summary: "Set virtual console keymap",
|
||||
Description: "Sets the virtual console keymap via localectl. The value is " +
|
||||
"validated against the host's keymap list, so an unknown keymap " +
|
||||
"returns 400.",
|
||||
Tags: []string{tagSystem},
|
||||
Metadata: op("write"),
|
||||
Errors: writeErrors,
|
||||
}, func(ctx context.Context, in *SetKeymapInput) (*oscmd.StatusOutput, error) {
|
||||
km := strings.TrimSpace(in.Body.Keymap)
|
||||
if km == "" {
|
||||
return nil, huma.Error400BadRequest("empty keymap")
|
||||
}
|
||||
keymaps, err := oscmd.RunLines("localectl", "list-keymaps")
|
||||
if err != nil {
|
||||
return nil, huma.Error500InternalServerError("localectl failed", err)
|
||||
}
|
||||
if !slices.Contains(keymaps, km) {
|
||||
return nil, huma.Error400BadRequest("unknown keymap: " + km)
|
||||
}
|
||||
if _, err := oscmd.Run("localectl", "set-keymap", km); err != nil {
|
||||
return nil, huma.Error500InternalServerError("set-keymap failed", err)
|
||||
}
|
||||
return oscmd.OK(), nil
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,35 @@
|
||||
package system
|
||||
|
||||
import (
|
||||
"nadir/internal/rbac"
|
||||
|
||||
"github.com/danielgtaylor/huma/v2"
|
||||
)
|
||||
|
||||
const ModuleID = "system"
|
||||
|
||||
type Module struct{}
|
||||
|
||||
func New() *Module { return &Module{} }
|
||||
|
||||
func (m *Module) ID() string { return ModuleID }
|
||||
func (m *Module) Name() string { return "System" }
|
||||
|
||||
func (m *Module) Permissions() []rbac.Permission {
|
||||
return []rbac.Permission{rbac.Read, rbac.Write, rbac.Root}
|
||||
}
|
||||
|
||||
func (m *Module) Register(api huma.API) {
|
||||
registerInfo(api)
|
||||
registerHostname(api)
|
||||
registerTimedate(api)
|
||||
registerLocale(api)
|
||||
registerPower(api)
|
||||
}
|
||||
|
||||
func op(permission string) map[string]any {
|
||||
return map[string]any{
|
||||
"module": ModuleID,
|
||||
"permission": permission,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,73 @@
|
||||
package system
|
||||
|
||||
import (
|
||||
"context"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"nadir/internal/oscmd"
|
||||
|
||||
"github.com/danielgtaylor/huma/v2"
|
||||
)
|
||||
|
||||
// whenRe matches the shutdown(8) TIME forms we accept: "now", "+<minutes>", or
|
||||
// "hh:mm". Validating before exec keeps a value like "-r" or "--help" from
|
||||
// being read as a flag (shutdown does not honor a "--" separator), per the
|
||||
// repo's input-validation rule.
|
||||
var whenRe = regexp.MustCompile(`^(now|\+\d+|([01]?\d|2[0-3]):[0-5]\d)$`)
|
||||
|
||||
type PowerInput struct {
|
||||
Body struct {
|
||||
When string `json:"when,omitempty" example:"+1" doc:"When to act, as a shutdown(8) TIME (e.g. \"+5\" minutes, \"23:00\"). Empty or 'now' = immediate."`
|
||||
}
|
||||
}
|
||||
|
||||
const powerDescription = "Requires the `root` permission. The response is sent once `shutdown` " +
|
||||
"accepts the request; for the immediate form it returns before the machine " +
|
||||
"actually goes down, so a 200 does not guarantee a clean shutdown completed."
|
||||
|
||||
func registerPower(api huma.API) {
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "system-reboot",
|
||||
Method: "POST",
|
||||
Path: "/api/system/reboot",
|
||||
Summary: "Reboot the system",
|
||||
Description: powerDescription,
|
||||
Tags: []string{tagSystem},
|
||||
Metadata: op("root"),
|
||||
Errors: []int{400, 401, 403, 500},
|
||||
}, func(ctx context.Context, in *PowerInput) (*oscmd.StatusOutput, error) {
|
||||
return schedulePower("reboot", in.Body.When)
|
||||
})
|
||||
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "system-poweroff",
|
||||
Method: "POST",
|
||||
Path: "/api/system/poweroff",
|
||||
Summary: "Power off the system",
|
||||
Description: powerDescription,
|
||||
Tags: []string{tagSystem},
|
||||
Metadata: op("root"),
|
||||
Errors: []int{400, 401, 403, 500},
|
||||
}, func(ctx context.Context, in *PowerInput) (*oscmd.StatusOutput, error) {
|
||||
return schedulePower("poweroff", in.Body.When)
|
||||
})
|
||||
}
|
||||
|
||||
func schedulePower(action, when string) (*oscmd.StatusOutput, error) {
|
||||
w := strings.TrimSpace(when)
|
||||
if w == "" {
|
||||
w = "now"
|
||||
}
|
||||
if !whenRe.MatchString(w) {
|
||||
return nil, huma.Error400BadRequest("invalid when: " + w)
|
||||
}
|
||||
flag := "-h"
|
||||
if action == "reboot" {
|
||||
flag = "-r"
|
||||
}
|
||||
if _, err := oscmd.Run("shutdown", flag, w); err != nil {
|
||||
return nil, huma.Error500InternalServerError("shutdown failed", err)
|
||||
}
|
||||
return oscmd.OK(), nil
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
package system
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestWhenRe(t *testing.T) {
|
||||
valid := []string{"now", "+0", "+5", "+120", "0:00", "9:30", "23:59", "07:05"}
|
||||
for _, w := range valid {
|
||||
if !whenRe.MatchString(w) {
|
||||
t.Errorf("whenRe.MatchString(%q) = false, want true", w)
|
||||
}
|
||||
}
|
||||
invalid := []string{"", "-r", "-h", "--help", "24:00", "9:60", "+5; reboot", "now ", "5", "1:2"}
|
||||
for _, w := range invalid {
|
||||
if whenRe.MatchString(w) {
|
||||
t.Errorf("whenRe.MatchString(%q) = true, want false", w)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,235 @@
|
||||
package system
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"os"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"nadir/internal/oscmd"
|
||||
|
||||
"github.com/danielgtaylor/huma/v2"
|
||||
"github.com/danielgtaylor/huma/v2/adapters/humago"
|
||||
"github.com/danielgtaylor/huma/v2/humatest"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
if oscmd.RunHelperProcess() {
|
||||
return
|
||||
}
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
func TestSystemHandlers(t *testing.T) {
|
||||
mux := http.NewServeMux()
|
||||
api := humatest.Wrap(t, humago.New(mux, huma.DefaultConfig("Test", "1.0.0")))
|
||||
|
||||
registerHostname(api)
|
||||
registerTimedate(api)
|
||||
registerLocale(api)
|
||||
registerPower(api)
|
||||
registerInfo(api)
|
||||
|
||||
// Mock uname for GET /api/system/info
|
||||
oscmd.SetMock("uname", func(args []string) oscmd.MockCommand {
|
||||
if reflect.DeepEqual(args, []string{"-r"}) {
|
||||
return oscmd.MockCommand{Stdout: "6.9.1-1-test\n", ExitCode: 0}
|
||||
}
|
||||
if reflect.DeepEqual(args, []string{"-m"}) {
|
||||
return oscmd.MockCommand{Stdout: "x86_64\n", ExitCode: 0}
|
||||
}
|
||||
return oscmd.MockCommand{ExitCode: 1}
|
||||
})
|
||||
defer oscmd.ClearMocks()
|
||||
|
||||
// 1. Test GET /api/system/info
|
||||
resp := api.Get("/api/system/info")
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("get info: got %d, want %d", resp.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
// 2. Test GET & POST /api/system/hostname
|
||||
oscmd.SetMock("hostnamectl", func(args []string) oscmd.MockCommand {
|
||||
if reflect.DeepEqual(args, []string{"hostname"}) {
|
||||
return oscmd.MockCommand{Stdout: "server01\n", ExitCode: 0}
|
||||
}
|
||||
if reflect.DeepEqual(args, []string{"set-hostname", "server02"}) {
|
||||
return oscmd.MockCommand{ExitCode: 0}
|
||||
}
|
||||
return oscmd.MockCommand{ExitCode: 1}
|
||||
})
|
||||
|
||||
resp = api.Get("/api/system/hostname")
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("get hostname: got %d, want %d", resp.Code, http.StatusOK)
|
||||
}
|
||||
var hostnameRes GetHostnameOutput
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &hostnameRes.Body); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if hostnameRes.Body.Hostname != "server01" {
|
||||
t.Errorf("got hostname %q, want %q", hostnameRes.Body.Hostname, "server01")
|
||||
}
|
||||
|
||||
resp = api.Post("/api/system/hostname", struct {
|
||||
Hostname string `json:"hostname"`
|
||||
}{
|
||||
Hostname: "server02",
|
||||
})
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("set hostname: got %d, want %d", resp.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
// 3. Test GET & POST /api/system/time
|
||||
oscmd.SetMock("timedatectl", func(args []string) oscmd.MockCommand {
|
||||
if reflect.DeepEqual(args, []string{"show"}) {
|
||||
showOut := "Timezone=Europe/Rome\nLocalRTC=no\nNTP=yes\nNTPSynchronized=yes\nCanNTP=yes\n"
|
||||
return oscmd.MockCommand{Stdout: showOut, ExitCode: 0}
|
||||
}
|
||||
if reflect.DeepEqual(args, []string{"list-timezones"}) {
|
||||
return oscmd.MockCommand{Stdout: "Europe/Rome\nUTC\n", ExitCode: 0}
|
||||
}
|
||||
if reflect.DeepEqual(args, []string{"set-timezone", "Europe/Rome"}) {
|
||||
return oscmd.MockCommand{ExitCode: 0}
|
||||
}
|
||||
if reflect.DeepEqual(args, []string{"set-ntp", "true"}) {
|
||||
return oscmd.MockCommand{ExitCode: 0}
|
||||
}
|
||||
if len(args) == 2 && args[0] == "set-time" {
|
||||
return oscmd.MockCommand{ExitCode: 0}
|
||||
}
|
||||
return oscmd.MockCommand{ExitCode: 1}
|
||||
})
|
||||
|
||||
resp = api.Get("/api/system/time")
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("get time: got %d, want %d", resp.Code, http.StatusOK)
|
||||
}
|
||||
var timeRes GetTimeOutput
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &timeRes.Body); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if timeRes.Body.Timezone != "Europe/Rome" || !timeRes.Body.NTP {
|
||||
t.Errorf("got time settings: %+v", timeRes.Body)
|
||||
}
|
||||
|
||||
resp = api.Get("/api/system/timezones")
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("list timezones: got %d, want %d", resp.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
resp = api.Post("/api/system/timezone", struct {
|
||||
Timezone string `json:"timezone"`
|
||||
}{
|
||||
Timezone: "Europe/Rome",
|
||||
})
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("set timezone: got %d, want %d", resp.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
resp = api.Post("/api/system/ntp", struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
}{
|
||||
Enabled: true,
|
||||
})
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("set ntp: got %d, want %d", resp.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
resp = api.Post("/api/system/time", struct {
|
||||
Time string `json:"time"`
|
||||
}{
|
||||
Time: "2026-06-20T12:00:00Z",
|
||||
})
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("set time: got %d, want %d", resp.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
// 4. Test GET & POST /api/system/locale
|
||||
oscmd.SetMock("localectl", func(args []string) oscmd.MockCommand {
|
||||
if reflect.DeepEqual(args, []string{"status"}) {
|
||||
statusOut := " System Locale: LANG=it_IT.UTF-8\n VC Keymap: it\n X11 Layout: it\n"
|
||||
return oscmd.MockCommand{Stdout: statusOut, ExitCode: 0}
|
||||
}
|
||||
if reflect.DeepEqual(args, []string{"list-locales"}) {
|
||||
return oscmd.MockCommand{Stdout: "it_IT.UTF-8\nen_US.UTF-8\n", ExitCode: 0}
|
||||
}
|
||||
if reflect.DeepEqual(args, []string{"set-locale", "LANG=it_IT.UTF-8"}) {
|
||||
return oscmd.MockCommand{ExitCode: 0}
|
||||
}
|
||||
if reflect.DeepEqual(args, []string{"list-keymaps"}) {
|
||||
return oscmd.MockCommand{Stdout: "it\nus\n", ExitCode: 0}
|
||||
}
|
||||
if reflect.DeepEqual(args, []string{"set-keymap", "it"}) {
|
||||
return oscmd.MockCommand{ExitCode: 0}
|
||||
}
|
||||
return oscmd.MockCommand{ExitCode: 1}
|
||||
})
|
||||
|
||||
resp = api.Get("/api/system/locale")
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("get locale: got %d, want %d", resp.Code, http.StatusOK)
|
||||
}
|
||||
var localeRes GetLocaleOutput
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &localeRes.Body); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if localeRes.Body.Lang != "it_IT.UTF-8" || localeRes.Body.VCKeymap != "it" {
|
||||
t.Errorf("got locale status: %+v", localeRes.Body)
|
||||
}
|
||||
|
||||
resp = api.Get("/api/system/locales")
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("list locales: got %d, want %d", resp.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
resp = api.Post("/api/system/locale", struct {
|
||||
Lang string `json:"lang"`
|
||||
}{
|
||||
Lang: "it_IT.UTF-8",
|
||||
})
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("set locale: got %d, want %d", resp.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
resp = api.Get("/api/system/keymaps")
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("list keymaps: got %d, want %d", resp.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
resp = api.Post("/api/system/keymap", struct {
|
||||
Keymap string `json:"keymap"`
|
||||
}{
|
||||
Keymap: "it",
|
||||
})
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("set keymap: got %d, want %d", resp.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
// 5. Test POST /api/system/reboot and /api/system/poweroff
|
||||
oscmd.SetMock("shutdown", func(args []string) oscmd.MockCommand {
|
||||
if reflect.DeepEqual(args, []string{"-r", "now"}) || reflect.DeepEqual(args, []string{"-h", "now"}) {
|
||||
return oscmd.MockCommand{ExitCode: 0}
|
||||
}
|
||||
return oscmd.MockCommand{ExitCode: 1}
|
||||
})
|
||||
|
||||
resp = api.Post("/api/system/reboot", struct {
|
||||
When string `json:"when"`
|
||||
}{
|
||||
When: "now",
|
||||
})
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("reboot: got %d, want %d", resp.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
resp = api.Post("/api/system/poweroff", struct {
|
||||
When string `json:"when"`
|
||||
}{
|
||||
When: "now",
|
||||
})
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("poweroff: got %d, want %d", resp.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,200 @@
|
||||
package system
|
||||
|
||||
import (
|
||||
"context"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"nadir/internal/oscmd"
|
||||
|
||||
"github.com/danielgtaylor/huma/v2"
|
||||
)
|
||||
|
||||
type TimeStatusBody struct {
|
||||
Timezone string `json:"timezone" example:"Europe/Rome" doc:"IANA timezone name"`
|
||||
LocalRTC bool `json:"local_rtc" doc:"Hardware clock kept in local time instead of UTC"`
|
||||
NTP bool `json:"ntp" doc:"Network time synchronization enabled"`
|
||||
NTPSynchronized bool `json:"ntp_synchronized" doc:"Clock is currently synchronized"`
|
||||
CanNTP bool `json:"can_ntp" doc:"An NTP service is available on this host"`
|
||||
Time string `json:"time" example:"2026-06-19T13:36:31Z" doc:"Current system time (RFC3339, UTC)"`
|
||||
}
|
||||
|
||||
type GetTimeOutput struct{ Body TimeStatusBody }
|
||||
|
||||
type TimezonesOutput struct {
|
||||
Body struct {
|
||||
Timezones []string `json:"timezones" doc:"Available IANA timezone names"`
|
||||
}
|
||||
}
|
||||
|
||||
type SetTimezoneInput struct {
|
||||
Body struct {
|
||||
Timezone string `json:"timezone" example:"Europe/Rome" doc:"IANA timezone name"`
|
||||
}
|
||||
}
|
||||
|
||||
type SetNTPInput struct {
|
||||
Body struct {
|
||||
Enabled bool `json:"enabled" doc:"Enable network time synchronization"`
|
||||
}
|
||||
}
|
||||
|
||||
type SetTimeInput struct {
|
||||
Body struct {
|
||||
Time string `json:"time" example:"2026-06-19T13:36:31Z" doc:"New time (RFC3339). Requires NTP disabled."`
|
||||
}
|
||||
}
|
||||
|
||||
func timedatectlShow() (map[string]string, error) {
|
||||
lines, err := oscmd.RunLines("timedatectl", "show")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return oscmd.ParseKV(lines), nil
|
||||
}
|
||||
|
||||
// readTimeStatus builds the current time/timezone/sync status. Shared by the
|
||||
// GET endpoint and the write endpoints, which return the resulting state so
|
||||
// clients render ground truth instead of guessing after an opaque "ok". Note
|
||||
// NTPSynchronized lags NTP by seconds-to-minutes while the NTP daemon converges.
|
||||
func readTimeStatus() (TimeStatusBody, error) {
|
||||
m, err := timedatectlShow()
|
||||
if err != nil {
|
||||
return TimeStatusBody{}, err
|
||||
}
|
||||
return TimeStatusBody{
|
||||
Timezone: m["Timezone"],
|
||||
LocalRTC: m["LocalRTC"] == "yes",
|
||||
NTP: m["NTP"] == "yes",
|
||||
NTPSynchronized: m["NTPSynchronized"] == "yes",
|
||||
CanNTP: m["CanNTP"] == "yes",
|
||||
Time: time.Now().UTC().Format(time.RFC3339),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func registerTimedate(api huma.API) {
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "system-get-time",
|
||||
Method: "GET",
|
||||
Path: "/api/system/time",
|
||||
Summary: "Get time, timezone and sync status",
|
||||
Description: "Returns the current UTC time plus timezone and " +
|
||||
"synchronization state from timedatectl.",
|
||||
Tags: []string{tagSystem},
|
||||
Metadata: op("read"),
|
||||
Errors: readErrors,
|
||||
}, func(ctx context.Context, _ *struct{}) (*GetTimeOutput, error) {
|
||||
body, err := readTimeStatus()
|
||||
if err != nil {
|
||||
return nil, huma.Error500InternalServerError("timedatectl failed", err)
|
||||
}
|
||||
return &GetTimeOutput{Body: body}, nil
|
||||
})
|
||||
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "system-list-timezones",
|
||||
Method: "GET",
|
||||
Path: "/api/system/timezones",
|
||||
Summary: "List available timezones",
|
||||
Description: "Returns every IANA timezone name known to the host, " +
|
||||
"suitable for populating a selector.",
|
||||
Tags: []string{tagSystem},
|
||||
Metadata: op("read"),
|
||||
Errors: readErrors,
|
||||
}, func(ctx context.Context, _ *struct{}) (*TimezonesOutput, error) {
|
||||
zones, err := oscmd.RunLines("timedatectl", "list-timezones")
|
||||
if err != nil {
|
||||
return nil, huma.Error500InternalServerError("timedatectl failed", err)
|
||||
}
|
||||
out := &TimezonesOutput{}
|
||||
out.Body.Timezones = zones
|
||||
return out, nil
|
||||
})
|
||||
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "system-set-timezone",
|
||||
Method: "POST",
|
||||
Path: "/api/system/timezone",
|
||||
Summary: "Set system timezone",
|
||||
Description: "Sets the timezone via timedatectl. The value is validated " +
|
||||
"against the host's timezone list, so an unknown name returns 400.",
|
||||
Tags: []string{tagSystem},
|
||||
Metadata: op("write"),
|
||||
Errors: writeErrors,
|
||||
}, func(ctx context.Context, in *SetTimezoneInput) (*oscmd.StatusOutput, error) {
|
||||
tz := strings.TrimSpace(in.Body.Timezone)
|
||||
if tz == "" {
|
||||
return nil, huma.Error400BadRequest("empty timezone")
|
||||
}
|
||||
zones, err := oscmd.RunLines("timedatectl", "list-timezones")
|
||||
if err != nil {
|
||||
return nil, huma.Error500InternalServerError("timedatectl failed", err)
|
||||
}
|
||||
if !slices.Contains(zones, tz) {
|
||||
return nil, huma.Error400BadRequest("unknown timezone: " + tz)
|
||||
}
|
||||
if _, err := oscmd.Run("timedatectl", "set-timezone", tz); err != nil {
|
||||
return nil, huma.Error500InternalServerError("set-timezone failed", err)
|
||||
}
|
||||
return oscmd.OK(), nil
|
||||
})
|
||||
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "system-set-ntp",
|
||||
Method: "POST",
|
||||
Path: "/api/system/ntp",
|
||||
Summary: "Enable or disable time synchronization",
|
||||
Description: "Toggles network time synchronization via " +
|
||||
"`timedatectl set-ntp`. Selecting specific NTP servers is not yet " +
|
||||
"supported. Returns the resulting time status: on enable, `ntp` is " +
|
||||
"true immediately, but `ntp_synchronized` stays false until the NTP " +
|
||||
"daemon converges (seconds to minutes).",
|
||||
Tags: []string{tagSystem},
|
||||
Metadata: op("write"),
|
||||
Errors: []int{400, 401, 403, 409, 500},
|
||||
}, func(ctx context.Context, in *SetNTPInput) (*GetTimeOutput, error) {
|
||||
// Enabling is a no-op when no NTP daemon is installed (CanNTP=no):
|
||||
// timedatectl reports success but nothing syncs. Reject it clearly.
|
||||
if in.Body.Enabled {
|
||||
if m, err := timedatectlShow(); err == nil && m["CanNTP"] != "yes" {
|
||||
return nil, huma.Error409Conflict("cannot enable NTP: no NTP service available on this host")
|
||||
}
|
||||
}
|
||||
val := strconv.FormatBool(in.Body.Enabled)
|
||||
if _, err := oscmd.Run("timedatectl", "set-ntp", val); err != nil {
|
||||
return nil, huma.Error500InternalServerError("set-ntp failed", err)
|
||||
}
|
||||
body, err := readTimeStatus()
|
||||
if err != nil {
|
||||
return nil, huma.Error500InternalServerError("set-ntp succeeded but reading status failed", err)
|
||||
}
|
||||
return &GetTimeOutput{Body: body}, nil
|
||||
})
|
||||
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "system-set-time",
|
||||
Method: "POST",
|
||||
Path: "/api/system/time",
|
||||
Summary: "Set system time manually",
|
||||
Description: "Sets the clock to an explicit RFC3339 time. This only works " +
|
||||
"when NTP is disabled; otherwise timedatectl refuses and the call " +
|
||||
"returns 409.",
|
||||
Tags: []string{tagSystem},
|
||||
Metadata: op("write"),
|
||||
Errors: []int{400, 401, 403, 409, 500},
|
||||
}, func(ctx context.Context, in *SetTimeInput) (*oscmd.StatusOutput, error) {
|
||||
t, err := time.Parse(time.RFC3339, strings.TrimSpace(in.Body.Time))
|
||||
if err != nil {
|
||||
return nil, huma.Error400BadRequest("time must be RFC3339", err)
|
||||
}
|
||||
// timedatectl set-time interprets its argument as local wall-clock time.
|
||||
stamp := t.Local().Format("2006-01-02 15:04:05")
|
||||
if _, err := oscmd.Run("timedatectl", "set-time", stamp); err != nil {
|
||||
// timedatectl refuses while NTP is active; make that actionable.
|
||||
return nil, huma.Error409Conflict("set-time failed (disable NTP first?)", err)
|
||||
}
|
||||
return oscmd.OK(), nil
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,210 @@
|
||||
package terminal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"nadir/internal/auth"
|
||||
"nadir/internal/module"
|
||||
"nadir/internal/rbac"
|
||||
|
||||
"github.com/coder/websocket"
|
||||
"github.com/creack/pty"
|
||||
"github.com/danielgtaylor/huma/v2"
|
||||
"github.com/danielgtaylor/huma/v2/adapters/humago"
|
||||
)
|
||||
|
||||
// tagTerminal is the OpenAPI tag for this module (registered in server.go),
|
||||
// keeping tags 1:1 with modules per the project convention.
|
||||
const tagTerminal = "Terminal"
|
||||
|
||||
const (
|
||||
// maxTerminals caps concurrent shells so a buggy frontend or careless admin
|
||||
// can't pile up PTYs (guideline: limit everything). Raise if it's ever a real
|
||||
// limit in practice.
|
||||
maxTerminals = 10
|
||||
// idleTimeout closes a session after this long with no I/O in either
|
||||
// direction, reclaiming abandoned shells.
|
||||
idleTimeout = 15 * time.Minute
|
||||
)
|
||||
|
||||
// terminalSem is the concurrency limiter; a slot is held for the life of a session.
|
||||
var terminalSem = make(chan struct{}, maxTerminals)
|
||||
|
||||
type terminalModule struct {
|
||||
sessions *auth.SessionStore
|
||||
}
|
||||
|
||||
// New creates a new Terminal module that allows interactive shell access.
|
||||
func New(sessions *auth.SessionStore) module.Module {
|
||||
return &terminalModule{sessions: sessions}
|
||||
}
|
||||
|
||||
func (m *terminalModule) ID() string { return "terminal" }
|
||||
func (m *terminalModule) Name() string { return "Terminal" }
|
||||
|
||||
func (m *terminalModule) Permissions() []rbac.Permission {
|
||||
return []rbac.Permission{rbac.Root}
|
||||
}
|
||||
|
||||
type TerminalInput struct {
|
||||
Ctx huma.Context `json:"-"`
|
||||
}
|
||||
|
||||
// Resolve extracts the huma.Context into the input struct.
|
||||
func (i *TerminalInput) Resolve(ctx huma.Context) []error {
|
||||
i.Ctx = ctx
|
||||
return nil
|
||||
}
|
||||
|
||||
type resizeMessage struct {
|
||||
Cols uint16 `json:"cols"`
|
||||
Rows uint16 `json:"rows"`
|
||||
}
|
||||
|
||||
func (m *terminalModule) Register(api huma.API) {
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "terminal-connect",
|
||||
Method: "GET",
|
||||
Path: "/api/terminal",
|
||||
Summary: "Connect to an interactive terminal",
|
||||
Description: "Upgrades the connection to a WebSocket and spawns a PTY shell as the logged-in user. Send JSON `{cols, rows}` text messages to resize, and raw binary/text messages for stdin. This is a raw WebSocket endpoint — it cannot be exercised from the API docs \"Try it\" panel; use a WebSocket client.",
|
||||
Tags: []string{tagTerminal},
|
||||
Metadata: map[string]any{"module": m.ID(), "permission": string(rbac.Root)},
|
||||
Errors: []int{401, 403, 426, 500},
|
||||
}, func(ctx context.Context, in *TerminalInput) (*struct{}, error) {
|
||||
// The RBAC middleware already authenticated this request and enforced the
|
||||
// "root" permission before we got here. We re-read the session only to get
|
||||
// the username for `su`; the 401 below is the fallback for when the module
|
||||
// is mounted without the middleware (e.g. in unit tests).
|
||||
cookie, err := huma.ReadCookie(in.Ctx, "nadir_session_id")
|
||||
if err != nil || cookie == nil {
|
||||
return nil, huma.Error401Unauthorized("unauthorized")
|
||||
}
|
||||
|
||||
sess, ok := m.sessions.GetByToken(cookie.Value)
|
||||
if !ok {
|
||||
return nil, huma.Error401Unauthorized("unauthorized")
|
||||
}
|
||||
|
||||
req, res := humago.Unwrap(in.Ctx)
|
||||
if req == nil || res == nil {
|
||||
return nil, huma.Error500InternalServerError("missing http context")
|
||||
}
|
||||
|
||||
// Reject plain GETs (e.g. the docs "Try it" button) with a clear 426 rather
|
||||
// than letting websocket.Accept emit a raw protocol-violation error.
|
||||
if !strings.EqualFold(req.Header.Get("Upgrade"), "websocket") {
|
||||
return nil, huma.NewError(http.StatusUpgradeRequired,
|
||||
"this endpoint requires a WebSocket connection; connect with a WebSocket client")
|
||||
}
|
||||
|
||||
// InsecureSkipVerify is deliberately NOT set: coder/websocket then enforces
|
||||
// that the Origin host matches the request Host, rejecting cross-site upgrade
|
||||
// attempts. Defense-in-depth on top of the SameSite=Strict session cookie —
|
||||
// important because this endpoint hands out an interactive shell.
|
||||
conn, err := websocket.Accept(res, req, nil)
|
||||
if err != nil {
|
||||
// websocket.Accept already wrote the error response (e.g. 403 on an
|
||||
// Origin mismatch). Just stop; writing again would corrupt the response.
|
||||
return nil, nil
|
||||
}
|
||||
defer conn.CloseNow()
|
||||
|
||||
// Bound concurrent shells: take a slot or reject (don't pile up PTYs).
|
||||
select {
|
||||
case terminalSem <- struct{}{}:
|
||||
defer func() { <-terminalSem }()
|
||||
default:
|
||||
conn.Close(websocket.StatusTryAgainLater, "too many terminal sessions")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Launch the user's login shell via su.
|
||||
// "su - <username>" ensures we get their actual environment and shell.
|
||||
cmd := exec.CommandContext(req.Context(), "su", "-", sess.Username)
|
||||
|
||||
// Start the command with a PTY.
|
||||
ptmx, err := pty.Start(cmd)
|
||||
if err != nil {
|
||||
conn.Close(websocket.StatusInternalError, "failed to start pty")
|
||||
return nil, nil
|
||||
}
|
||||
defer ptmx.Close()
|
||||
|
||||
// lastActive is bumped by both pumps; the watchdog uses it to close idle
|
||||
// sessions. Output activity (e.g. `top`) counts, so it isn't killed.
|
||||
var lastActive atomic.Int64
|
||||
lastActive.Store(time.Now().UnixNano())
|
||||
go func() {
|
||||
tick := time.NewTicker(idleTimeout / 4)
|
||||
defer tick.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-req.Context().Done():
|
||||
return
|
||||
case <-tick.C:
|
||||
if time.Since(time.Unix(0, lastActive.Load())) > idleTimeout {
|
||||
conn.Close(websocket.StatusGoingAway, "idle timeout")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Pump stdout/stderr from PTY to WebSocket.
|
||||
go func() {
|
||||
buf := make([]byte, 8192)
|
||||
for {
|
||||
n, err := ptmx.Read(buf)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
lastActive.Store(time.Now().UnixNano())
|
||||
// We write PTY output as binary messages. The frontend (e.g., xterm.js)
|
||||
// can handle UTF-8 binary or text transparently.
|
||||
err = conn.Write(req.Context(), websocket.MessageBinary, buf[:n])
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
conn.Close(websocket.StatusNormalClosure, "")
|
||||
}()
|
||||
|
||||
// Pump stdin and resize commands from WebSocket to PTY.
|
||||
for {
|
||||
typ, b, err := conn.Read(req.Context())
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
lastActive.Store(time.Now().UnixNano())
|
||||
|
||||
if typ == websocket.MessageText {
|
||||
var resize resizeMessage
|
||||
if err := json.Unmarshal(b, &resize); err == nil && resize.Cols > 0 && resize.Rows > 0 {
|
||||
// Handle resize
|
||||
_ = pty.Setsize(ptmx, &pty.Winsize{
|
||||
Cols: resize.Cols,
|
||||
Rows: resize.Rows,
|
||||
})
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// If not a valid resize message, or if it's MessageBinary, pass to PTY stdin.
|
||||
_, _ = ptmx.Write(b)
|
||||
}
|
||||
|
||||
// The read loop has ended (client gone or PTY closed). Tear down the shell
|
||||
// and reap it: closing the PTY sends EOF, Kill covers shells that ignore it.
|
||||
_ = ptmx.Close()
|
||||
_ = cmd.Process.Kill()
|
||||
_ = cmd.Wait()
|
||||
return nil, nil
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,124 @@
|
||||
package terminal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"nadir/internal/auth"
|
||||
|
||||
"github.com/coder/websocket"
|
||||
"github.com/danielgtaylor/huma/v2"
|
||||
"github.com/danielgtaylor/huma/v2/adapters/humago"
|
||||
)
|
||||
|
||||
func TestTerminalConnectUnauthorized(t *testing.T) {
|
||||
sessions, err := auth.NewSessionStore("file::memory:?cache=shared")
|
||||
if err != nil {
|
||||
t.Fatalf("session db: %v", err)
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
api := humago.New(mux, huma.DefaultConfig("Test", "1.0.0"))
|
||||
mod := New(sessions)
|
||||
mod.Register(api)
|
||||
|
||||
srv := httptest.NewServer(mux)
|
||||
defer srv.Close()
|
||||
|
||||
wsURL := strings.Replace(srv.URL, "http://", "ws://", 1) + "/api/terminal"
|
||||
|
||||
// Try without cookie
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
_, resp, err := websocket.Dial(ctx, wsURL, nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got success")
|
||||
}
|
||||
if resp != nil && resp.StatusCode != http.StatusUnauthorized {
|
||||
t.Errorf("expected 401, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTerminalConnectPlainGET(t *testing.T) {
|
||||
// An authenticated but non-WebSocket GET (e.g. the docs "Try it" button) must
|
||||
// get a clean 426 Upgrade Required, not a raw websocket protocol error.
|
||||
sessions, err := auth.NewSessionStore("file::memory:?cache=shared")
|
||||
if err != nil {
|
||||
t.Fatalf("session db: %v", err)
|
||||
}
|
||||
token, err := sessions.Create("root")
|
||||
if err != nil {
|
||||
t.Fatalf("create session: %v", err)
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
api := humago.New(mux, huma.DefaultConfig("Test", "1.0.0"))
|
||||
New(sessions).Register(api)
|
||||
|
||||
srv := httptest.NewServer(mux)
|
||||
defer srv.Close()
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, srv.URL+"/api/terminal", nil)
|
||||
req.Header.Set("Cookie", "nadir_session_id="+token)
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("GET failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusUpgradeRequired {
|
||||
t.Errorf("expected 426, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTerminalConnectAuthorized(t *testing.T) {
|
||||
// Root or specific system configs might cause PTY or su to fail if run in constrained environments.
|
||||
// But we expect the websocket upgrade to succeed and then maybe close with an error if PTY fails.
|
||||
sessions, err := auth.NewSessionStore("file::memory:?cache=shared")
|
||||
if err != nil {
|
||||
t.Fatalf("session db: %v", err)
|
||||
}
|
||||
|
||||
token, err := sessions.Create("root")
|
||||
if err != nil {
|
||||
t.Fatalf("create session: %v", err)
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
api := humago.New(mux, huma.DefaultConfig("Test", "1.0.0"))
|
||||
mod := New(sessions)
|
||||
mod.Register(api)
|
||||
|
||||
srv := httptest.NewServer(mux)
|
||||
defer srv.Close()
|
||||
|
||||
wsURL := strings.Replace(srv.URL, "http://", "ws://", 1) + "/api/terminal"
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
opts := &websocket.DialOptions{
|
||||
HTTPHeader: http.Header{},
|
||||
}
|
||||
opts.HTTPHeader.Set("Cookie", "nadir_session_id="+token)
|
||||
|
||||
conn, resp, err := websocket.Dial(ctx, wsURL, opts)
|
||||
if err != nil {
|
||||
// Depending on the test environment, "su" or "pty" might fail, but
|
||||
// the websocket upgrade itself should succeed before it drops.
|
||||
// If it fails to upgrade, that's a real error.
|
||||
t.Fatalf("dial failed: %v (status %d)", err, resp.StatusCode)
|
||||
}
|
||||
defer conn.CloseNow()
|
||||
|
||||
// The connection was upgraded successfully.
|
||||
// If PTY allocation failed or su failed, the server might close the connection immediately.
|
||||
// We just verify the upgrade succeeded.
|
||||
if resp.StatusCode != http.StatusSwitchingProtocols {
|
||||
t.Errorf("expected 101, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
package users
|
||||
|
||||
import (
|
||||
"nadir/internal/rbac"
|
||||
|
||||
"github.com/danielgtaylor/huma/v2"
|
||||
)
|
||||
|
||||
const ModuleID = "users"
|
||||
|
||||
type Module struct{}
|
||||
|
||||
func New() *Module { return &Module{} }
|
||||
|
||||
func (m *Module) ID() string { return ModuleID }
|
||||
func (m *Module) Name() string { return "Users" }
|
||||
|
||||
// Permissions: read to list/inspect accounts; write to create and change
|
||||
// passwords; root to delete (irreversible).
|
||||
func (m *Module) Permissions() []rbac.Permission {
|
||||
return []rbac.Permission{rbac.Read, rbac.Write, rbac.Root}
|
||||
}
|
||||
|
||||
func (m *Module) Register(api huma.API) {
|
||||
registerUsers(api)
|
||||
}
|
||||
|
||||
func op(permission string) map[string]any {
|
||||
return map[string]any{"module": ModuleID, "permission": permission}
|
||||
}
|
||||
@@ -0,0 +1,373 @@
|
||||
package users
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"nadir/internal/oscmd"
|
||||
|
||||
"github.com/danielgtaylor/huma/v2"
|
||||
)
|
||||
|
||||
const tagUsers = "Users"
|
||||
|
||||
var passwdPath = "/etc/passwd"
|
||||
|
||||
// systemUIDMax is the conventional upper bound for system (non-login) accounts;
|
||||
// regular users start at 1000 on Debian/Fedora. Used only to flag accounts for
|
||||
// the client, not to filter them out.
|
||||
const systemUIDMax = 1000
|
||||
|
||||
var (
|
||||
readErrors = []int{401, 403, 500}
|
||||
writeErrors = []int{400, 401, 403, 404, 409, 500}
|
||||
)
|
||||
|
||||
// userNameRe matches valid Linux usernames (the useradd default NAME_REGEX).
|
||||
// Starting with a letter/underscore also rejects leading-dash flag injection.
|
||||
var userNameRe = regexp.MustCompile(`^[a-z_][a-z0-9_-]{0,31}\$?$`)
|
||||
|
||||
// User mirrors one /etc/passwd entry.
|
||||
type User struct {
|
||||
Username string `json:"username" example:"alice" doc:"Login name"`
|
||||
UID int `json:"uid" example:"1000" doc:"User ID"`
|
||||
GID int `json:"gid" example:"1000" doc:"Primary group ID"`
|
||||
Comment string `json:"comment" example:"Alice Smith" doc:"GECOS comment (often the full name)"`
|
||||
Home string `json:"home" example:"/home/alice" doc:"Home directory"`
|
||||
Shell string `json:"shell" example:"/bin/bash" doc:"Login shell"`
|
||||
System bool `json:"system" doc:"True for system accounts (uid < 1000)"`
|
||||
}
|
||||
|
||||
type ListUsersOutput struct {
|
||||
Body struct {
|
||||
Users []User `json:"users" doc:"All accounts from /etc/passwd"`
|
||||
}
|
||||
}
|
||||
|
||||
type GetUserOutput struct{ Body User }
|
||||
|
||||
type UserPath struct {
|
||||
Username string `path:"username" example:"alice" doc:"Login name"`
|
||||
}
|
||||
|
||||
type CreateUserInput struct {
|
||||
Body struct {
|
||||
Username string `json:"username" example:"alice" doc:"Login name"`
|
||||
Comment string `json:"comment,omitempty" example:"Alice Smith" doc:"GECOS comment"`
|
||||
Shell string `json:"shell,omitempty" example:"/bin/bash" doc:"Login shell"`
|
||||
Home string `json:"home,omitempty" example:"/home/alice" doc:"Home directory (defaults to /home/<username>)"`
|
||||
CreateHome bool `json:"create_home,omitempty" doc:"Create the home directory (useradd -m)"`
|
||||
System bool `json:"system,omitempty" doc:"Create a system account (useradd --system)"`
|
||||
}
|
||||
}
|
||||
|
||||
type DeleteUserInput struct {
|
||||
Username string `path:"username" example:"alice" doc:"Login name"`
|
||||
RemoveHome bool `query:"remove_home" doc:"Also remove the home directory and mail spool (userdel -r)"`
|
||||
}
|
||||
|
||||
type SetPasswordInput struct {
|
||||
Username string `path:"username" example:"alice" doc:"Login name"`
|
||||
Body struct {
|
||||
Password string `json:"password" doc:"New password (sent to chpasswd over stdin, never argv)"`
|
||||
}
|
||||
}
|
||||
|
||||
type SetGroupsInput struct {
|
||||
Username string `path:"username" example:"alice" doc:"Login name"`
|
||||
Body struct {
|
||||
Groups []string `json:"groups" doc:"Supplementary groups; replaces the user's full supplementary set"`
|
||||
}
|
||||
}
|
||||
|
||||
type UserGroupsOutput struct {
|
||||
Body struct {
|
||||
Username string `json:"username" example:"alice"`
|
||||
Groups []string `json:"groups" doc:"All groups the user belongs to (primary + supplementary)"`
|
||||
}
|
||||
}
|
||||
|
||||
func registerUsers(api huma.API) {
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "users-list",
|
||||
Method: "GET",
|
||||
Path: "/api/users",
|
||||
Summary: "List user accounts",
|
||||
Description: "Returns every account in /etc/passwd, including system " +
|
||||
"accounts (flagged via `system`).",
|
||||
Tags: []string{tagUsers},
|
||||
Metadata: op("read"),
|
||||
Errors: readErrors,
|
||||
}, func(ctx context.Context, _ *struct{}) (*ListUsersOutput, error) {
|
||||
list, err := listUsers()
|
||||
if err != nil {
|
||||
return nil, huma.Error500InternalServerError("read "+passwdPath+" failed", err)
|
||||
}
|
||||
out := &ListUsersOutput{}
|
||||
out.Body.Users = list
|
||||
return out, nil
|
||||
})
|
||||
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "users-get",
|
||||
Method: "GET",
|
||||
Path: "/api/users/{username}",
|
||||
Summary: "Get a single user",
|
||||
Description: "Returns one account by login name. 404 if it does not exist.",
|
||||
Tags: []string{tagUsers},
|
||||
Metadata: op("read"),
|
||||
Errors: []int{400, 401, 403, 404, 500},
|
||||
}, func(ctx context.Context, in *UserPath) (*GetUserOutput, error) {
|
||||
if err := validateUsername(in.Username); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u, ok, err := lookupUser(in.Username)
|
||||
if err != nil {
|
||||
return nil, huma.Error500InternalServerError("read "+passwdPath+" failed", err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, huma.Error404NotFound("user not found: " + in.Username)
|
||||
}
|
||||
return &GetUserOutput{Body: u}, nil
|
||||
})
|
||||
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "users-create",
|
||||
Method: "POST",
|
||||
Path: "/api/users",
|
||||
Summary: "Create a user account",
|
||||
Description: "Creates an account via useradd. The new account has a locked " +
|
||||
"password until one is set via the password endpoint. 409 if the user " +
|
||||
"already exists.",
|
||||
Tags: []string{tagUsers},
|
||||
Metadata: op("write"),
|
||||
Errors: writeErrors,
|
||||
}, func(ctx context.Context, in *CreateUserInput) (*GetUserOutput, error) {
|
||||
if err := validateUsername(in.Body.Username); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// -c/-s/-d are option *arguments*, so the "--" separator doesn't shield
|
||||
// them. Validate at the boundary: a ':' or newline in the GECOS field
|
||||
// would corrupt /etc/passwd; shell/home must be absolute paths.
|
||||
if strings.ContainsAny(in.Body.Comment, ":\n") {
|
||||
return nil, huma.Error400BadRequest("comment may not contain ':' or newlines")
|
||||
}
|
||||
if in.Body.Shell != "" && !filepath.IsAbs(in.Body.Shell) {
|
||||
return nil, huma.Error400BadRequest("shell must be an absolute path")
|
||||
}
|
||||
if in.Body.Home != "" && !filepath.IsAbs(in.Body.Home) {
|
||||
return nil, huma.Error400BadRequest("home must be an absolute path")
|
||||
}
|
||||
if _, ok, err := lookupUser(in.Body.Username); err != nil {
|
||||
return nil, huma.Error500InternalServerError("read "+passwdPath+" failed", err)
|
||||
} else if ok {
|
||||
return nil, huma.Error409Conflict("user already exists: " + in.Body.Username)
|
||||
}
|
||||
|
||||
args := []string{}
|
||||
if in.Body.System {
|
||||
args = append(args, "--system")
|
||||
}
|
||||
if in.Body.CreateHome {
|
||||
args = append(args, "-m")
|
||||
}
|
||||
if in.Body.Comment != "" {
|
||||
args = append(args, "-c", in.Body.Comment)
|
||||
}
|
||||
if in.Body.Shell != "" {
|
||||
args = append(args, "-s", in.Body.Shell)
|
||||
}
|
||||
if in.Body.Home != "" {
|
||||
args = append(args, "-d", in.Body.Home)
|
||||
}
|
||||
args = append(args, "--", in.Body.Username)
|
||||
if _, err := oscmd.Run("useradd", args...); err != nil {
|
||||
return nil, huma.Error500InternalServerError("useradd failed", err)
|
||||
}
|
||||
|
||||
u, ok, err := lookupUser(in.Body.Username)
|
||||
if err != nil || !ok {
|
||||
return nil, huma.Error500InternalServerError("user created but could not be read back", err)
|
||||
}
|
||||
return &GetUserOutput{Body: u}, nil
|
||||
})
|
||||
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "users-delete",
|
||||
Method: "DELETE",
|
||||
Path: "/api/users/{username}",
|
||||
Summary: "Delete a user account",
|
||||
Description: "Removes an account via userdel. Pass ?remove_home=true to " +
|
||||
"also delete the home directory. 404 if the user does not exist.",
|
||||
Tags: []string{tagUsers},
|
||||
// Deleting an account is irreversible - gated behind root, not write.
|
||||
Metadata: op("root"),
|
||||
Errors: []int{400, 401, 403, 404, 500},
|
||||
}, func(ctx context.Context, in *DeleteUserInput) (*oscmd.StatusOutput, error) {
|
||||
if err := validateUsername(in.Username); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, ok, err := lookupUser(in.Username); err != nil {
|
||||
return nil, huma.Error500InternalServerError("read "+passwdPath+" failed", err)
|
||||
} else if !ok {
|
||||
return nil, huma.Error404NotFound("user not found: " + in.Username)
|
||||
}
|
||||
|
||||
args := []string{}
|
||||
if in.RemoveHome {
|
||||
args = append(args, "-r")
|
||||
}
|
||||
args = append(args, "--", in.Username)
|
||||
if _, err := oscmd.Run("userdel", args...); err != nil {
|
||||
return nil, huma.Error500InternalServerError("userdel failed", err)
|
||||
}
|
||||
return oscmd.OK(), nil
|
||||
})
|
||||
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "users-set-password",
|
||||
Method: "POST",
|
||||
Path: "/api/users/{username}/password",
|
||||
Summary: "Set a user's password",
|
||||
Description: "Sets the password via chpasswd (fed over stdin, so the secret " +
|
||||
"never appears in the process list). 404 if the user does not exist. " +
|
||||
"Requires the `root` permission: resetting a privileged account's " +
|
||||
"password (e.g. root) is a full-system action, not a routine write.",
|
||||
Tags: []string{tagUsers},
|
||||
Metadata: op("root"),
|
||||
Errors: []int{400, 401, 403, 404, 500},
|
||||
}, func(ctx context.Context, in *SetPasswordInput) (*oscmd.StatusOutput, error) {
|
||||
if err := validateUsername(in.Username); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if in.Body.Password == "" {
|
||||
return nil, huma.Error400BadRequest("empty password")
|
||||
}
|
||||
// chpasswd reads one "name:password" line per stdin line, so a newline in
|
||||
// the password would inject a second line and set another account's password.
|
||||
if strings.ContainsAny(in.Body.Password, "\n\r") {
|
||||
return nil, huma.Error400BadRequest("password may not contain newlines")
|
||||
}
|
||||
if _, ok, err := lookupUser(in.Username); err != nil {
|
||||
return nil, huma.Error500InternalServerError("read "+passwdPath+" failed", err)
|
||||
} else if !ok {
|
||||
return nil, huma.Error404NotFound("user not found: " + in.Username)
|
||||
}
|
||||
// chpasswd reads "name:password" lines from stdin.
|
||||
if _, err := oscmd.RunStdin(in.Username+":"+in.Body.Password+"\n", "chpasswd"); err != nil {
|
||||
return nil, huma.Error500InternalServerError("chpasswd failed", err)
|
||||
}
|
||||
return oscmd.OK(), nil
|
||||
})
|
||||
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "users-set-groups",
|
||||
Method: "PUT",
|
||||
Path: "/api/users/{username}/groups",
|
||||
Summary: "Set a user's supplementary groups",
|
||||
Description: "Replaces the user's full supplementary group set via " +
|
||||
"`usermod -G` (an empty list removes them from all supplementary " +
|
||||
"groups). Returns the resulting group membership. 404 if the user " +
|
||||
"does not exist; 400 if any named group is missing. Requires the " +
|
||||
"`root` permission: adding an account to wheel/sudo/docker is a " +
|
||||
"privilege grant, not a routine write.",
|
||||
Tags: []string{tagUsers},
|
||||
Metadata: op("root"),
|
||||
Errors: writeErrors,
|
||||
}, func(ctx context.Context, in *SetGroupsInput) (*UserGroupsOutput, error) {
|
||||
if err := validateUsername(in.Username); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, g := range in.Body.Groups {
|
||||
// Group names follow the same rule as usernames; this also blocks
|
||||
// flag injection and a stray comma turning one group into two.
|
||||
if !userNameRe.MatchString(g) {
|
||||
return nil, huma.Error400BadRequest("invalid group name: " + g)
|
||||
}
|
||||
}
|
||||
if _, ok, err := lookupUser(in.Username); err != nil {
|
||||
return nil, huma.Error500InternalServerError("read "+passwdPath+" failed", err)
|
||||
} else if !ok {
|
||||
return nil, huma.Error404NotFound("user not found: " + in.Username)
|
||||
}
|
||||
if _, err := oscmd.Run("usermod", "-G", strings.Join(in.Body.Groups, ","), "--", in.Username); err != nil {
|
||||
return nil, huma.Error400BadRequest("usermod failed (does every named group exist?)", err)
|
||||
}
|
||||
|
||||
// id -nG lists all groups (primary + supplementary) the user now has.
|
||||
out, err := oscmd.Run("id", "-nG", in.Username)
|
||||
if err != nil {
|
||||
return nil, huma.Error500InternalServerError("groups set but read-back failed", err)
|
||||
}
|
||||
res := &UserGroupsOutput{}
|
||||
res.Body.Username = in.Username
|
||||
res.Body.Groups = strings.Fields(out)
|
||||
return res, nil
|
||||
})
|
||||
}
|
||||
|
||||
// validateUsername rejects empty, flag-like, or malformed names before exec.
|
||||
func validateUsername(name string) error {
|
||||
if !userNameRe.MatchString(name) {
|
||||
return huma.Error400BadRequest("invalid username: " + name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func listUsers() ([]User, error) {
|
||||
data, err := os.ReadFile(passwdPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return parsePasswd(data), nil
|
||||
}
|
||||
|
||||
func lookupUser(name string) (User, bool, error) {
|
||||
list, err := listUsers()
|
||||
if err != nil {
|
||||
return User{}, false, err
|
||||
}
|
||||
for _, u := range list {
|
||||
if u.Username == name {
|
||||
return u, true, nil
|
||||
}
|
||||
}
|
||||
return User{}, false, nil
|
||||
}
|
||||
|
||||
// parsePasswd parses /etc/passwd content. Lines that are blank, commented, or
|
||||
// malformed (fewer than 7 fields, non-numeric ids) are skipped.
|
||||
func parsePasswd(data []byte) []User {
|
||||
var users []User
|
||||
for line := range strings.SplitSeq(string(data), "\n") {
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
f := strings.Split(line, ":")
|
||||
if len(f) < 7 {
|
||||
continue
|
||||
}
|
||||
uid, err := strconv.Atoi(f[2])
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
gid, err := strconv.Atoi(f[3])
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
users = append(users, User{
|
||||
Username: f[0],
|
||||
UID: uid,
|
||||
GID: gid,
|
||||
Comment: f[4],
|
||||
Home: f[5],
|
||||
Shell: f[6],
|
||||
System: uid < systemUIDMax,
|
||||
})
|
||||
}
|
||||
return users
|
||||
}
|
||||
@@ -0,0 +1,152 @@
|
||||
package users
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"nadir/internal/oscmd"
|
||||
|
||||
"github.com/danielgtaylor/huma/v2"
|
||||
"github.com/danielgtaylor/huma/v2/adapters/humago"
|
||||
"github.com/danielgtaylor/huma/v2/humatest"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
if oscmd.RunHelperProcess() {
|
||||
return
|
||||
}
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
func TestUsersHandlers(t *testing.T) {
|
||||
tempPasswd := filepath.Join(t.TempDir(), "passwd")
|
||||
initialContent := "root:x:0:0:root:/root:/bin/bash\nalice:x:1000:1000:Alice Smith:/home/alice:/bin/bash\n"
|
||||
if err := os.WriteFile(tempPasswd, []byte(initialContent), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
oldPasswd := passwdPath
|
||||
passwdPath = tempPasswd
|
||||
defer func() { passwdPath = oldPasswd }()
|
||||
|
||||
mux := http.NewServeMux()
|
||||
api := humatest.Wrap(t, humago.New(mux, huma.DefaultConfig("Test", "1.0.0")))
|
||||
registerUsers(api)
|
||||
|
||||
// 1. Test GET /api/users
|
||||
resp := api.Get("/api/users")
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("list users: got %d, want %d", resp.Code, http.StatusOK)
|
||||
}
|
||||
var listRes ListUsersOutput
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &listRes.Body); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(listRes.Body.Users) != 2 {
|
||||
t.Errorf("got %d users, want 2", len(listRes.Body.Users))
|
||||
}
|
||||
|
||||
// 2. Test GET /api/users/{username}
|
||||
resp = api.Get("/api/users/alice")
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("get user: got %d, want %d", resp.Code, http.StatusOK)
|
||||
}
|
||||
var getRes GetUserOutput
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &getRes.Body); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if getRes.Body.Username != "alice" || getRes.Body.UID != 1000 {
|
||||
t.Errorf("get user: got %+v", getRes.Body)
|
||||
}
|
||||
|
||||
resp = api.Get("/api/users/bob")
|
||||
if resp.Code != http.StatusNotFound {
|
||||
t.Errorf("get non-existent user: got %d, want %d", resp.Code, http.StatusNotFound)
|
||||
}
|
||||
|
||||
// 3. Test POST /api/users
|
||||
oscmd.SetMock("useradd", func(args []string) oscmd.MockCommand {
|
||||
wantArgs := []string{"-m", "-c", "Bob Jones", "-s", "/bin/sh", "--", "bob"}
|
||||
if !reflect.DeepEqual(args, wantArgs) {
|
||||
t.Errorf("useradd args: got %v, want %v", args, wantArgs)
|
||||
}
|
||||
bobContent := initialContent + "bob:x:1001:1001:Bob Jones:/home/bob:/bin/sh\n"
|
||||
os.WriteFile(tempPasswd, []byte(bobContent), 0644)
|
||||
return oscmd.MockCommand{ExitCode: 0}
|
||||
})
|
||||
defer oscmd.ClearMocks()
|
||||
|
||||
resp = api.Post("/api/users", struct {
|
||||
Username string `json:"username"`
|
||||
Comment string `json:"comment"`
|
||||
Shell string `json:"shell"`
|
||||
CreateHome bool `json:"create_home"`
|
||||
}{
|
||||
Username: "bob",
|
||||
Comment: "Bob Jones",
|
||||
Shell: "/bin/sh",
|
||||
CreateHome: true,
|
||||
})
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("create user: got %d, want %d", resp.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
// 4. Test DELETE /api/users/{username}
|
||||
oscmd.SetMock("userdel", func(args []string) oscmd.MockCommand {
|
||||
wantArgs := []string{"-r", "--", "bob"}
|
||||
if !reflect.DeepEqual(args, wantArgs) {
|
||||
t.Errorf("userdel args: got %v, want %v", args, wantArgs)
|
||||
}
|
||||
os.WriteFile(tempPasswd, []byte(initialContent), 0644)
|
||||
return oscmd.MockCommand{ExitCode: 0}
|
||||
})
|
||||
|
||||
resp = api.Delete("/api/users/bob?remove_home=true")
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("delete user: got %d, want %d", resp.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
// 5. Test POST /api/users/{username}/password
|
||||
oscmd.SetMock("chpasswd", func(args []string) oscmd.MockCommand {
|
||||
return oscmd.MockCommand{ExitCode: 0}
|
||||
})
|
||||
|
||||
resp = api.Post("/api/users/alice/password", struct {
|
||||
Password string `json:"password"`
|
||||
}{
|
||||
Password: "newsecretpwd",
|
||||
})
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("set password: got %d, want %d", resp.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
// 6. Test PUT /api/users/{username}/groups
|
||||
oscmd.SetMock("usermod", func(args []string) oscmd.MockCommand {
|
||||
wantArgs := []string{"-G", "wheel,dev", "--", "alice"}
|
||||
if !reflect.DeepEqual(args, wantArgs) {
|
||||
t.Errorf("usermod args: got %v, want %v", args, wantArgs)
|
||||
}
|
||||
return oscmd.MockCommand{ExitCode: 0}
|
||||
})
|
||||
|
||||
oscmd.SetMock("id", func(args []string) oscmd.MockCommand {
|
||||
wantArgs := []string{"-nG", "alice"}
|
||||
if !reflect.DeepEqual(args, wantArgs) {
|
||||
t.Errorf("id args: got %v, want %v", args, wantArgs)
|
||||
}
|
||||
return oscmd.MockCommand{Stdout: "alice wheel dev\n", ExitCode: 0}
|
||||
})
|
||||
|
||||
resp = api.Put("/api/users/alice/groups", struct {
|
||||
Groups []string `json:"groups"`
|
||||
}{
|
||||
Groups: []string{"wheel", "dev"},
|
||||
})
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("set groups: got %d, want %d", resp.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
package users
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestParsePasswd(t *testing.T) {
|
||||
data := []byte(`root:x:0:0:root:/root:/bin/bash
|
||||
# a comment
|
||||
|
||||
daemon:x:1:1:daemon:/usr/sbin:/usr/sbin/nologin
|
||||
alice:x:1000:1000:Alice Smith:/home/alice:/bin/bash
|
||||
broken:x:notanumber:5:::
|
||||
short:x:2:2
|
||||
`)
|
||||
got := parsePasswd(data)
|
||||
if len(got) != 3 {
|
||||
t.Fatalf("expected 3 valid users, got %d: %+v", len(got), got)
|
||||
}
|
||||
|
||||
alice := got[2]
|
||||
if alice.Username != "alice" || alice.UID != 1000 || alice.GID != 1000 ||
|
||||
alice.Comment != "Alice Smith" || alice.Home != "/home/alice" ||
|
||||
alice.Shell != "/bin/bash" || alice.System {
|
||||
t.Errorf("alice parsed wrong: %+v", alice)
|
||||
}
|
||||
if !got[0].System || !got[1].System {
|
||||
t.Error("root/daemon should be flagged as system accounts")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateUsername(t *testing.T) {
|
||||
valid := []string{"alice", "_svc", "user-1", "a", "machine$", "abc_def"}
|
||||
for _, n := range valid {
|
||||
if err := validateUsername(n); err != nil {
|
||||
t.Errorf("validateUsername(%q) = %v, want nil", n, err)
|
||||
}
|
||||
}
|
||||
|
||||
invalid := []string{
|
||||
"", // empty
|
||||
"-rf", // leading dash (flag injection)
|
||||
"Alice", // uppercase
|
||||
"1user", // leading digit
|
||||
"a b", // space
|
||||
"foo;rm", // shell metachar
|
||||
"root:x", // colon (passwd separator)
|
||||
"waytoolongusernamethatexceedsthirtytwochars", // >32
|
||||
}
|
||||
for _, n := range invalid {
|
||||
if err := validateUsername(n); err == nil {
|
||||
t.Errorf("validateUsername(%q) = nil, want error", n)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
// Package mounts parses the kernel mount table (/proc/mounts). Both the system
|
||||
// dashboard (disk usage) and the storage module (mount management) need it, so
|
||||
// it lives here rather than being duplicated. It also exposes the octal
|
||||
// unescaping that /proc/mounts and /etc/fstab use for spaces/tabs in paths.
|
||||
package mounts
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// procMounts is a var so tests can point it at a fixture.
|
||||
var procMounts = "/proc/mounts"
|
||||
|
||||
// Mount is one mount-table line: the backing device, where it's mounted, the
|
||||
// filesystem type, and the comma-separated mount options.
|
||||
type Mount struct {
|
||||
Device string `json:"device" example:"/dev/sda1"`
|
||||
Mountpoint string `json:"mountpoint" example:"/mnt/data"`
|
||||
FSType string `json:"fstype" example:"ext4"`
|
||||
Options string `json:"options" example:"rw,relatime"`
|
||||
}
|
||||
|
||||
// Proc reads and parses /proc/mounts.
|
||||
func Proc() ([]Mount, error) {
|
||||
data, err := os.ReadFile(procMounts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return parseProc(string(data)), nil
|
||||
}
|
||||
|
||||
func parseProc(data string) []Mount {
|
||||
entries := []Mount{}
|
||||
for line := range strings.SplitSeq(data, "\n") {
|
||||
f := strings.Fields(line)
|
||||
if len(f) < 4 { // device mountpoint fstype options [dump pass]
|
||||
continue
|
||||
}
|
||||
entries = append(entries, Mount{
|
||||
Device: Unescape(f[0]),
|
||||
Mountpoint: Unescape(f[1]),
|
||||
FSType: f[2],
|
||||
Options: f[3],
|
||||
})
|
||||
}
|
||||
return entries
|
||||
}
|
||||
|
||||
// Unescape decodes the octal escapes (\040 space, \011 tab, \012 newline,
|
||||
// \134 backslash) that mount tables and fstab use for whitespace in fields.
|
||||
func Unescape(s string) string {
|
||||
if !strings.ContainsRune(s, '\\') {
|
||||
return s
|
||||
}
|
||||
var b strings.Builder
|
||||
for i := 0; i < len(s); i++ {
|
||||
if s[i] == '\\' && i+3 < len(s) && isOctal(s[i+1]) && isOctal(s[i+2]) && isOctal(s[i+3]) {
|
||||
b.WriteByte((s[i+1]-'0')*64 + (s[i+2]-'0')*8 + (s[i+3] - '0'))
|
||||
i += 3
|
||||
continue
|
||||
}
|
||||
b.WriteByte(s[i])
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func isOctal(c byte) bool { return c >= '0' && c <= '7' }
|
||||
@@ -0,0 +1,58 @@
|
||||
package mounts
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseProc(t *testing.T) {
|
||||
// Real /proc/mounts shape, including a pseudo fs and an octal-escaped space.
|
||||
data := `proc /proc proc rw,nosuid,nodev,noexec 0 0
|
||||
/dev/sda1 / ext4 rw,relatime 0 0
|
||||
/dev/sdb1 /mnt/my\040disk ext4 rw 0 0
|
||||
`
|
||||
got := parseProc(data)
|
||||
want := []Mount{
|
||||
{Device: "proc", Mountpoint: "/proc", FSType: "proc", Options: "rw,nosuid,nodev,noexec"},
|
||||
{Device: "/dev/sda1", Mountpoint: "/", FSType: "ext4", Options: "rw,relatime"},
|
||||
{Device: "/dev/sdb1", Mountpoint: "/mnt/my disk", FSType: "ext4", Options: "rw"},
|
||||
}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("parseProc:\n got %+v\nwant %+v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnescape(t *testing.T) {
|
||||
cases := map[string]string{
|
||||
`/mnt/my\040disk`: "/mnt/my disk",
|
||||
`/no/escapes`: "/no/escapes",
|
||||
`tab\011here`: "tab\there",
|
||||
`back\134slash`: `back\slash`,
|
||||
}
|
||||
for in, want := range cases {
|
||||
if got := Unescape(in); got != want {
|
||||
t.Errorf("Unescape(%q) = %q, want %q", in, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestProc(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
f := filepath.Join(dir, "mounts")
|
||||
if err := os.WriteFile(f, []byte("/dev/sda1 / ext4 rw 0 0\n"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
old := procMounts
|
||||
procMounts = f
|
||||
defer func() { procMounts = old }()
|
||||
|
||||
got, err := Proc()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(got) != 1 || got[0].Mountpoint != "/" {
|
||||
t.Errorf("Proc() = %+v", got)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,71 @@
|
||||
// Package openapitest holds a cross-module guard: it registers every module
|
||||
// into one Huma API and generates the spec. Huma names OpenAPI schemas by Go
|
||||
// type name alone (not package-qualified), so two modules declaring an
|
||||
// identically named response type (e.g. two "ListOutput"s) collide only when
|
||||
// registered together - a failure invisible to per-package `go build`/`go test`.
|
||||
package openapitest
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"nadir/internal/auditlog"
|
||||
"nadir/internal/auth"
|
||||
"nadir/internal/meta"
|
||||
"nadir/internal/module"
|
||||
"nadir/internal/modules/audit"
|
||||
"nadir/internal/modules/groups"
|
||||
"nadir/internal/modules/networking"
|
||||
"nadir/internal/modules/packages"
|
||||
"nadir/internal/modules/services"
|
||||
"nadir/internal/modules/storage"
|
||||
"nadir/internal/modules/system"
|
||||
"nadir/internal/modules/terminal"
|
||||
"nadir/internal/modules/users"
|
||||
"nadir/internal/rbac"
|
||||
|
||||
"github.com/danielgtaylor/huma/v2"
|
||||
"github.com/danielgtaylor/huma/v2/adapters/humago"
|
||||
)
|
||||
|
||||
func TestOpenAPISchemaNoCollisions(t *testing.T) {
|
||||
auditStore, err := auditlog.New(filepath.Join(t.TempDir(), "audit.db"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
sessions, err := auth.NewSessionStore(filepath.Join(t.TempDir(), "sessions.db"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
roles := rbac.New()
|
||||
|
||||
// Mirror cmd/server.go's registration so this catches real collisions.
|
||||
mods := []module.Module{
|
||||
system.New(),
|
||||
services.New(nil),
|
||||
users.New(),
|
||||
groups.New(),
|
||||
packages.New(),
|
||||
networking.New(),
|
||||
storage.New(),
|
||||
audit.New(auditStore),
|
||||
terminal.New(sessions),
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
api := humago.New(mux, huma.DefaultConfig("test", "1.0.0"))
|
||||
for _, m := range mods {
|
||||
m.Register(api) // huma panics here on a duplicate schema name
|
||||
}
|
||||
meta.Register(api, mods)
|
||||
meta.RegisterHealth(api, sessions)
|
||||
meta.RegisterWhoami(api, sessions, roles, mods)
|
||||
auth.RegisterLogin(api, sessions, auditStore, true)
|
||||
auth.RegisterLogout(api, sessions, true)
|
||||
|
||||
// Force full schema resolution.
|
||||
if _, err := api.OpenAPI().YAML(); err != nil {
|
||||
t.Fatalf("OpenAPI generation failed: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,354 @@
|
||||
// Package oscmd holds small helpers shared by modules that shell out to system
|
||||
// tools (hostnamectl, timedatectl, systemctl, …): a command runner that
|
||||
// surfaces stderr, and the common "status: ok" HTTP response.
|
||||
package oscmd
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// cmdTimeout caps how long a synchronous system command may run before it is
|
||||
// killed, so a wedged tool can't tie up a request indefinitely.
|
||||
//
|
||||
// It bounds hangs for the plain Run path (which uses context.Background()).
|
||||
// RunContext additionally propagates client cancellation. Long ops (package
|
||||
// install/upgrade) use the streaming runners below, which take a ctx and are
|
||||
// uncapped.
|
||||
const cmdTimeout = 60 * time.Second
|
||||
|
||||
// CommandRunner allows overriding the command execution function for testing.
|
||||
var CommandRunner = func(ctx context.Context, name string, args ...string) *exec.Cmd {
|
||||
mockMu.Lock()
|
||||
handler, ok := mockCmds[name]
|
||||
mockMu.Unlock()
|
||||
|
||||
if ok {
|
||||
mockRes := handler(args)
|
||||
tmpFile, err := os.CreateTemp("", "nadir-mock-*.json")
|
||||
if err != nil {
|
||||
log.Printf("oscmd mock: failed to create temp file: %v", err)
|
||||
return exec.CommandContext(ctx, name, args...)
|
||||
}
|
||||
|
||||
encoder := json.NewEncoder(tmpFile)
|
||||
if err := encoder.Encode(mockRes); err != nil {
|
||||
log.Printf("oscmd mock: failed to write json: %v", err)
|
||||
tmpFile.Close()
|
||||
return exec.CommandContext(ctx, name, args...)
|
||||
}
|
||||
tmpPath := tmpFile.Name()
|
||||
tmpFile.Close()
|
||||
|
||||
cmd := exec.CommandContext(ctx, os.Args[0])
|
||||
cmd.Env = append(os.Environ(),
|
||||
"GO_WANT_HELPER_PROCESS=1",
|
||||
"NADIR_MOCK_FILE="+tmpPath,
|
||||
)
|
||||
return cmd
|
||||
}
|
||||
|
||||
return exec.CommandContext(ctx, name, args...)
|
||||
}
|
||||
|
||||
// Run executes name with args and returns trimmed stdout. On failure it wraps
|
||||
// the command's stderr (falling back to the exec error) so handlers can surface
|
||||
// a meaningful message instead of a bare "exit status 1".
|
||||
//
|
||||
// Run uses context.Background(): it is bounded by cmdTimeout but not tied to any
|
||||
// request. Callers running a slow command on behalf of a request (so a client
|
||||
// disconnect should kill it) should use RunContext instead.
|
||||
func Run(name string, args ...string) (string, error) {
|
||||
return RunContext(context.Background(), name, args...)
|
||||
}
|
||||
|
||||
// RunContext is Run with a caller-supplied context. The command is killed when
|
||||
// ctx is cancelled (e.g. the client disconnected) or after cmdTimeout, whichever
|
||||
// comes first.
|
||||
func RunContext(ctx context.Context, name string, args ...string) (string, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, cmdTimeout)
|
||||
defer cancel()
|
||||
cmd := CommandRunner(ctx, name, args...)
|
||||
var stderr strings.Builder
|
||||
cmd.Stderr = &stderr
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
if msg := strings.TrimSpace(stderr.String()); msg != "" {
|
||||
return "", errors.New(msg)
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
return strings.TrimSpace(string(out)), nil
|
||||
}
|
||||
|
||||
// RunStdin is Run with data fed to the command's stdin. Use it for secrets
|
||||
// (e.g. piping "user:password" to chpasswd) so they never appear in argv/ps.
|
||||
func RunStdin(stdin, name string, args ...string) (string, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), cmdTimeout)
|
||||
defer cancel()
|
||||
cmd := CommandRunner(ctx, name, args...)
|
||||
cmd.Stdin = strings.NewReader(stdin)
|
||||
var stderr strings.Builder
|
||||
cmd.Stderr = &stderr
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
if msg := strings.TrimSpace(stderr.String()); msg != "" {
|
||||
return "", errors.New(msg)
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
return strings.TrimSpace(string(out)), nil
|
||||
}
|
||||
|
||||
// RunLines runs the command and splits stdout into non-empty lines.
|
||||
func RunLines(name string, args ...string) ([]string, error) {
|
||||
out, err := Run(name, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if out == "" {
|
||||
return []string{}, nil
|
||||
}
|
||||
return strings.Split(out, "\n"), nil
|
||||
}
|
||||
|
||||
// RunStatus runs name and returns trimmed stdout plus the process exit code. It
|
||||
// does NOT treat a non-zero exit as an error, because some tools signal state
|
||||
// that way (dnf `check-update` exits 100 when updates exist; pacman `-Qu` exits
|
||||
// 1 when there are none). err is set only when the command could not be run.
|
||||
func RunStatus(name string, args ...string) (string, int, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), cmdTimeout)
|
||||
defer cancel()
|
||||
cmd := CommandRunner(ctx, name, args...)
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
var ee *exec.ExitError
|
||||
if errors.As(err, &ee) {
|
||||
return strings.TrimSpace(string(out)), ee.ExitCode(), nil
|
||||
}
|
||||
return "", -1, err
|
||||
}
|
||||
return strings.TrimSpace(string(out)), 0, nil
|
||||
}
|
||||
|
||||
// RunStream runs a long-lived command (e.g. `journalctl -f`) and pushes its
|
||||
// stdout lines onto the returned channel until the process exits or ctx is
|
||||
// cancelled. Cancelling ctx kills the process (exec.CommandContext) and closes
|
||||
// the channel, so an SSE handler can stop simply by cancelling its request
|
||||
// context (client disconnect).
|
||||
func RunStream(ctx context.Context, name string, args ...string) (<-chan string, error) {
|
||||
cmd := CommandRunner(ctx, name, args...)
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ch := make(chan string)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
defer cmd.Wait() // reap; process is already killed on ctx cancel
|
||||
sc := bufio.NewScanner(stdout)
|
||||
sc.Buffer(make([]byte, 64*1024), 1024*1024) // tolerate long log lines
|
||||
for sc.Scan() {
|
||||
select {
|
||||
case ch <- sc.Text():
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
// A scan error (e.g. a line over the buffer cap) ends the loop silently
|
||||
// otherwise; surface it so a truncated stream is diagnosable. Context
|
||||
// cancellation returns above, so this only fires on a real read error.
|
||||
if err := sc.Err(); err != nil && ctx.Err() == nil {
|
||||
log.Printf("oscmd: stream %s: %v", name, err)
|
||||
}
|
||||
}()
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
// RunStreamCombined runs a command and streams its merged stdout+stderr line by
|
||||
// line, the way the command's own terminal output looks. Lines are split on \r
|
||||
// as well as \n so progress redraws (apt/dnf) stream instead of buffering until
|
||||
// a newline. extraEnv is appended to the process environment.
|
||||
//
|
||||
// It returns a lines channel and a one-shot error channel that delivers the
|
||||
// process's exit status after the lines channel closes (nil = success).
|
||||
// Cancelling ctx kills the process.
|
||||
func RunStreamCombined(ctx context.Context, extraEnv []string, name string, args ...string) (<-chan string, <-chan error, error) {
|
||||
cmd := CommandRunner(ctx, name, args...)
|
||||
if len(extraEnv) > 0 {
|
||||
if cmd.Env == nil {
|
||||
cmd.Env = append(os.Environ(), extraEnv...)
|
||||
} else {
|
||||
cmd.Env = append(cmd.Env, extraEnv...)
|
||||
}
|
||||
}
|
||||
// One OS pipe for both streams gives the natural interleaving; passing an
|
||||
// *os.File hands the fd straight to the child (no copy goroutine).
|
||||
pr, pw, err := os.Pipe()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
cmd.Stdout, cmd.Stderr = pw, pw
|
||||
if err := cmd.Start(); err != nil {
|
||||
pr.Close()
|
||||
pw.Close()
|
||||
return nil, nil, err
|
||||
}
|
||||
pw.Close() // the child holds its own dup; the reader sees EOF when it exits
|
||||
|
||||
lines := make(chan string)
|
||||
errc := make(chan error, 1)
|
||||
go func() {
|
||||
defer close(lines)
|
||||
defer func() { errc <- cmd.Wait() }()
|
||||
defer pr.Close()
|
||||
sc := bufio.NewScanner(pr)
|
||||
sc.Buffer(make([]byte, 64*1024), 1024*1024)
|
||||
sc.Split(scanLinesCR)
|
||||
for sc.Scan() {
|
||||
if line := sc.Text(); line != "" {
|
||||
select {
|
||||
case lines <- line:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
// Surface a read error (e.g. a line over the buffer cap) so a truncated
|
||||
// stream is diagnosable. ctx cancellation returns above, so this only
|
||||
// fires on a real error; the exit status still travels via errc.
|
||||
if err := sc.Err(); err != nil && ctx.Err() == nil {
|
||||
log.Printf("oscmd: stream %s: %v", name, err)
|
||||
}
|
||||
}()
|
||||
return lines, errc, nil
|
||||
}
|
||||
|
||||
// scanLinesCR is a bufio.SplitFunc that breaks on either \r or \n, so terminal
|
||||
// progress redraws (carriage returns) surface as they happen.
|
||||
func scanLinesCR(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := bytes.IndexAny(data, "\r\n"); i >= 0 {
|
||||
return i + 1, data[:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil // request more data
|
||||
}
|
||||
|
||||
// ParseKV parses key=value lines (e.g. from timedatectl/systemctl show) into a map.
|
||||
func ParseKV(lines []string) map[string]string {
|
||||
m := make(map[string]string, len(lines))
|
||||
for _, line := range lines {
|
||||
if k, v, ok := strings.Cut(line, "="); ok {
|
||||
m[k] = v
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// StatusOutput is the shared response for write operations that just report
|
||||
// success. Reusing one type means all such endpoints share a single OpenAPI
|
||||
// schema.
|
||||
type StatusOutput struct {
|
||||
Body struct {
|
||||
Status string `json:"status" example:"ok" doc:"Always \"ok\" on success"`
|
||||
}
|
||||
}
|
||||
|
||||
// OK returns a populated StatusOutput.
|
||||
func OK() *StatusOutput {
|
||||
out := &StatusOutput{}
|
||||
out.Body.Status = "ok"
|
||||
return out
|
||||
}
|
||||
|
||||
// --- Mocking helpers for testing ---------------------------------------------
|
||||
|
||||
// MockCommand holds the behavior for a mocked command.
|
||||
type MockCommand struct {
|
||||
Stdout string `json:"stdout"`
|
||||
Stderr string `json:"stderr"`
|
||||
ExitCode int `json:"exit_code"`
|
||||
Lines []string `json:"lines,omitempty"`
|
||||
DelayMs int `json:"delay_ms,omitempty"`
|
||||
}
|
||||
|
||||
var (
|
||||
mockMu sync.Mutex
|
||||
mockCmds = make(map[string]func(args []string) MockCommand)
|
||||
)
|
||||
|
||||
// SetMock registers a mock handler function for the given command name.
|
||||
func SetMock(name string, handler func(args []string) MockCommand) {
|
||||
mockMu.Lock()
|
||||
defer mockMu.Unlock()
|
||||
mockCmds[name] = handler
|
||||
}
|
||||
|
||||
// ClearMocks removes all registered mock command handlers.
|
||||
func ClearMocks() {
|
||||
mockMu.Lock()
|
||||
defer mockMu.Unlock()
|
||||
clear(mockCmds)
|
||||
}
|
||||
|
||||
// RunHelperProcess executes the mock helper process logic if GO_WANT_HELPER_PROCESS is set.
|
||||
// It returns true if it ran (and exits the process), false otherwise.
|
||||
func RunHelperProcess() bool {
|
||||
if os.Getenv("GO_WANT_HELPER_PROCESS") != "1" {
|
||||
return false
|
||||
}
|
||||
mockFile := os.Getenv("NADIR_MOCK_FILE")
|
||||
if mockFile == "" {
|
||||
os.Exit(1)
|
||||
}
|
||||
defer os.Remove(mockFile)
|
||||
|
||||
data, err := os.ReadFile(mockFile)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "mock helper: read failed: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
var mock MockCommand
|
||||
if err := json.Unmarshal(data, &mock); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "mock helper: unmarshal failed: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if len(mock.Lines) > 0 {
|
||||
for _, line := range mock.Lines {
|
||||
fmt.Fprintln(os.Stdout, line)
|
||||
if mock.DelayMs > 0 {
|
||||
time.Sleep(time.Duration(mock.DelayMs) * time.Millisecond)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if mock.Stdout != "" {
|
||||
fmt.Fprint(os.Stdout, mock.Stdout)
|
||||
}
|
||||
if mock.Stderr != "" {
|
||||
fmt.Fprint(os.Stderr, mock.Stderr)
|
||||
}
|
||||
}
|
||||
|
||||
os.Exit(mock.ExitCode)
|
||||
return true
|
||||
}
|
||||
@@ -0,0 +1,143 @@
|
||||
package oscmd
|
||||
|
||||
import (
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
if RunHelperProcess() {
|
||||
return
|
||||
}
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
func TestRunTrimsStdout(t *testing.T) {
|
||||
out, err := Run("echo", "hello")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if out != "hello" {
|
||||
t.Errorf("got %q, want %q", out, "hello")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunSurfacesStderr(t *testing.T) {
|
||||
_, err := Run("sh", "-c", "echo boom >&2; exit 1")
|
||||
if err == nil || !strings.Contains(err.Error(), "boom") {
|
||||
t.Fatalf("expected stderr 'boom' in error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunMissingBinary(t *testing.T) {
|
||||
if _, err := Run("definitely-not-a-real-binary-xyz"); err == nil {
|
||||
t.Fatal("expected error for missing binary")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunLines(t *testing.T) {
|
||||
got, err := RunLines("printf", `a\nb\nc\n`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
want := []string{"a", "b", "c"}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("got %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunLinesEmpty(t *testing.T) {
|
||||
got, err := RunLines("true")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(got) != 0 {
|
||||
t.Errorf("empty output should yield no lines, got %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOK(t *testing.T) {
|
||||
if OK().Body.Status != "ok" {
|
||||
t.Error("OK() should report status ok")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunStdin(t *testing.T) {
|
||||
SetMock("cat", func(args []string) MockCommand {
|
||||
return MockCommand{Stdout: "piped input data", ExitCode: 0}
|
||||
})
|
||||
defer ClearMocks()
|
||||
|
||||
got, err := RunStdin("some input", "cat")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got != "piped input data" {
|
||||
t.Errorf("got %q, want %q", got, "piped input data")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunStatus(t *testing.T) {
|
||||
SetMock("check-update", func(args []string) MockCommand {
|
||||
return MockCommand{Stdout: "updates available", ExitCode: 100}
|
||||
})
|
||||
defer ClearMocks()
|
||||
|
||||
out, code, err := RunStatus("check-update")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if out != "updates available" || code != 100 {
|
||||
t.Errorf("got out=%q code=%d, want out=%q code=100", out, code, "updates available")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunStream(t *testing.T) {
|
||||
SetMock("stream-cmd", func(args []string) MockCommand {
|
||||
return MockCommand{Lines: []string{"line1", "line2"}, DelayMs: 1}
|
||||
})
|
||||
defer ClearMocks()
|
||||
|
||||
ch, err := RunStream(t.Context(), "stream-cmd")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var got []string
|
||||
for line := range ch {
|
||||
got = append(got, line)
|
||||
}
|
||||
|
||||
want := []string{"line1", "line2"}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("got %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunStreamCombined(t *testing.T) {
|
||||
SetMock("combined-cmd", func(args []string) MockCommand {
|
||||
return MockCommand{Lines: []string{"output1", "output2"}, ExitCode: 0}
|
||||
})
|
||||
defer ClearMocks()
|
||||
|
||||
lines, errc, err := RunStreamCombined(t.Context(), nil, "combined-cmd")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var got []string
|
||||
for line := range lines {
|
||||
got = append(got, line)
|
||||
}
|
||||
|
||||
if err := <-errc; err != nil {
|
||||
t.Errorf("stream failed with error: %v", err)
|
||||
}
|
||||
|
||||
want := []string{"output1", "output2"}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("got %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
package rbac
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func BenchmarkCan(b *testing.B) {
|
||||
r := New()
|
||||
r.DefineRole(Role{Name: "admin", ModuleGrants: map[string][]Permission{"*": {"*"}}})
|
||||
r.DefineRole(Role{Name: "viewer", ModuleGrants: map[string][]Permission{
|
||||
"system": {"read"},
|
||||
"services": {"read"},
|
||||
"audit": {"read"},
|
||||
}})
|
||||
r.DefineRole(Role{Name: "operator", ModuleGrants: map[string][]Permission{
|
||||
"system": {"read", "root"},
|
||||
"services": {"read", "write"},
|
||||
}})
|
||||
|
||||
r.AssignRole("alice", "admin")
|
||||
r.AssignRole("bob", "viewer")
|
||||
r.AssignRole("charlie", "viewer")
|
||||
r.AssignRole("charlie", "operator")
|
||||
|
||||
// b.Loop automatically excludes the setup above from timing (no ResetTimer
|
||||
// needed) and keeps the calls from being optimised away.
|
||||
for b.Loop() {
|
||||
_ = r.Can("charlie", "services", "write")
|
||||
_ = r.Can("bob", "system", "root")
|
||||
_ = r.Can("alice", "audit", "write")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,121 @@
|
||||
package rbac
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"nadir/internal/auditlog"
|
||||
"nadir/internal/auth"
|
||||
|
||||
"github.com/danielgtaylor/huma/v2"
|
||||
)
|
||||
|
||||
func RbacMiddleware(api huma.API, sessions *auth.SessionStore, tokens *auth.TokenAuth, roles *RBAC, auditor *auditlog.Store) func(huma.Context, func(huma.Context)) {
|
||||
return func(ctx huma.Context, next func(huma.Context)) {
|
||||
// CSRF defense-in-depth (beyond the SameSite=Strict cookie): reject any
|
||||
// state-changing request whose Origin doesn't match our Host. Runs before
|
||||
// the permission check so it also covers the public login/logout writes.
|
||||
//
|
||||
// Bearer tokens are exempt by construction: browsers don't auto-attach an
|
||||
// Authorization header, so a token request carries no ambient authority for
|
||||
// CSRF to exploit. Such requests also send no Origin, so they pass here too.
|
||||
if !sameOriginOK(ctx) {
|
||||
huma.WriteErr(api, ctx, http.StatusForbidden, "cross-origin request blocked")
|
||||
return
|
||||
}
|
||||
|
||||
op := ctx.Operation()
|
||||
perm, _ := op.Metadata["permission"].(string)
|
||||
if perm == "" {
|
||||
next(ctx)
|
||||
return
|
||||
}
|
||||
moduleID, _ := op.Metadata["module"].(string)
|
||||
|
||||
// Machine-to-machine: a Bearer token authenticates as its token name (a
|
||||
// dashboard managing N nodes needs no PAM session). Checked before the
|
||||
// cookie. The token name is the RBAC subject - assign it a role in
|
||||
// config.yaml's `assignments`, same as a username.
|
||||
if raw, isBearer := auth.BearerToken(ctx.Header("Authorization")); isBearer {
|
||||
if tokens == nil {
|
||||
huma.WriteErr(api, ctx, http.StatusUnauthorized, "unauthorized")
|
||||
return
|
||||
}
|
||||
name, ok, throttled := tokens.Verify(auth.ClientIP(ctx.Context()), raw)
|
||||
if throttled {
|
||||
huma.WriteErr(api, ctx, http.StatusTooManyRequests, "too many failed token attempts; wait a minute")
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
huma.WriteErr(api, ctx, http.StatusUnauthorized, "unauthorized")
|
||||
return
|
||||
}
|
||||
// Audit actor is prefixed "token:" so the trail distinguishes a token
|
||||
// from a human with the same name; RBAC still checks the bare name.
|
||||
if !roles.Can(name, moduleID, Permission(perm)) {
|
||||
huma.WriteErr(api, ctx, http.StatusForbidden, "forbidden")
|
||||
record(auditor, "token:"+name, op, ctx.URL().Path, http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
next(ctx)
|
||||
record(auditor, "token:"+name, op, ctx.URL().Path, ctx.Status())
|
||||
return
|
||||
}
|
||||
|
||||
cookie, err := huma.ReadCookie(ctx, "nadir_session_id")
|
||||
if err != nil || cookie == nil {
|
||||
huma.WriteErr(api, ctx, http.StatusUnauthorized, "unauthorized")
|
||||
return
|
||||
}
|
||||
sess, ok := sessions.GetByToken(cookie.Value)
|
||||
if !ok {
|
||||
huma.WriteErr(api, ctx, http.StatusUnauthorized, "unauthorized")
|
||||
return
|
||||
}
|
||||
if !roles.Can(sess.Username, moduleID, Permission(perm)) {
|
||||
huma.WriteErr(api, ctx, http.StatusForbidden, "forbidden")
|
||||
record(auditor, sess.Username, op, ctx.URL().Path, http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
next(ctx)
|
||||
record(auditor, sess.Username, op, ctx.URL().Path, ctx.Status())
|
||||
}
|
||||
}
|
||||
|
||||
// sameOriginOK allows safe methods and any request whose Origin header matches
|
||||
// the request Host. A missing Origin (non-browser client, or a same-origin
|
||||
// navigation) is allowed - CSRF is a browser-only concern. Only browser-issued
|
||||
// cross-origin writes, which always carry an Origin, are rejected.
|
||||
func sameOriginOK(ctx huma.Context) bool {
|
||||
switch ctx.Method() {
|
||||
case http.MethodGet, http.MethodHead, http.MethodOptions:
|
||||
return true
|
||||
}
|
||||
origin := ctx.Header("Origin")
|
||||
if origin == "" {
|
||||
return true
|
||||
}
|
||||
u, err := url.Parse(origin)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return u.Host == ctx.Host()
|
||||
}
|
||||
|
||||
// record logs a mutation (anything but a read) to the audit trail. Reads are
|
||||
// skipped to keep the trail to "who changed what". Best-effort: a logging
|
||||
// failure is reported to the server log, never to the caller.
|
||||
func record(auditor *auditlog.Store, username string, op *huma.Operation, path string, status int) {
|
||||
if op.Method == http.MethodGet || op.Method == http.MethodHead || op.Method == http.MethodOptions {
|
||||
return
|
||||
}
|
||||
// SSE handlers stream via BodyWriter and never call SetStatus, so huma's
|
||||
// context reports status 0 even though net/http has implicitly sent a 200.
|
||||
// A streamed response that reached here passed RBAC and started, so record
|
||||
// it as 200 (stream-level failures surface as `error` events, not codes).
|
||||
if status == 0 {
|
||||
status = http.StatusOK
|
||||
}
|
||||
moduleID, _ := op.Metadata["module"].(string)
|
||||
auditor.Record(username, op.Method, path, moduleID, status)
|
||||
}
|
||||
@@ -0,0 +1,165 @@
|
||||
package rbac
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"nadir/internal/auditlog"
|
||||
"nadir/internal/auth"
|
||||
"nadir/internal/oscmd"
|
||||
|
||||
"github.com/danielgtaylor/huma/v2"
|
||||
"github.com/danielgtaylor/huma/v2/adapters/humago"
|
||||
"github.com/danielgtaylor/huma/v2/humatest"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
if oscmd.RunHelperProcess() {
|
||||
return
|
||||
}
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
func TestRbacMiddleware(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
auditStore, err := auditlog.New(filepath.Join(tempDir, "audit.db"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer auditStore.Close()
|
||||
|
||||
sessions, err := auth.NewSessionStore(filepath.Join(tempDir, "sessions.db"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tokenStore, err := auth.NewTokenStore(filepath.Join(tempDir, "tokens.db"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tokenAuth := auth.NewTokenAuth(tokenStore)
|
||||
|
||||
r := New()
|
||||
r.DefineRole(Role{
|
||||
Name: "test-role",
|
||||
ModuleGrants: map[string][]Permission{
|
||||
"test-mod": {Read, Write},
|
||||
},
|
||||
})
|
||||
r.AssignRole("alice", "test-role")
|
||||
// A machine credential is just another RBAC subject: the token name is
|
||||
// assigned a role exactly like a username.
|
||||
r.AssignRole("dash", "test-role")
|
||||
|
||||
mux := http.NewServeMux()
|
||||
api := humatest.Wrap(t, humago.New(mux, huma.DefaultConfig("Test", "1.0.0")))
|
||||
|
||||
api.UseMiddleware(RbacMiddleware(api, sessions, tokenAuth, r, auditStore))
|
||||
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "public-get",
|
||||
Method: "GET",
|
||||
Path: "/public",
|
||||
}, func(ctx context.Context, _ *struct{}) (*struct{ Body string }, error) {
|
||||
return &struct{ Body string }{Body: "public"}, nil
|
||||
})
|
||||
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "gated-get",
|
||||
Method: "GET",
|
||||
Path: "/gated-read",
|
||||
Metadata: map[string]any{"module": "test-mod", "permission": "read"},
|
||||
}, func(ctx context.Context, _ *struct{}) (*struct{ Body string }, error) {
|
||||
return &struct{ Body string }{Body: "gated-read"}, nil
|
||||
})
|
||||
|
||||
huma.Register(api, huma.Operation{
|
||||
OperationID: "gated-post",
|
||||
Method: "POST",
|
||||
Path: "/gated-write",
|
||||
Metadata: map[string]any{"module": "test-mod", "permission": "write"},
|
||||
}, func(ctx context.Context, _ *struct{}) (*struct{ Body string }, error) {
|
||||
return &struct{ Body string }{Body: "gated-write"}, nil
|
||||
})
|
||||
|
||||
// 1. Test public route
|
||||
resp := api.Get("/public")
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("public GET: got status %d, want %d", resp.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
// 2. Test gated route without cookie -> 401 Unauthorized
|
||||
resp = api.Get("/gated-read")
|
||||
if resp.Code != http.StatusUnauthorized {
|
||||
t.Errorf("gated GET no cookie: got status %d, want %d", resp.Code, http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
// 3. Test gated route with invalid cookie -> 401 Unauthorized
|
||||
resp = api.Get("/gated-read", "Cookie: nadir_session_id=invalid")
|
||||
if resp.Code != http.StatusUnauthorized {
|
||||
t.Errorf("gated GET invalid cookie: got status %d, want %d", resp.Code, http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
// Create valid session
|
||||
token, err := sessions.Create("alice")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// 4. Test gated route with valid cookie -> 200 OK
|
||||
resp = api.Get("/gated-read", "Cookie: nadir_session_id="+token)
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("gated GET valid cookie: got status %d, want %d", resp.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
// 5. Test CSRF violation: POST with mismatched Origin header -> 403 Forbidden
|
||||
resp = api.Post("/gated-write", "Cookie: nadir_session_id="+token, "Origin: http://evil.com", "Host: example.com", struct{}{})
|
||||
if resp.Code != http.StatusForbidden {
|
||||
t.Errorf("CSRF mismatched Origin: got status %d, want %d", resp.Code, http.StatusForbidden)
|
||||
}
|
||||
|
||||
// 6. Test CSRF success: POST with matching Origin header -> 200 OK
|
||||
resp = api.Post("/gated-write", "Cookie: nadir_session_id="+token, "Origin: http://example.com", "Host: example.com", struct{}{})
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("CSRF matching Origin: got status %d, want %d", resp.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
// 7. Test gated route with unauthorized user -> 403 Forbidden
|
||||
tokenBob, err := sessions.Create("bob")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
resp = api.Get("/gated-read", "Cookie: nadir_session_id="+tokenBob)
|
||||
if resp.Code != http.StatusForbidden {
|
||||
t.Errorf("bob unauthorized GET: got status %d, want %d", resp.Code, http.StatusForbidden)
|
||||
}
|
||||
|
||||
// 8. Bearer token for an assigned name -> 200 OK
|
||||
rawToken, err := tokenStore.Create("dash")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
resp = api.Get("/gated-read", "Authorization: Bearer "+rawToken)
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("valid bearer GET: got status %d, want %d", resp.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
// 9. Bogus bearer token -> 401 Unauthorized
|
||||
resp = api.Get("/gated-read", "Authorization: Bearer nad_deadbeef")
|
||||
if resp.Code != http.StatusUnauthorized {
|
||||
t.Errorf("bogus bearer GET: got status %d, want %d", resp.Code, http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
// 10. Bearer token with no role assignment -> 403 Forbidden
|
||||
rawUnassigned, err := tokenStore.Create("orphan")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
resp = api.Get("/gated-read", "Authorization: Bearer "+rawUnassigned)
|
||||
if resp.Code != http.StatusForbidden {
|
||||
t.Errorf("unassigned bearer GET: got status %d, want %d", resp.Code, http.StatusForbidden)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,71 @@
|
||||
package rbac
|
||||
|
||||
type Permission string
|
||||
|
||||
const (
|
||||
Read Permission = "read"
|
||||
Write Permission = "write"
|
||||
// Root is the high-impact tier: destructive or irreversible operations
|
||||
// (reboot, shutdown, account deletion, firewall flush, …) that callers
|
||||
// should be able to grant separately from routine writes.
|
||||
Root Permission = "root"
|
||||
All Permission = "*" // wildcard: matches any permission
|
||||
)
|
||||
|
||||
// Wildcard is the module-key value that matches all modules in a Role's grants.
|
||||
const Wildcard = "*"
|
||||
|
||||
type Role struct {
|
||||
Name string
|
||||
ModuleGrants map[string][]Permission // module ID (or "*") -> permissions (each may be "*")
|
||||
}
|
||||
|
||||
type RBAC struct {
|
||||
roles map[string]Role
|
||||
userRoles map[string][]string
|
||||
}
|
||||
|
||||
func New() *RBAC {
|
||||
return &RBAC{
|
||||
roles: make(map[string]Role),
|
||||
userRoles: make(map[string][]string),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RBAC) DefineRole(role Role) {
|
||||
r.roles[role.Name] = role
|
||||
}
|
||||
|
||||
// RoleExists reports whether a role with the given name has been defined.
|
||||
func (r *RBAC) RoleExists(name string) bool {
|
||||
_, ok := r.roles[name]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (r *RBAC) AssignRole(username, roleName string) {
|
||||
r.userRoles[username] = append(r.userRoles[username], roleName)
|
||||
}
|
||||
|
||||
// Can checks whether the user holds any role granting (module, perm),
|
||||
// honoring "*" wildcards on both the module key and inside the permission list.
|
||||
func (r *RBAC) Can(username, module string, perm Permission) bool {
|
||||
for _, roleName := range r.userRoles[username] {
|
||||
role, ok := r.roles[roleName]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
// Check the exact module key AND the wildcard module key.
|
||||
for _, key := range []string{module, Wildcard} {
|
||||
grants, ok := role.ModuleGrants[key]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for _, p := range grants {
|
||||
if p == perm || p == All {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -0,0 +1,84 @@
|
||||
package rbac
|
||||
|
||||
import "testing"
|
||||
|
||||
// build sets up an RBAC store with the given roles, then assigns them all to
|
||||
// user "u".
|
||||
func build(t *testing.T, roles ...Role) *RBAC {
|
||||
t.Helper()
|
||||
r := New()
|
||||
for _, role := range roles {
|
||||
r.DefineRole(role)
|
||||
r.AssignRole("u", role.Name)
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func TestCan(t *testing.T) {
|
||||
admin := Role{Name: "admin", ModuleGrants: map[string][]Permission{Wildcard: {All}}}
|
||||
reader := Role{Name: "reader", ModuleGrants: map[string][]Permission{Wildcard: {Read}}}
|
||||
sysWrite := Role{Name: "sysw", ModuleGrants: map[string][]Permission{"system": {Read, Write}}}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
role Role
|
||||
module string
|
||||
perm Permission
|
||||
want bool
|
||||
}{
|
||||
{"wildcard module + wildcard perm", admin, "system", Root, true},
|
||||
{"wildcard module + wildcard perm, other module", admin, "services", Write, true},
|
||||
{"wildcard module, read only, read", reader, "anything", Read, true},
|
||||
{"wildcard module, read only, write denied", reader, "anything", Write, false},
|
||||
{"exact module exact perm", sysWrite, "system", Write, true},
|
||||
{"exact module missing perm", sysWrite, "system", Root, false},
|
||||
{"exact module, different module denied", sysWrite, "services", Read, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := build(t, tt.role)
|
||||
if got := r.Can("u", tt.module, tt.perm); got != tt.want {
|
||||
t.Errorf("Can(u, %q, %q) = %v, want %v", tt.module, tt.perm, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCanNoRolesDenied(t *testing.T) {
|
||||
r := New()
|
||||
if r.Can("nobody", "system", Read) {
|
||||
t.Fatal("user with no roles was granted access")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCanAssignedUndefinedRoleSkipped(t *testing.T) {
|
||||
r := New()
|
||||
r.AssignRole("u", "ghost") // never defined
|
||||
if r.Can("u", "system", Read) {
|
||||
t.Fatal("undefined role granted access")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCanUnionOfRoles(t *testing.T) {
|
||||
r := build(t,
|
||||
Role{Name: "a", ModuleGrants: map[string][]Permission{"system": {Read}}},
|
||||
Role{Name: "b", ModuleGrants: map[string][]Permission{"services": {Write}}},
|
||||
)
|
||||
if !r.Can("u", "system", Read) || !r.Can("u", "services", Write) {
|
||||
t.Fatal("union of roles not honored")
|
||||
}
|
||||
if r.Can("u", "system", Write) {
|
||||
t.Fatal("permission leaked across modules")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoleExists(t *testing.T) {
|
||||
r := New()
|
||||
r.DefineRole(Role{Name: "x"})
|
||||
if !r.RoleExists("x") {
|
||||
t.Error("defined role reported missing")
|
||||
}
|
||||
if r.RoleExists("y") {
|
||||
t.Error("undefined role reported present")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIDbDCCAlSgAwIBAgIUfPrui00oPGw5okXJNXo0RP4dEMEwDQYJKoZIhvcNAQEL
|
||||
BQAwLjEYMBYGA1UECgwPbmFkaXItZGV2LWxvY2FsMRIwEAYDVQQDDAlsb2NhbGhv
|
||||
c3QwHhcNMjYwNjIwMDgwNDQ1WhcNMjcwNjIwMDgwNDQ1WjAuMRgwFgYDVQQKDA9u
|
||||
YWRpci1kZXYtbG9jYWwxEjAQBgNVBAMMCWxvY2FsaG9zdDCCASIwDQYJKoZIhvcN
|
||||
AQEBBQADggEPADCCAQoCggEBALifDPyGxdt32aP/QKbq+DISfwNzurcNu9hCqQT9
|
||||
keMgarPGnLd7J1aI6c12ypVz0DtFwfj9ALBTS7ga1x2FTo6cZtCF9tA0v6cpCEQs
|
||||
mC0ecv8VOnWD9hy9teh6a/Elc48FMg1aug2/bPdHiKCfUT7yd7A7YEOl/o1RZj3A
|
||||
1MkJllJ9CrHpgjmZBuypI21V0pYl1Wxwh3NvtMOAEkOp2OWZFsltzHoAtVbzPTR+
|
||||
5iDmrPUrmzToth0cakqBZlqlvU1e9bmSdiLdUj3Z0poZouGyjuuEixZzX3UJJwsM
|
||||
DI8juRbWJioKjl7P+M5D6rZyB0RoO1n4N33lTluMqBeH8M8CAwEAAaOBgTB/MB0G
|
||||
A1UdDgQWBBT9ugmFk9gRxY8Gkys7U4sYiiHiBzAfBgNVHSMEGDAWgBT9ugmFk9gR
|
||||
xY8Gkys7U4sYiiHiBzAPBgNVHRMBAf8EBTADAQH/MCwGA1UdEQQlMCOCCWxvY2Fs
|
||||
aG9zdIcEfwAAAYcQAAAAAAAAAAAAAAAAAAAAATANBgkqhkiG9w0BAQsFAAOCAQEA
|
||||
nT3ca+8u8uhPlqj4M/uLwdxYsvZbYDRkAs058aXSjFl2bvF+jYqJhbO9YR5SkqPk
|
||||
fFXG2/d7gtQZAMpgfKI3fIaEkWGUvj98MDqTbSfksspbTpIQUBssV8eBRUi+L74f
|
||||
Midm6Ua6+M0kqCCwyCVs6befTBSnmxUa/xsMfHwaDH2YoAVxySnuT5/lU6y+X5fP
|
||||
eAsGokINX2rZNBZbQofQwJ5y1rfPye9u6giPMcyEjnyHUvHuNUvBcxCfNy8gvs7T
|
||||
+/kUbO1uQFRm1HcXh5dfdkb1QjottH1JGoF0AA3l1XqP6m9EdtFldt7WyIHqr7NJ
|
||||
D+qOxooyToC2JLCaxDAoZw==
|
||||
-----END CERTIFICATE-----
|
||||
+28
@@ -0,0 +1,28 @@
|
||||
-----BEGIN PRIVATE KEY-----
|
||||
MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC4nwz8hsXbd9mj
|
||||
/0Cm6vgyEn8Dc7q3DbvYQqkE/ZHjIGqzxpy3eydWiOnNdsqVc9A7RcH4/QCwU0u4
|
||||
GtcdhU6OnGbQhfbQNL+nKQhELJgtHnL/FTp1g/YcvbXoemvxJXOPBTINWroNv2z3
|
||||
R4ign1E+8newO2BDpf6NUWY9wNTJCZZSfQqx6YI5mQbsqSNtVdKWJdVscIdzb7TD
|
||||
gBJDqdjlmRbJbcx6ALVW8z00fuYg5qz1K5s06LYdHGpKgWZapb1NXvW5knYi3VI9
|
||||
2dKaGaLhso7rhIsWc191CScLDAyPI7kW1iYqCo5ez/jOQ+q2cgdEaDtZ+Dd95U5b
|
||||
jKgXh/DPAgMBAAECggEAAWckAB8+Dabhfn+IDDyo2iiN0obkmlN+Y+xNwH30x9cN
|
||||
OIR/2F0VNXEg5bDLZUtV/71N9ghmIvDfGG0LyWuj5y2FEnySHY7pDeof5/S2y1D5
|
||||
6rpMkWwJSLqgUT3s6A4yzJlrgfJ4i3Yy68YdYasUQPgytKIe3yS5xHUj48A9XbGz
|
||||
ppbkHg76WuU9AUhCIgfSq0Xzz6cYv09eFx8ueiBlB2fGbjTs0j5CMlxLFkciNrYZ
|
||||
2VNdLt5UuNvQFjj5gH0nWe76WAb2PjeQBzgiFdRKGh19AuClWG4URXLvnScGvSti
|
||||
DuDY0QVusTCLeUJXx8bjxlK7wb7SKDGFGxjp7kr9gQKBgQDfWQG4Mts0fegxbsPy
|
||||
WM42WQvSfxn6SSVDnoxAG4TZ0+FpaUKQKXgQ0z3sC8P78obE2aC+cPKWK8ZWHNt7
|
||||
6yaflG8yn4hodfXqb2JVp4WltXwPvG9s2+GnwFqs6yooXiRtnMp99ZUpqvnCg1i7
|
||||
e1mGe3nDBVMIquhP0VoVbFBgoQKBgQDTnKxO7IA2JtLQFWgQCDcO0uL2yvGXPLze
|
||||
5pyjOCVuNUnA60ubVBdS+bB14KYZjdPLnX08QqtIdKjessHRAi36NStAA8cWUoeq
|
||||
O10dWjmM5oMGtOnSn564+fj7nwoLKLWq9PQq18zuJll4MU9ntycgjAL9vTc5nKoT
|
||||
0BlepJMrbwKBgQDbx5BDm/fM3aDhE+hJ0E2LeXCCwIPloJjEw32rj+jZGQCVY/kW
|
||||
N1ho5hXm82T1xiAMEUN2Y1qzn3vaPSdV933YRo5tuELY2EsXWGfhdamz+LSOH5Vd
|
||||
/7k8A7K2ueqQMqOSIVm5PTJ9ADwpxmpIgwcDqPmWiOS+gL9927rTnfQyQQKBgHUG
|
||||
WNgQvFq2H7GJlRIAqQoen/uhgfeEVGLkn8032KNY/t+cgCR3Xaq6gNa/lLvfDji1
|
||||
cLOpnvWj5lu5+atvjCOp0bBGJox2uaXvzG/WHKuKMv27gO/E7E8ZlpL4geJn8geI
|
||||
DZu/2gn91U693k7aH95E78aJJIhM1lW8qLsJQoYrAoGACZ5kJQcyasDCTh+cw/dB
|
||||
B3qHhu9R416hr3OZrqDNHfqpI8D26ON87LECPg+veOCuZ3cmj/j6cCuQl1RXabxA
|
||||
B0/kWfPc5M+WU9wgvRuYfBRfsNBVUiz3anEvo74rzFoUdl6rTg2OE2TtBz4ElITr
|
||||
Ju5DB/Mki7nr8L0FHsJzMmY=
|
||||
-----END PRIVATE KEY-----
|
||||
Reference in New Issue
Block a user