server.go

236 lines
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236
package platform

import (
	"context"
	"fmt"
	"log"
	"os"
	"os/exec"
	"strings"
	"time"
)

// DefaultSSHTimeout is the default timeout for SSH commands (60 seconds).
// Set Server.SSHTimeout to override per-server or use SSHWithTimeout for per-call.
const DefaultSSHTimeout = 60 * time.Second

// Server represents a cloud virtual machine
type Server struct {
	ID         string
	Name       string
	IP         string
	PrivateIP  string        // Private IP for VPC networking (empty if no VPC)
	Size       string        // e.g., "s-1vcpu-2gb"
	Region     string        // e.g., "nyc1"
	Status     string        // creating, active, off
	SSHTimeout time.Duration // Per-command SSH timeout (0 = DefaultSSHTimeout)

	// sshHook intercepts SSH calls for testing. If set, SSH and SSHWithTimeout
	// call this function instead of executing a real SSH command.
	sshHook func(args ...string) (string, error)
	// copyHook intercepts Copy calls for testing. If set, Copy calls this
	// function instead of executing a real SCP command.
	copyHook func(local, remote string) error
}

// sshOptions returns common SSH options with connection multiplexing.
// ControlMaster reuses connections to avoid rate limiting.
// ConnectTimeout prevents indefinite hangs on unreachable hosts.
// ServerAliveInterval detects dead connections.
func (s *Server) sshOptions() []string {
	return []string{
		"-o", "StrictHostKeyChecking=accept-new",
		"-o", "ControlMaster=auto",
		"-o", fmt.Sprintf("ControlPath=/tmp/ssh-congo-%s", s.IP),
		"-o", "ControlPersist=60",
		"-o", "ConnectTimeout=15",
		"-o", "ServerAliveInterval=10",
		"-o", "ServerAliveCountMax=3",
	}
}

// sshTimeout returns the configured SSH timeout or the default.
func (s *Server) sshTimeout() time.Duration {
	if s.SSHTimeout > 0 {
		return s.SSHTimeout
	}
	return DefaultSSHTimeout
}

// SSH executes a command on the server and returns output.
// Uses the configured SSHTimeout (default 60s) to prevent indefinite hangs.
// SECURITY: Arguments are shell-escaped to prevent command injection.
func (s *Server) SSH(args ...string) (string, error) {
	return s.SSHWithTimeout(s.sshTimeout(), args...)
}

// SSHWithTimeout executes a command on the server with a specific timeout.
// Use this for long-running operations like docker build that need more time.
func (s *Server) SSHWithTimeout(timeout time.Duration, args ...string) (string, error) {
	if s.sshHook != nil {
		return s.sshHook(args...)
	}

	ctx, cancel := context.WithTimeout(context.Background(), timeout)
	defer cancel()

	cmdArgs := append(s.sshOptions(), "root@"+s.IP)

	// Shell-escape each argument to prevent command injection
	escapedArgs := make([]string, len(args))
	for i, arg := range args {
		escapedArgs[i] = shellEscape(arg)
	}
	cmdArgs = append(cmdArgs, strings.Join(escapedArgs, " "))

	cmd := exec.CommandContext(ctx, "ssh", cmdArgs...)
	output, err := cmd.CombinedOutput()
	if ctx.Err() == context.DeadlineExceeded {
		return string(output), fmt.Errorf("ssh: command timed out after %s", timeout)
	}
	if err != nil {
		return string(output), fmt.Errorf("ssh: %w: %s", err, output)
	}
	return strings.TrimSpace(string(output)), nil
}

func fileExists(path string) bool {
	_, err := os.Stat(path)
	return err == nil
}

// shellEscape escapes a string for safe use in shell commands.
func shellEscape(s string) string {
	// If the string is simple (alphanumeric, dash, underscore, dot, slash, colon),
	// no escaping needed
	safe := true
	for _, c := range s {
		if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') ||
			(c >= '0' && c <= '9') || c == '-' || c == '_' ||
			c == '.' || c == '/' || c == ':' || c == '=' || c == ',') {
			safe = false
			break
		}
	}
	if safe && s != "" {
		return s
	}
	// Wrap in single quotes, escaping any single quotes within
	return "'" + strings.ReplaceAll(s, "'", "'\"'\"'") + "'"
}

// Copy uploads a file to the server via SCP.
// Uses a 5-minute timeout to prevent indefinite hangs on large files.
func (s *Server) Copy(local, remote string) error {
	if s.copyHook != nil {
		return s.copyHook(local, remote)
	}

	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
	defer cancel()

	cmdArgs := append(s.sshOptions(), local, "root@"+s.IP+":"+remote)
	cmd := exec.CommandContext(ctx, "scp", cmdArgs...)
	output, err := cmd.CombinedOutput()
	if ctx.Err() == context.DeadlineExceeded {
		return fmt.Errorf("scp: file upload timed out after 5m")
	}
	if err != nil {
		return fmt.Errorf("scp: %w: %s", err, output)
	}
	return nil
}

// Interactive opens an interactive SSH shell
func (s *Server) Interactive() error {
	cmdArgs := append(s.sshOptions(), "root@"+s.IP)
	cmd := exec.Command("ssh", cmdArgs...)
	cmd.Stdin = os.Stdin
	cmd.Stdout = os.Stdout
	cmd.Stderr = os.Stderr
	return cmd.Run()
}

// WaitForSSH waits until the server accepts SSH connections.
// Uses a real SSH command to verify full authentication works,
// not just a TCP dial (sshd may accept connections before it's
// ready to authenticate).
func (s *Server) WaitForSSH(timeout time.Duration) error {
	deadline := time.Now().Add(timeout)
	for time.Now().Before(deadline) {
		_, err := s.SSH("echo", "ready")
		if err == nil {
			return nil
		}
		time.Sleep(5 * time.Second)
	}
	return ErrTimeout
}

// Connect executes a command with stdin/stdout/stderr connected
func (s *Server) Connect(args ...string) error {
	cmdArgs := append(s.sshOptions(), "root@"+s.IP)
	if len(args) > 0 {
		escaped := make([]string, len(args))
		for i, a := range args {
			escaped[i] = shellEscape(a)
		}
		cmdArgs = append(cmdArgs, strings.Join(escaped, " "))
	}

	cmd := exec.Command("ssh", cmdArgs...)
	cmd.Stdin = os.Stdin
	cmd.Stdout = os.Stdout
	cmd.Stderr = os.Stderr
	return cmd.Run()
}

// Write writes data to a file on the server
func (s *Server) Write(remotePath string, data []byte, executable bool) error {
	tmp, err := os.CreateTemp("", "congo-*")
	if err != nil {
		return err
	}
	defer os.Remove(tmp.Name())

	if _, err := tmp.Write(data); err != nil {
		return err
	}
	tmp.Close()

	if err := s.Copy(tmp.Name(), remotePath); err != nil {
		return err
	}

	if executable {
		_, err := s.SSH("chmod", "+x", remotePath)
		return err
	}
	return nil
}

// RunScript uploads and runs a setup script on the server.
// If the script file doesn't exist locally, it's a no-op.
func (s *Server) RunScript(localPath string) error {
	if _, err := os.Stat(localPath); os.IsNotExist(err) {
		return nil
	}

	log.Printf("   Running %s...", localPath)
	if err := s.Copy(localPath, "/tmp/setup.sh"); err != nil {
		return fmt.Errorf("upload %s: %w", localPath, err)
	}

	output, err := s.SSH("bash", "/tmp/setup.sh")
	if err != nil {
		return fmt.Errorf("%s: %w\n%s", localPath, err, output)
	}

	for _, line := range strings.Split(strings.TrimSpace(output), "\n") {
		if line != "" {
			log.Printf("   %s", line)
		}
	}
	return nil
}