diff --git a/README.md b/README.md index 4a99d6e..5bba2b1 100644 --- a/README.md +++ b/README.md @@ -53,7 +53,6 @@ Edit `provisioning.yaml`: - `pxe.subnet` — LAN CIDR for proxy-DHCP - `proxmox.existing_node` — IP of any current cluster member - `proxmox.join_fingerprint` — from `pvecm status` on an existing node -- `credentials.ssh_public_key` — public key injected into new hosts - `credentials.root_password_hash` — `mkpasswd -m sha-512` - `infrastructure.base_url` — URL of the Infrastructure service - `infrastructure.server_type_map` — maps local type keys to Infrastructure IDs diff --git a/deploy/provisioning.example.yaml b/deploy/provisioning.example.yaml index 8e7f76d..767b6be 100644 --- a/deploy/provisioning.example.yaml +++ b/deploy/provisioning.example.yaml @@ -22,8 +22,6 @@ proxmox: join_fingerprint: "AA:BB:CC:DD:EE:FF:00:11:22:33:44:55:66:77:88:99:AA:BB:CC:DD:EE:FF:00:11:22:33:44:55:66:77:88:99" credentials: - ssh_private_key_path: "/etc/provisioning/keys/id_ed25519" - ssh_public_key: "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAEXAMPLE provisioning@homelab" root_password_hash: "$6$rounds=5000$randomsalt$hashedpasswordhere" infrastructure: diff --git a/internal/api/boot.go b/internal/api/boot.go index 8364cbb..a721509 100644 --- a/internal/api/boot.go +++ b/internal/api/boot.go @@ -82,7 +82,13 @@ func (a *BootAPI) AnswerFile(w http.ResponseWriter, r *http.Request) { a.Runner.Transition(r.Context(), host.ID, statemachine.TriggerAnswerServed) } - answer := pxe.GenerateAnswerFile(host, st, a.Config) + _, pubKey, _ := a.Hosts.GetEphemeralKey(r.Context(), host.ID) + if pubKey == "" { + http.Error(w, "no ephemeral key for host", http.StatusInternalServerError) + return + } + + answer := pxe.GenerateAnswerFile(host, st, a.Config, pubKey) w.Header().Set("Content-Type", "application/toml") w.Write([]byte(answer)) } diff --git a/internal/api/hosts.go b/internal/api/hosts.go index 4e0f766..73d2594 100644 --- a/internal/api/hosts.go +++ b/internal/api/hosts.go @@ -18,14 +18,15 @@ import ( ) type HostAPI struct { - Hosts *store.Hosts - Ops *store.Operations - Locks *store.Locks - Images *store.Images - Runner *orchestrator.Runner - PXE *pxe.Supervisor - Config *config.Config - ServerTypes *config.ServerTypeRegistry + Hosts *store.Hosts + Ops *store.Operations + Locks *store.Locks + Images *store.Images + Runner *orchestrator.Runner + Orchestrator *orchestrator.HostOrchestrator + PXE *pxe.Supervisor + Config *config.Config + ServerTypes *config.ServerTypeRegistry } func (a *HostAPI) List(w http.ResponseWriter, r *http.Request) { @@ -124,6 +125,12 @@ func (a *HostAPI) Rebuild(w http.ResponseWriter, r *http.Request) { return } + if err := a.Orchestrator.PrepareRebuild(r.Context(), host.ID); err != nil { + _ = a.Locks.Release(r.Context(), host.ID) + writeJSONErr(w, http.StatusInternalServerError, "failed to generate SSH key: "+err.Error()) + return + } + if _, err := a.Runner.Transition(r.Context(), host.ID, statemachine.TriggerRebuildRequested); err != nil { _ = a.Locks.Release(r.Context(), host.ID) writeJSONErr(w, http.StatusConflict, err.Error()) diff --git a/internal/api/smoke_test.go b/internal/api/smoke_test.go index 5bf6e57..77a109b 100644 --- a/internal/api/smoke_test.go +++ b/internal/api/smoke_test.go @@ -58,17 +58,6 @@ func newTestServer(t *testing.T) *httptest.Server { pxeSupervisor := pxe.NewSupervisor(pxe.SupervisorConfig{Enabled: false}) - hostAPI := &api.HostAPI{ - Hosts: hosts, - Ops: ops, - Locks: locks, - Images: images, - Runner: runner, - PXE: pxeSupervisor, - Config: cfg, - ServerTypes: serverTypes, - } - hostOrch := &orchestrator.HostOrchestrator{ Runner: runner, Hosts: hosts, @@ -79,6 +68,18 @@ func newTestServer(t *testing.T) *httptest.Server { ServerTypes: serverTypes, } + hostAPI := &api.HostAPI{ + Hosts: hosts, + Ops: ops, + Locks: locks, + Images: images, + Runner: runner, + Orchestrator: hostOrch, + PXE: pxeSupervisor, + Config: cfg, + ServerTypes: serverTypes, + } + bootAPI := &api.BootAPI{ Hosts: hosts, Images: images, @@ -89,15 +90,16 @@ func newTestServer(t *testing.T) *httptest.Server { } ui := &api.UI{ - Hosts: hosts, - Ops: ops, - Locks: locks, - Images: images, - Runner: runner, - Hub: hub, - PXE: pxeSupervisor, - Config: cfg, - ServerTypes: serverTypes, + Hosts: hosts, + Ops: ops, + Locks: locks, + Images: images, + Runner: runner, + Orchestrator: hostOrch, + Hub: hub, + PXE: pxeSupervisor, + Config: cfg, + ServerTypes: serverTypes, } router := httpserver.NewRouter(httpserver.Deps{ diff --git a/internal/api/ui.go b/internal/api/ui.go index 0b18227..19e5f72 100644 --- a/internal/api/ui.go +++ b/internal/api/ui.go @@ -19,15 +19,16 @@ import ( ) type UI struct { - Hosts *store.Hosts - Ops *store.Operations - Locks *store.Locks - Images *store.Images - Runner *orchestrator.Runner - Hub *events.Hub - PXE *pxe.Supervisor - Config *config.Config - ServerTypes *config.ServerTypeRegistry + Hosts *store.Hosts + Ops *store.Operations + Locks *store.Locks + Images *store.Images + Runner *orchestrator.Runner + Orchestrator *orchestrator.HostOrchestrator + Hub *events.Hub + PXE *pxe.Supervisor + Config *config.Config + ServerTypes *config.ServerTypeRegistry } func (u *UI) Dashboard(w http.ResponseWriter, r *http.Request) { @@ -128,6 +129,13 @@ func (u *UI) TriggerRebuild(w http.ResponseWriter, r *http.Request) { Kind: model.OpRebuildProxmox, }) _ = u.Locks.Acquire(r.Context(), host.ID, opID) + + if err := u.Orchestrator.PrepareRebuild(r.Context(), host.ID); err != nil { + _ = u.Locks.Release(r.Context(), host.ID) + http.Error(w, "Failed to prepare rebuild: "+err.Error(), http.StatusInternalServerError) + return + } + u.Runner.Transition(r.Context(), host.ID, statemachine.TriggerRebuildRequested) hosts, _ := u.Hosts.List(r.Context()) diff --git a/internal/config/config.go b/internal/config/config.go index d08b14d..f6b8b27 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -48,9 +48,7 @@ type Proxmox struct { } type Credentials struct { - SSHPrivateKeyPath string `yaml:"ssh_private_key_path"` - SSHPublicKey string `yaml:"ssh_public_key"` - RootPasswordHash string `yaml:"root_password_hash"` + RootPasswordHash string `yaml:"root_password_hash"` } type Infrastructure struct { diff --git a/internal/db/migrations/0002_ephemeral_ssh_key.sql b/internal/db/migrations/0002_ephemeral_ssh_key.sql new file mode 100644 index 0000000..21388dc --- /dev/null +++ b/internal/db/migrations/0002_ephemeral_ssh_key.sql @@ -0,0 +1,2 @@ +ALTER TABLE hosts ADD COLUMN ssh_private_key TEXT; +ALTER TABLE hosts ADD COLUMN ssh_public_key TEXT; diff --git a/internal/orchestrator/cluster.go b/internal/orchestrator/cluster.go index 539a639..8f39e88 100644 --- a/internal/orchestrator/cluster.go +++ b/internal/orchestrator/cluster.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "log" - "os" "time" "golang.org/x/crypto/ssh" @@ -14,41 +13,39 @@ type ClusterJoiner struct { ExistingNode string ClusterName string JoinFingerprint string - SSHKeyPath string } -func (c *ClusterJoiner) Join(ctx context.Context, hostIP string) error { - client, err := c.connect(hostIP) +func (c *ClusterJoiner) Join(ctx context.Context, hostIP string, privateKey string, publicKey string) error { + client, err := c.connect(hostIP, privateKey) if err != nil { return fmt.Errorf("ssh connect to %s: %w", hostIP, err) } defer client.Close() + // Join the cluster cmd := fmt.Sprintf("pvecm add %s --force", c.ExistingNode) log.Printf("cluster: running on %s: %s", hostIP, cmd) - - session, err := client.NewSession() - if err != nil { - return fmt.Errorf("ssh session: %w", err) - } - defer session.Close() - - output, err := session.CombinedOutput(cmd) - if err != nil { - return fmt.Errorf("pvecm add failed: %w\noutput: %s", err, string(output)) + if err := c.runCmd(client, cmd); err != nil { + return fmt.Errorf("pvecm add failed: %w", err) } log.Printf("cluster: %s joined successfully", hostIP) + + // Remove the ephemeral key from authorized_keys + escaped := escapeForSed(publicKey) + removeCmd := fmt.Sprintf(`sed -i '\|%s|d' /root/.ssh/authorized_keys`, escaped) + if err := c.runCmd(client, removeCmd); err != nil { + log.Printf("cluster: warning: failed to remove ephemeral key from %s: %v", hostIP, err) + } else { + log.Printf("cluster: ephemeral key removed from %s", hostIP) + } + return nil } -func (c *ClusterJoiner) connect(hostIP string) (*ssh.Client, error) { - keyData, err := os.ReadFile(c.SSHKeyPath) +func (c *ClusterJoiner) connect(hostIP string, privateKeyPEM string) (*ssh.Client, error) { + signer, err := ssh.ParsePrivateKey([]byte(privateKeyPEM)) if err != nil { - return nil, fmt.Errorf("read ssh key: %w", err) - } - signer, err := ssh.ParsePrivateKey(keyData) - if err != nil { - return nil, fmt.Errorf("parse ssh key: %w", err) + return nil, fmt.Errorf("parse ephemeral key: %w", err) } config := &ssh.ClientConfig{ User: "root", @@ -58,3 +55,31 @@ func (c *ClusterJoiner) connect(hostIP string) (*ssh.Client, error) { } return ssh.Dial("tcp", hostIP+":22", config) } + +func (c *ClusterJoiner) runCmd(client *ssh.Client, cmd string) error { + session, err := client.NewSession() + if err != nil { + return fmt.Errorf("ssh session: %w", err) + } + defer session.Close() + output, err := session.CombinedOutput(cmd) + if err != nil { + return fmt.Errorf("%w\noutput: %s", err, string(output)) + } + return nil +} + +func escapeForSed(s string) string { + // Trim trailing newline and escape sed delimiter + result := "" + for _, c := range s { + if c == '|' { + result += `\|` + } else if c == '\n' { + continue + } else { + result += string(c) + } + } + return result +} diff --git a/internal/orchestrator/host.go b/internal/orchestrator/host.go index 8f01eb6..571c4ce 100644 --- a/internal/orchestrator/host.go +++ b/internal/orchestrator/host.go @@ -22,6 +22,16 @@ type HostOrchestrator struct { ServerTypes *config.ServerTypeRegistry } +// PrepareRebuild generates an ephemeral SSH key pair and stores it on the host. +// The public key will be injected into the Proxmox answer file. +func (o *HostOrchestrator) PrepareRebuild(ctx context.Context, hostID int64) error { + kp, err := GenerateEphemeralKey() + if err != nil { + return err + } + return o.Hosts.SetEphemeralKey(ctx, hostID, kp.PrivateKey, kp.PublicKey) +} + func (o *HostOrchestrator) HandlePhoneHome(ctx context.Context, hostID int64, ip string, hardwareID string) { if err := o.Hosts.UpdateIP(ctx, hostID, ip, hardwareID); err != nil { log.Printf("host %d: failed to update IP: %v", hostID, err) @@ -47,17 +57,27 @@ func (o *HostOrchestrator) postPhoneHome(hostID int64, ip string, hardwareID str return } + privateKey, publicKey, err := o.Hosts.GetEphemeralKey(ctx, hostID) + if err != nil || privateKey == "" { + log.Printf("host %d: no ephemeral key available: %v", hostID, err) + o.Runner.FailHost(ctx, hostID, "no ephemeral SSH key") + return + } + if _, err := o.Runner.Transition(ctx, hostID, statemachine.TriggerClusterJoinStart); err != nil { log.Printf("host %d: cluster join start transition failed: %v", hostID, err) return } - if err := o.Cluster.Join(ctx, ip); err != nil { + if err := o.Cluster.Join(ctx, ip, privateKey, publicKey); err != nil { log.Printf("host %d: cluster join failed: %v", hostID, err) o.Runner.FailHost(ctx, hostID, "cluster join: "+err.Error()) return } + // Key has been removed from the remote host; clear it from the DB + _ = o.Hosts.ClearEphemeralKey(ctx, hostID) + if err := o.registerInfra(ctx, host, ip, hardwareID); err != nil { log.Printf("host %d: infra registration failed: %v", hostID, err) o.Runner.FailHost(ctx, hostID, "infra registration: "+err.Error()) diff --git a/internal/orchestrator/sshkey.go b/internal/orchestrator/sshkey.go new file mode 100644 index 0000000..462fa86 --- /dev/null +++ b/internal/orchestrator/sshkey.go @@ -0,0 +1,39 @@ +package orchestrator + +import ( + "crypto/ed25519" + "crypto/rand" + "encoding/pem" + "fmt" + + "golang.org/x/crypto/ssh" +) + +type KeyPair struct { + PrivateKey string + PublicKey string +} + +func GenerateEphemeralKey() (*KeyPair, error) { + pub, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + return nil, fmt.Errorf("generate ed25519 key: %w", err) + } + + sshPub, err := ssh.NewPublicKey(pub) + if err != nil { + return nil, fmt.Errorf("ssh public key: %w", err) + } + pubStr := string(ssh.MarshalAuthorizedKey(sshPub)) + + privBytes, err := ssh.MarshalPrivateKey(priv, "") + if err != nil { + return nil, fmt.Errorf("marshal private key: %w", err) + } + privStr := string(pem.EncodeToMemory(privBytes)) + + return &KeyPair{ + PrivateKey: privStr, + PublicKey: pubStr, + }, nil +} diff --git a/internal/pxe/answer.go b/internal/pxe/answer.go index b633c83..7282022 100644 --- a/internal/pxe/answer.go +++ b/internal/pxe/answer.go @@ -8,7 +8,7 @@ import ( "provisioning/internal/model" ) -func GenerateAnswerFile(host *model.Host, serverType model.ServerType, cfg *config.Config) string { +func GenerateAnswerFile(host *model.Host, serverType model.ServerType, cfg *config.Config, sshPublicKey string) string { var b strings.Builder b.WriteString("[global]\n") @@ -18,7 +18,7 @@ func GenerateAnswerFile(host *model.Host, serverType model.ServerType, cfg *conf b.WriteString(`mailto = "admin@thewrightserver.net"` + "\n") b.WriteString(`timezone = "America/Indiana/Indianapolis"` + "\n") b.WriteString(fmt.Sprintf("root-password-hashed = \"%s\"\n", cfg.Credentials.RootPasswordHash)) - b.WriteString(fmt.Sprintf("root-ssh-keys = [\"%s\"]\n", cfg.Credentials.SSHPublicKey)) + b.WriteString(fmt.Sprintf("root-ssh-keys = [\"%s\"]\n", strings.TrimSpace(sshPublicKey))) b.WriteString("\n") b.WriteString("[network]\n") diff --git a/internal/store/hosts.go b/internal/store/hosts.go index 1aee526..f89a55c 100644 --- a/internal/store/hosts.go +++ b/internal/store/hosts.go @@ -109,6 +109,25 @@ func (s *Hosts) UpdateInfraID(ctx context.Context, id int64, infraHostID int64) return err } +func (s *Hosts) SetEphemeralKey(ctx context.Context, id int64, privateKey, publicKey string) error { + _, err := s.DB.ExecContext(ctx, `UPDATE hosts SET ssh_private_key = ?, ssh_public_key = ? WHERE id = ?`, privateKey, publicKey, id) + return err +} + +func (s *Hosts) GetEphemeralKey(ctx context.Context, id int64) (privateKey, publicKey string, err error) { + var priv, pub sql.NullString + err = s.DB.QueryRowContext(ctx, `SELECT ssh_private_key, ssh_public_key FROM hosts WHERE id = ?`, id).Scan(&priv, &pub) + if err != nil { + return "", "", err + } + return priv.String, pub.String, nil +} + +func (s *Hosts) ClearEphemeralKey(ctx context.Context, id int64) error { + _, err := s.DB.ExecContext(ctx, `UPDATE hosts SET ssh_private_key = NULL, ssh_public_key = NULL WHERE id = ?`, id) + return err +} + func (s *Hosts) Delete(ctx context.Context, id int64) error { res, err := s.DB.ExecContext(ctx, `DELETE FROM hosts WHERE id = ?`, id) if err != nil {