diff --git a/lib/exec/README.md b/lib/exec/README.md index 48d776a2..33735a1d 100644 --- a/lib/exec/README.md +++ b/lib/exec/README.md @@ -45,10 +45,12 @@ Container (chroot /overlay/newroot) - **ExecIntoInstance()**: Main client function - Connects to Cloud Hypervisor's vsock Unix socket - Performs vsock handshake: `CONNECT 2222\n` → `OK ` -- Creates gRPC client over the vsock connection +- Creates gRPC client over the vsock connection (pooled per VM for efficiency) - Streams stdin/stdout/stderr bidirectionally - Returns exit status when command completes +**Concurrency**: Multiple exec calls to the same VM share the underlying gRPC connection but use separate streams, enabling concurrent command execution. + ### 3. Protocol (`exec.proto`) gRPC streaming RPC with protobuf messages: @@ -84,6 +86,7 @@ gRPC streaming RPC with protobuf messages: - **Bidirectional streaming**: Real-time stdin/stdout/stderr - **TTY support**: Interactive shells with terminal control +- **Concurrent exec**: Multiple simultaneous commands per VM (separate streams) - **Exit codes**: Proper process exit status reporting - **No SSH required**: Direct vsock communication (faster, simpler) - **Container isolation**: Commands run in container context, not VM context diff --git a/lib/exec/client.go b/lib/exec/client.go index 35dc2786..0484bc96 100644 --- a/lib/exec/client.go +++ b/lib/exec/client.go @@ -8,6 +8,7 @@ import ( "log/slog" "net" "strings" + "sync" "time" "google.golang.org/grpc" @@ -23,6 +24,62 @@ const ( vsockGuestPort = 2222 ) +// connPool manages reusable gRPC connections per vsock socket path +// This avoids the overhead and potential issues of rapidly creating/closing connections +var connPool = struct { + sync.RWMutex + conns map[string]*grpc.ClientConn +}{ + conns: make(map[string]*grpc.ClientConn), +} + +// getOrCreateConn returns an existing connection or creates a new one +func getOrCreateConn(ctx context.Context, vsockSocketPath string) (*grpc.ClientConn, error) { + // Try read lock first for existing connection + connPool.RLock() + if conn, ok := connPool.conns[vsockSocketPath]; ok { + connPool.RUnlock() + return conn, nil + } + connPool.RUnlock() + + // Need to create new connection - acquire write lock + connPool.Lock() + defer connPool.Unlock() + + // Double-check after acquiring write lock + if conn, ok := connPool.conns[vsockSocketPath]; ok { + return conn, nil + } + + // Create new connection + conn, err := grpc.Dial("passthrough:///vsock", + grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { + return dialVsock(ctx, vsockSocketPath) + }), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + if err != nil { + return nil, fmt.Errorf("create grpc connection: %w", err) + } + + connPool.conns[vsockSocketPath] = conn + slog.Debug("created new gRPC connection", "socket", vsockSocketPath) + return conn, nil +} + +// CloseConn closes and removes a connection from the pool (call when VM is deleted) +func CloseConn(vsockSocketPath string) { + connPool.Lock() + defer connPool.Unlock() + + if conn, ok := connPool.conns[vsockSocketPath]; ok { + conn.Close() + delete(connPool.conns, vsockSocketPath) + slog.Debug("closed gRPC connection", "socket", vsockSocketPath) + } +} + // ExitStatus represents command exit information type ExitStatus struct { Code int @@ -54,17 +111,13 @@ func (c *bufferedConn) Read(p []byte) (int, error) { // ExecIntoInstance executes command in instance via vsock using gRPC // vsockSocketPath is the Unix socket created by Cloud Hypervisor (e.g., /var/lib/hypeman/guests/{id}/vsock.sock) func ExecIntoInstance(ctx context.Context, vsockSocketPath string, opts ExecOptions) (*ExitStatus, error) { - // Connect to Cloud Hypervisor's vsock Unix socket with custom dialer - grpcConn, err := grpc.NewClient("passthrough:///vsock", - grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { - return dialVsock(ctx, vsockSocketPath) - }), - grpc.WithTransportCredentials(insecure.NewCredentials()), - ) + // Get or create a reusable gRPC connection for this vsock socket + // Connection pooling avoids issues with rapid connect/disconnect cycles + grpcConn, err := getOrCreateConn(ctx, vsockSocketPath) if err != nil { - return nil, fmt.Errorf("create grpc client: %w", err) + return nil, fmt.Errorf("get grpc connection: %w", err) } - defer grpcConn.Close() + // Note: Don't close the connection - it's pooled and reused // Create exec client client := NewExecServiceClient(grpcConn) @@ -72,6 +125,8 @@ func ExecIntoInstance(ctx context.Context, vsockSocketPath string, opts ExecOpti if err != nil { return nil, fmt.Errorf("start exec stream: %w", err) } + // Ensure stream is properly closed when we're done + defer stream.CloseSend() // Send start request if err := stream.Send(&ExecRequest{ @@ -108,22 +163,24 @@ func ExecIntoInstance(ctx context.Context, vsockSocketPath string, opts ExecOpti } // Receive responses + var totalStdout, totalStderr int for { resp, err := stream.Recv() if err == io.EOF { - // Stream closed without exit code - return nil, fmt.Errorf("stream closed without exit code") + return nil, fmt.Errorf("stream closed without exit code (stdout=%d, stderr=%d)", totalStdout, totalStderr) } if err != nil { - return nil, fmt.Errorf("receive response: %w", err) + return nil, fmt.Errorf("receive response (stdout=%d, stderr=%d): %w", totalStdout, totalStderr, err) } switch r := resp.Response.(type) { case *ExecResponse_Stdout: + totalStdout += len(r.Stdout) if opts.Stdout != nil { opts.Stdout.Write(r.Stdout) } case *ExecResponse_Stderr: + totalStderr += len(r.Stderr) if opts.Stderr != nil { opts.Stderr.Write(r.Stderr) } diff --git a/lib/instances/exec_test.go b/lib/instances/exec_test.go new file mode 100644 index 00000000..ef0ba8f0 --- /dev/null +++ b/lib/instances/exec_test.go @@ -0,0 +1,218 @@ +package instances + +import ( + "context" + "fmt" + "os" + "strings" + "sync" + "testing" + "time" + + "github.com/onkernel/hypeman/lib/images" + "github.com/onkernel/hypeman/lib/paths" + "github.com/onkernel/hypeman/lib/system" + "github.com/stretchr/testify/require" +) + +// waitForExecAgent polls until exec-agent is ready +func waitForExecAgent(ctx context.Context, mgr *manager, instanceID string, timeout time.Duration) error { + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + logs, err := collectLogs(ctx, mgr, instanceID, 100) + if err == nil && strings.Contains(logs, "[exec-agent] listening on vsock port 2222") { + return nil + } + time.Sleep(500 * time.Millisecond) + } + return context.DeadlineExceeded +} + +// Note: execCommand is defined in network_test.go + +// TestExecConcurrent tests concurrent exec commands from multiple goroutines. +// This validates that the exec infrastructure handles concurrent access correctly. +func TestExecConcurrent(t *testing.T) { + if _, err := os.Stat("/dev/kvm"); os.IsNotExist(err) { + t.Fatal("/dev/kvm not available") + } + + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + manager, tmpDir := setupTestManager(t) + ctx := context.Background() + p := paths.New(tmpDir) + + // Setup image + imageManager, err := images.NewManager(p, 1) + require.NoError(t, err) + + t.Log("Pulling nginx:alpine image...") + _, err = imageManager.CreateImage(ctx, images.CreateImageRequest{ + Name: "docker.io/library/nginx:alpine", + }) + require.NoError(t, err) + + for i := 0; i < 60; i++ { + img, err := imageManager.GetImage(ctx, "docker.io/library/nginx:alpine") + if err == nil && img.Status == images.StatusReady { + break + } + time.Sleep(1 * time.Second) + } + t.Log("Image ready") + + // Ensure system files + systemManager := system.NewManager(p) + err = systemManager.EnsureSystemFiles(ctx) + require.NoError(t, err) + + // Create nginx instance + t.Log("Creating nginx instance...") + inst, err := manager.CreateInstance(ctx, CreateInstanceRequest{ + Name: "exec-test", + Image: "docker.io/library/nginx:alpine", + Size: 512 * 1024 * 1024, + HotplugSize: 512 * 1024 * 1024, + OverlaySize: 1024 * 1024 * 1024, + Vcpus: 2, // More vCPUs for concurrency + NetworkEnabled: false, + }) + require.NoError(t, err) + t.Logf("Instance created: %s", inst.Id) + + t.Cleanup(func() { + t.Log("Cleaning up...") + manager.DeleteInstance(ctx, inst.Id) + }) + + // Wait for exec-agent to be ready (retry here is OK - we're just waiting for startup) + err = waitForExecAgent(ctx, manager, inst.Id, 15*time.Second) + require.NoError(t, err, "exec-agent should be ready") + + // Verify exec-agent works with a simple command first + _, code, err := execCommand(ctx, inst.VsockSocket, "echo", "ready") + require.NoError(t, err, "initial exec should work") + require.Equal(t, 0, code, "initial exec should succeed") + + // Run 5 concurrent workers, each doing 25 iterations with its own file + // NO RETRIES - this tests that exec works reliably under concurrent load + const numWorkers = 5 + const numIterations = 25 + + t.Logf("Running %d concurrent workers, %d iterations each (no retries)...", numWorkers, numIterations) + + var wg sync.WaitGroup + errors := make(chan error, numWorkers) + + for w := 0; w < numWorkers; w++ { + wg.Add(1) + go func(workerID int) { + defer wg.Done() + filename := fmt.Sprintf("/tmp/test%d.txt", workerID) + + for i := 1; i <= numIterations; i++ { + // Write (no retry - must work first time) + writeCmd := fmt.Sprintf("echo '%d-%d' > %s", workerID, i, filename) + output, code, err := execCommand(ctx, inst.VsockSocket, "/bin/sh", "-c", writeCmd) + if err != nil { + errors <- fmt.Errorf("worker %d, iter %d: write error: %w", workerID, i, err) + return + } + if code != 0 { + errors <- fmt.Errorf("worker %d, iter %d: write failed with code %d, output: %s", workerID, i, code, output) + return + } + + // Read (no retry - must work first time) + output, code, err = execCommand(ctx, inst.VsockSocket, "cat", filename) + if err != nil { + errors <- fmt.Errorf("worker %d, iter %d: read error: %w", workerID, i, err) + return + } + if code != 0 { + errors <- fmt.Errorf("worker %d, iter %d: read failed with code %d", workerID, i, code) + return + } + + expected := fmt.Sprintf("%d-%d", workerID, i) + actual := strings.TrimSpace(output) + if expected != actual { + errors <- fmt.Errorf("worker %d, iter %d: expected %q, got %q", workerID, i, expected, actual) + return + } + } + t.Logf("Worker %d completed %d iterations", workerID, numIterations) + }(w) + } + + // Wait for all workers + wg.Wait() + close(errors) + + // Check for errors + var errs []error + for err := range errors { + errs = append(errs, err) + } + require.Empty(t, errs, "concurrent exec failed: %v", errs) + + t.Logf("All %d workers completed %d iterations each (total: %d exec pairs)", numWorkers, numIterations, numWorkers*numIterations*2) + + // Phase 2: Test long-running concurrent streams + // This verifies streams don't block each other (e.g., multiple shells or streaming commands) + t.Log("Phase 2: Testing long-running concurrent streams...") + + const streamWorkers = 5 + const streamDuration = 2 // seconds + + var streamWg sync.WaitGroup + streamErrors := make(chan error, streamWorkers) + streamStart := time.Now() + + for w := 0; w < streamWorkers; w++ { + streamWg.Add(1) + go func(workerID int) { + defer streamWg.Done() + + // Command that takes ~2 seconds and produces output + cmd := fmt.Sprintf("sleep %d && echo 'stream-%d-done'", streamDuration, workerID) + output, code, err := execCommand(ctx, inst.VsockSocket, "/bin/sh", "-c", cmd) + if err != nil { + streamErrors <- fmt.Errorf("stream worker %d: error: %w", workerID, err) + return + } + if code != 0 { + streamErrors <- fmt.Errorf("stream worker %d: exit code %d", workerID, code) + return + } + expected := fmt.Sprintf("stream-%d-done", workerID) + if !strings.Contains(output, expected) { + streamErrors <- fmt.Errorf("stream worker %d: expected %q in output, got %q", workerID, expected, output) + return + } + }(w) + } + + streamWg.Wait() + close(streamErrors) + + streamElapsed := time.Since(streamStart) + + // Check for errors + var streamErrs []error + for err := range streamErrors { + streamErrs = append(streamErrs, err) + } + require.Empty(t, streamErrs, "long-running streams failed: %v", streamErrs) + + // If concurrent, should complete in ~2-4s; if serialized would be ~10s + maxExpected := time.Duration(streamDuration+2) * time.Second + require.Less(t, streamElapsed, maxExpected, + "streams appear serialized - took %v, expected < %v", streamElapsed, maxExpected) + + t.Logf("Long-running streams completed in %v (concurrent OK)", streamElapsed) +} + diff --git a/lib/system/exec_agent/main.go b/lib/system/exec_agent/main.go index 54bb79c1..35744d5a 100644 --- a/lib/system/exec_agent/main.go +++ b/lib/system/exec_agent/main.go @@ -3,6 +3,7 @@ package main import ( "context" "fmt" + "io" "log" "os" "os/exec" @@ -113,8 +114,12 @@ func (s *execServer) executeNoTTY(ctx context.Context, stream pb.ExecService_Exe return fmt.Errorf("start command: %w", err) } - // Use WaitGroup to ensure all output is sent before exit code + // Mutex to protect concurrent stream.Send calls (gRPC streams are not thread-safe) + var sendMu sync.Mutex + + // Use WaitGroup to ensure all output is read before sending var wg sync.WaitGroup + var stdoutData, stderrData []byte // Handle stdin in background go func() { @@ -130,47 +135,51 @@ func (s *execServer) executeNoTTY(ctx context.Context, stream pb.ExecService_Exe } }() - // Stream stdout + // Read all stdout/stderr BEFORE calling Wait() - Wait() closes the pipes! wg.Add(1) go func() { defer wg.Done() - buf := make([]byte, 32 * 1024) - for { - n, err := stdout.Read(buf) - if n > 0 { - stream.Send(&pb.ExecResponse{ - Response: &pb.ExecResponse_Stdout{Stdout: buf[:n]}, - }) - } - if err != nil { - return - } - } + data, _ := io.ReadAll(stdout) + stdoutData = data }() - // Stream stderr wg.Add(1) go func() { defer wg.Done() - buf := make([]byte, 32 * 1024) - for { - n, err := stderr.Read(buf) - if n > 0 { - stream.Send(&pb.ExecResponse{ - Response: &pb.ExecResponse_Stderr{Stderr: buf[:n]}, - }) - } - if err != nil { - return - } - } + data, _ := io.ReadAll(stderr) + stderrData = data }() - // Wait for command to finish or context cancellation - waitErr := cmd.Wait() - - // Wait for all output to be sent + // Wait for all reads to complete FIRST (before Wait closes pipes) wg.Wait() + + // Now safe to call Wait - pipes are fully drained + waitErr := cmd.Wait() + + // Now stream output in chunks (streaming compatible) + const chunkSize = 32 * 1024 + for i := 0; i < len(stdoutData); i += chunkSize { + end := i + chunkSize + if end > len(stdoutData) { + end = len(stdoutData) + } + sendMu.Lock() + stream.Send(&pb.ExecResponse{ + Response: &pb.ExecResponse_Stdout{Stdout: stdoutData[i:end]}, + }) + sendMu.Unlock() + } + for i := 0; i < len(stderrData); i += chunkSize { + end := i + chunkSize + if end > len(stderrData) { + end = len(stderrData) + } + sendMu.Lock() + stream.Send(&pb.ExecResponse{ + Response: &pb.ExecResponse_Stderr{Stderr: stderrData[i:end]}, + }) + sendMu.Unlock() + } exitCode := int32(0) if cmd.ProcessState != nil { @@ -213,6 +222,9 @@ func (s *execServer) executeTTY(ctx context.Context, stream pb.ExecService_ExecS } defer ptmx.Close() + // Mutex to protect concurrent stream.Send calls (gRPC streams are not thread-safe) + var sendMu sync.Mutex + // Use WaitGroup to ensure all output is sent before exit code var wg sync.WaitGroup @@ -238,9 +250,11 @@ func (s *execServer) executeTTY(ctx context.Context, stream pb.ExecService_ExecS for { n, err := ptmx.Read(buf) if n > 0 { + sendMu.Lock() stream.Send(&pb.ExecResponse{ Response: &pb.ExecResponse_Stdout{Stdout: buf[:n]}, }) + sendMu.Unlock() } if err != nil { return