server.go
236 lines1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
package platform
import (
"context"
"fmt"
"log"
"os"
"os/exec"
"strings"
"time"
)
// DefaultSSHTimeout is the default timeout for SSH commands (60 seconds).
// Set Server.SSHTimeout to override per-server or use SSHWithTimeout for per-call.
const DefaultSSHTimeout = 60 * time.Second
// Server represents a cloud virtual machine
type Server struct {
ID string
Name string
IP string
PrivateIP string // Private IP for VPC networking (empty if no VPC)
Size string // e.g., "s-1vcpu-2gb"
Region string // e.g., "nyc1"
Status string // creating, active, off
SSHTimeout time.Duration // Per-command SSH timeout (0 = DefaultSSHTimeout)
// sshHook intercepts SSH calls for testing. If set, SSH and SSHWithTimeout
// call this function instead of executing a real SSH command.
sshHook func(args ...string) (string, error)
// copyHook intercepts Copy calls for testing. If set, Copy calls this
// function instead of executing a real SCP command.
copyHook func(local, remote string) error
}
// sshOptions returns common SSH options with connection multiplexing.
// ControlMaster reuses connections to avoid rate limiting.
// ConnectTimeout prevents indefinite hangs on unreachable hosts.
// ServerAliveInterval detects dead connections.
func (s *Server) sshOptions() []string {
return []string{
"-o", "StrictHostKeyChecking=accept-new",
"-o", "ControlMaster=auto",
"-o", fmt.Sprintf("ControlPath=/tmp/ssh-congo-%s", s.IP),
"-o", "ControlPersist=60",
"-o", "ConnectTimeout=15",
"-o", "ServerAliveInterval=10",
"-o", "ServerAliveCountMax=3",
}
}
// sshTimeout returns the configured SSH timeout or the default.
func (s *Server) sshTimeout() time.Duration {
if s.SSHTimeout > 0 {
return s.SSHTimeout
}
return DefaultSSHTimeout
}
// SSH executes a command on the server and returns output.
// Uses the configured SSHTimeout (default 60s) to prevent indefinite hangs.
// SECURITY: Arguments are shell-escaped to prevent command injection.
func (s *Server) SSH(args ...string) (string, error) {
return s.SSHWithTimeout(s.sshTimeout(), args...)
}
// SSHWithTimeout executes a command on the server with a specific timeout.
// Use this for long-running operations like docker build that need more time.
func (s *Server) SSHWithTimeout(timeout time.Duration, args ...string) (string, error) {
if s.sshHook != nil {
return s.sshHook(args...)
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
cmdArgs := append(s.sshOptions(), "root@"+s.IP)
// Shell-escape each argument to prevent command injection
escapedArgs := make([]string, len(args))
for i, arg := range args {
escapedArgs[i] = shellEscape(arg)
}
cmdArgs = append(cmdArgs, strings.Join(escapedArgs, " "))
cmd := exec.CommandContext(ctx, "ssh", cmdArgs...)
output, err := cmd.CombinedOutput()
if ctx.Err() == context.DeadlineExceeded {
return string(output), fmt.Errorf("ssh: command timed out after %s", timeout)
}
if err != nil {
return string(output), fmt.Errorf("ssh: %w: %s", err, output)
}
return strings.TrimSpace(string(output)), nil
}
func fileExists(path string) bool {
_, err := os.Stat(path)
return err == nil
}
// shellEscape escapes a string for safe use in shell commands.
func shellEscape(s string) string {
// If the string is simple (alphanumeric, dash, underscore, dot, slash, colon),
// no escaping needed
safe := true
for _, c := range s {
if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') ||
(c >= '0' && c <= '9') || c == '-' || c == '_' ||
c == '.' || c == '/' || c == ':' || c == '=' || c == ',') {
safe = false
break
}
}
if safe && s != "" {
return s
}
// Wrap in single quotes, escaping any single quotes within
return "'" + strings.ReplaceAll(s, "'", "'\"'\"'") + "'"
}
// Copy uploads a file to the server via SCP.
// Uses a 5-minute timeout to prevent indefinite hangs on large files.
func (s *Server) Copy(local, remote string) error {
if s.copyHook != nil {
return s.copyHook(local, remote)
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
cmdArgs := append(s.sshOptions(), local, "root@"+s.IP+":"+remote)
cmd := exec.CommandContext(ctx, "scp", cmdArgs...)
output, err := cmd.CombinedOutput()
if ctx.Err() == context.DeadlineExceeded {
return fmt.Errorf("scp: file upload timed out after 5m")
}
if err != nil {
return fmt.Errorf("scp: %w: %s", err, output)
}
return nil
}
// Interactive opens an interactive SSH shell
func (s *Server) Interactive() error {
cmdArgs := append(s.sshOptions(), "root@"+s.IP)
cmd := exec.Command("ssh", cmdArgs...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
// WaitForSSH waits until the server accepts SSH connections.
// Uses a real SSH command to verify full authentication works,
// not just a TCP dial (sshd may accept connections before it's
// ready to authenticate).
func (s *Server) WaitForSSH(timeout time.Duration) error {
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
_, err := s.SSH("echo", "ready")
if err == nil {
return nil
}
time.Sleep(5 * time.Second)
}
return ErrTimeout
}
// Connect executes a command with stdin/stdout/stderr connected
func (s *Server) Connect(args ...string) error {
cmdArgs := append(s.sshOptions(), "root@"+s.IP)
if len(args) > 0 {
escaped := make([]string, len(args))
for i, a := range args {
escaped[i] = shellEscape(a)
}
cmdArgs = append(cmdArgs, strings.Join(escaped, " "))
}
cmd := exec.Command("ssh", cmdArgs...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
// Write writes data to a file on the server
func (s *Server) Write(remotePath string, data []byte, executable bool) error {
tmp, err := os.CreateTemp("", "congo-*")
if err != nil {
return err
}
defer os.Remove(tmp.Name())
if _, err := tmp.Write(data); err != nil {
return err
}
tmp.Close()
if err := s.Copy(tmp.Name(), remotePath); err != nil {
return err
}
if executable {
_, err := s.SSH("chmod", "+x", remotePath)
return err
}
return nil
}
// RunScript uploads and runs a setup script on the server.
// If the script file doesn't exist locally, it's a no-op.
func (s *Server) RunScript(localPath string) error {
if _, err := os.Stat(localPath); os.IsNotExist(err) {
return nil
}
log.Printf(" Running %s...", localPath)
if err := s.Copy(localPath, "/tmp/setup.sh"); err != nil {
return fmt.Errorf("upload %s: %w", localPath, err)
}
output, err := s.SSH("bash", "/tmp/setup.sh")
if err != nil {
return fmt.Errorf("%s: %w\n%s", localPath, err, output)
}
for _, line := range strings.Split(strings.TrimSpace(output), "\n") {
if line != "" {
log.Printf(" %s", line)
}
}
return nil
}