package api_test import ( "bytes" "context" "encoding/json" "net/http" "net/http/httptest" "path/filepath" "strconv" "testing" "time" "github.com/go-chi/chi/v5" "vetting/internal/api" "vetting/internal/db" "vetting/internal/events" "vetting/internal/model" "vetting/internal/orchestrator" "vetting/internal/store" ) func setupAgent(t *testing.T) (*api.Agent, int64, string) { t.Helper() path := filepath.Join(t.TempDir(), "vetting.db") conn, err := db.Open(path) if err != nil { t.Fatalf("open db: %v", err) } t.Cleanup(func() { _ = conn.Close() }) hosts := &store.Hosts{DB: conn} runs := &store.Runs{DB: conn} meas := &store.Measurements{DB: conn} subSteps := &store.SubSteps{DB: conn} hostID, err := hosts.Create(context.Background(), model.Host{ Name: "t-host", MAC: "aa:bb:cc:dd:ee:01", WoLBroadcastIP: "10.0.0.255", WoLPort: 9, ExpectedSpecYAML: "memory:\n total_gib: 16\n", }) if err != nil { t.Fatalf("create host: %v", err) } plain, hash, err := orchestrator.IssueRunToken() if err != nil { t.Fatalf("issue token: %v", err) } runID, err := runs.Create(context.Background(), hostID, hash, false) if err != nil { t.Fatalf("create run: %v", err) } return &api.Agent{ Hosts: hosts, Runs: runs, Measurements: meas, SubSteps: subSteps, }, runID, plain } func routedRequest(runID int64, method, path string, body []byte) *http.Request { req := httptest.NewRequest(method, path, bytes.NewReader(body)) // chi.URLParam is read from chi's context routing; fake that here. rctx := chi.NewRouteContext() rctx.URLParams.Add("id", strconv.FormatInt(runID, 10)) return req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) } func TestSensorPersistsBatch(t *testing.T) { a, runID, token := setupAgent(t) batch := api.SensorBatch{Samples: []api.SensorSample{ {Kind: "thermal", Key: "cpu", Value: 47.5, Unit: "C"}, {Kind: "iperf", Key: "throughput_mbps", Value: 938.2, Unit: "Mbps"}, }} buf, _ := json.Marshal(batch) req := routedRequest(runID, http.MethodPost, "/api/v1/runs/"+strconv.FormatInt(runID, 10)+"/sensor", buf) req.Header.Set("Authorization", "Bearer "+token) req.Header.Set("Content-Type", "application/json") rr := httptest.NewRecorder() a.Sensor(rr, req) if rr.Code != http.StatusOK { t.Fatalf("status = %d, body = %q", rr.Code, rr.Body.String()) } rows, err := a.Measurements.ListForRun(context.Background(), runID) if err != nil { t.Fatalf("ListForRun: %v", err) } if len(rows) != 2 { t.Fatalf("expected 2 measurements, got %d", len(rows)) } } func TestSensorRejectsBadToken(t *testing.T) { a, runID, _ := setupAgent(t) body, _ := json.Marshal(api.SensorBatch{}) req := routedRequest(runID, http.MethodPost, "/api/v1/runs/"+strconv.FormatInt(runID, 10)+"/sensor", body) req.Header.Set("Authorization", "Bearer wrong-token") rr := httptest.NewRecorder() a.Sensor(rr, req) if rr.Code != http.StatusUnauthorized { t.Fatalf("status = %d, want 401", rr.Code) } } // TestHeartbeatRebootWhenCompleted: once the orchestrator has flipped // the run into Completed, the next heartbeat response must carry // cmd=reboot so the agent reboots the host back to local disk. func TestHeartbeatRebootWhenCompleted(t *testing.T) { a, runID, token := setupAgent(t) // Wire a runner so Heartbeat's TouchHeartbeat call doesn't nil-panic. a.Runner = &orchestrator.Runner{Runs: a.Runs, Hosts: a.Hosts, Stages: &store.Stages{DB: a.Runs.DB}, EventHub: events.NewHub()} if err := a.Runs.SetState(context.Background(), runID, model.StateCompleted); err != nil { t.Fatalf("set state: %v", err) } req := routedRequest(runID, http.MethodPost, "/api/v1/runs/"+strconv.FormatInt(runID, 10)+"/heartbeat", nil) req.Header.Set("Authorization", "Bearer "+token) rr := httptest.NewRecorder() a.Heartbeat(rr, req) if rr.Code != http.StatusOK { t.Fatalf("status = %d, body = %s", rr.Code, rr.Body.String()) } var resp map[string]any if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil { t.Fatalf("decode: %v", err) } if resp["cmd"] != "reboot" { t.Fatalf("cmd = %v, want reboot", resp["cmd"]) } } // TestHeartbeatRebootWhenCancelledFromHold: operator hit Cancel on a // FailedHolding run. Because there's no in-flight stage subprocess (the // agent is parked in waitForOverride), the heartbeat must answer with // cmd=reboot — not cmd=cancel_stage which only makes sense mid-stage. // The FailedStage marker is the discriminator: set means we came // through hold; empty means a mid-stage cancel. func TestHeartbeatRebootWhenCancelledFromHold(t *testing.T) { a, runID, token := setupAgent(t) a.Runner = &orchestrator.Runner{Runs: a.Runs, Hosts: a.Hosts, Stages: &store.Stages{DB: a.Runs.DB}, EventHub: events.NewHub()} if err := a.Runs.SetFailedStage(context.Background(), runID, "Storage"); err != nil { t.Fatalf("set failed stage: %v", err) } if err := a.Runs.SetState(context.Background(), runID, model.StateCancelled); err != nil { t.Fatalf("set state: %v", err) } req := routedRequest(runID, http.MethodPost, "/api/v1/runs/"+strconv.FormatInt(runID, 10)+"/heartbeat", nil) req.Header.Set("Authorization", "Bearer "+token) rr := httptest.NewRecorder() a.Heartbeat(rr, req) if rr.Code != http.StatusOK { t.Fatalf("status = %d, body = %s", rr.Code, rr.Body.String()) } var resp map[string]any if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil { t.Fatalf("decode: %v", err) } if resp["cmd"] != "reboot" { t.Fatalf("cmd = %v, want reboot", resp["cmd"]) } } // TestHeartbeatCancelStageWhenCancelledMidRun: the mid-stage cancel // path (no FailedStage marker) still answers cmd=cancel_stage so the // agent kills its in-flight subprocess before powering off. This is // the pre-existing behaviour; the hold-cancel branch is additive. func TestHeartbeatCancelStageWhenCancelledMidRun(t *testing.T) { a, runID, token := setupAgent(t) a.Runner = &orchestrator.Runner{Runs: a.Runs, Hosts: a.Hosts, Stages: &store.Stages{DB: a.Runs.DB}, EventHub: events.NewHub()} if err := a.Runs.SetState(context.Background(), runID, model.StateCancelled); err != nil { t.Fatalf("set state: %v", err) } req := routedRequest(runID, http.MethodPost, "/api/v1/runs/"+strconv.FormatInt(runID, 10)+"/heartbeat", nil) req.Header.Set("Authorization", "Bearer "+token) rr := httptest.NewRecorder() a.Heartbeat(rr, req) if rr.Code != http.StatusOK { t.Fatalf("status = %d, body = %s", rr.Code, rr.Body.String()) } var resp map[string]any if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil { t.Fatalf("decode: %v", err) } if resp["cmd"] != "cancel_stage" { t.Fatalf("cmd = %v, want cancel_stage", resp["cmd"]) } } // TestResult_RejectsMismatchedStage is the silent-skip guard's unit // test. The Orion failure mode: agent crashes mid-CPUStress, systemd // restarts it, restarted agent replays Inventory and /results it. // Before the guard, the orchestrator advanced StateCPUStress → Storage // on TriggerStageCompleted; CPUStress got marked passed without ever // running. Guard's contract: if body.Stage doesn't match the stage the // run is in, reject with 409 and park the run in FailedHolding with a // failed_stage that names *what* was reported vs. what was expected. func TestResult_RejectsMismatchedStage(t *testing.T) { a, runID, token := setupAgent(t) a.Runner = &orchestrator.Runner{Runs: a.Runs, Hosts: a.Hosts, Stages: &store.Stages{DB: a.Runs.DB}, EventHub: events.NewHub()} // Park the run in CPUStress — the state Orion was in when its // agent crashed. if err := a.Runs.SetState(context.Background(), runID, model.StateCPUStress); err != nil { t.Fatalf("set state: %v", err) } // Restarted agent's hardcoded-Inventory-first behavior: it replays // Inventory and posts a passed result for it. body, _ := json.Marshal(map[string]any{ "stage": "Inventory", "passed": true, }) req := routedRequest(runID, http.MethodPost, "/api/v1/runs/"+strconv.FormatInt(runID, 10)+"/result", body) req.Header.Set("Authorization", "Bearer "+token) req.Header.Set("Content-Type", "application/json") rr := httptest.NewRecorder() a.Result(rr, req) if rr.Code != http.StatusConflict { t.Fatalf("status = %d, want 409; body = %s", rr.Code, rr.Body.String()) } after, err := a.Runs.Get(context.Background(), runID) if err != nil { t.Fatalf("get run: %v", err) } if after.State != model.StateFailedHolding { t.Fatalf("run state = %q, want FailedHolding", after.State) } if after.FailedStage == "" { t.Fatalf("failed_stage is empty; expected mismatch label") } // The label must name both sides so the operator can see the // skew without digging through logs. for _, want := range []string{"Inventory", "CPUStress"} { if !bytes.Contains([]byte(after.FailedStage), []byte(want)) { t.Errorf("failed_stage %q missing %q", after.FailedStage, want) } } } // TestResult_AcceptsMatchingStage confirms the guard's complement: when // the agent reports the stage the run is actually in, /result advances // the pipeline normally. Without this, a too-strict guard could reject // every result and freeze all runs. func TestResult_AcceptsMatchingStage(t *testing.T) { a, runID, token := setupAgent(t) a.Runner = &orchestrator.Runner{Runs: a.Runs, Hosts: a.Hosts, Stages: &store.Stages{DB: a.Runs.DB}, EventHub: events.NewHub()} stages := &store.Stages{DB: a.Runs.DB} if err := stages.Seed(context.Background(), runID); err != nil { t.Fatalf("seed stages: %v", err) } if err := a.Runs.SetState(context.Background(), runID, model.StateSMART); err != nil { t.Fatalf("set state: %v", err) } body, _ := json.Marshal(map[string]any{ "stage": "SMART", "passed": true, }) req := routedRequest(runID, http.MethodPost, "/api/v1/runs/"+strconv.FormatInt(runID, 10)+"/result", body) req.Header.Set("Authorization", "Bearer "+token) req.Header.Set("Content-Type", "application/json") rr := httptest.NewRecorder() a.Result(rr, req) if rr.Code != http.StatusOK { t.Fatalf("status = %d, want 200; body = %s", rr.Code, rr.Body.String()) } after, err := a.Runs.Get(context.Background(), runID) if err != nil { t.Fatalf("get run: %v", err) } if after.State != model.StateCPUStress { t.Fatalf("run state = %q, want CPUStress after SMART pass", after.State) } } // TestResult_PersistsSubSteps covers the /result handler's contract for // the new sub_steps table: when the agent includes a sub_steps array in // the POST body, each entry lands in the table with an ordinal equal to // its slice index, state derived from passed/skipped, and timestamps // parsed from RFC3339. The guard must let the call through (matching // stage) and sub-steps are written *after* CompleteStage so a persistence // error doesn't wedge the whole run. func TestResult_PersistsSubSteps(t *testing.T) { a, runID, token := setupAgent(t) a.Runner = &orchestrator.Runner{Runs: a.Runs, Hosts: a.Hosts, Stages: &store.Stages{DB: a.Runs.DB}, EventHub: events.NewHub()} stages := &store.Stages{DB: a.Runs.DB} if err := stages.Seed(context.Background(), runID); err != nil { t.Fatalf("seed stages: %v", err) } if err := a.Runs.SetState(context.Background(), runID, model.StateCPUStress); err != nil { t.Fatalf("set state: %v", err) } start := time.Date(2026, 4, 18, 13, 0, 0, 0, time.UTC) end := start.Add(3 * time.Minute) body, _ := json.Marshal(map[string]any{ "stage": "CPUStress", "passed": true, "sub_steps": []map[string]any{ { "name": "CPU pass", "passed": true, "started_at": start.Format(time.RFC3339Nano), "completed_at": end.Format(time.RFC3339Nano), "summary": json.RawMessage(`{"elapsed_secs":180}`), }, { "name": "Memory pass", "passed": false, "started_at": end.Format(time.RFC3339Nano), "completed_at": end.Add(2 * time.Minute).Format(time.RFC3339Nano), }, }, }) req := routedRequest(runID, http.MethodPost, "/api/v1/runs/"+strconv.FormatInt(runID, 10)+"/result", body) req.Header.Set("Authorization", "Bearer "+token) req.Header.Set("Content-Type", "application/json") rr := httptest.NewRecorder() a.Result(rr, req) if rr.Code != http.StatusOK { t.Fatalf("status = %d, want 200; body = %s", rr.Code, rr.Body.String()) } rows, err := a.SubSteps.ListForRun(context.Background(), runID) if err != nil { t.Fatalf("ListForRun: %v", err) } if len(rows) != 2 { t.Fatalf("got %d sub-steps, want 2", len(rows)) } if rows[0].Ordinal != 0 || rows[0].Name != "CPU pass" || rows[0].State != model.StagePassed { t.Fatalf("row[0] = %+v", rows[0]) } if rows[1].Ordinal != 1 || rows[1].Name != "Memory pass" || rows[1].State != model.StageFailed { t.Fatalf("row[1] = %+v", rows[1]) } if rows[0].StartedAt == nil || !rows[0].StartedAt.Equal(start) { t.Fatalf("row[0].StartedAt = %v, want %v", rows[0].StartedAt, start) } if rows[0].SummaryJSON != `{"elapsed_secs":180}` { t.Fatalf("row[0].SummaryJSON = %q", rows[0].SummaryJSON) } }