diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e4d18ad20..fd90f9718 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -71,7 +71,9 @@ jobs: - name: Run tests with coverage if: github.ref == 'refs/heads/main' - run: cargo llvm-cov nextest --workspace --json --output-path coverage.json --html --output-dir coverage-html -E 'not (test(/live/) | test(/integration/))' + run: | + cargo llvm-cov nextest --workspace --json --output-path coverage.json -E 'not (test(/live/) | test(/integration/))' + cargo llvm-cov report --html --output-dir coverage-html - name: Upload coverage HTML report if: github.ref == 'refs/heads/main' diff --git a/Cargo.lock b/Cargo.lock index fc20f351f..bb695575a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1351,6 +1351,7 @@ dependencies = [ "secure-container-runtime", "serde", "serde_json", + "serial_test", "tempfile", "thiserror 2.0.17", "tokio", @@ -1875,7 +1876,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8d162beedaa69905488a8da94f5ac3edb4dd4788b732fadb7bd120b2625c1976" dependencies = [ "data-encoding", - "syn 2.0.111", + "syn 1.0.109", ] [[package]] @@ -6963,6 +6964,15 @@ dependencies = [ "yap", ] +[[package]] +name = "scc" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46e6f046b7fef48e2660c57ed794263155d713de679057f2d0c169bfc6e756cc" +dependencies = [ + "sdd", +] + [[package]] name = "schannel" version = "0.1.28" @@ -7032,6 +7042,12 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "sdd" +version = "3.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "490dcfcbfef26be6800d11870ff2df8774fa6e86d047e3e8c8a76b25655e41ca" + [[package]] name = "sec1" version = "0.7.3" @@ -7362,6 +7378,32 @@ dependencies = [ "serde", ] +[[package]] +name = "serial_test" +version = "3.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d0b343e184fc3b7bb44dff0705fffcf4b3756ba6aff420dddd8b24ca145e555" +dependencies = [ + "futures-executor", + "futures-util", + "log", + "once_cell", + "parking_lot 0.12.5", + "scc", + "serial_test_derive", +] + +[[package]] +name = "serial_test_derive" +version = "3.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f50427f258fb77356e4cd4aa0e87e2bd2c66dbcee41dc405282cae2bfc26c83" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.111", +] + [[package]] name = "sha1" version = "0.10.6" diff --git a/crates/challenge-orchestrator/Cargo.toml b/crates/challenge-orchestrator/Cargo.toml index 66afcd05e..1a00bb15d 100644 --- a/crates/challenge-orchestrator/Cargo.toml +++ b/crates/challenge-orchestrator/Cargo.toml @@ -41,3 +41,4 @@ hostname = "0.4" [dev-dependencies] tempfile = { workspace = true } tokio-test = "0.4" +serial_test = "3.2" diff --git a/crates/challenge-orchestrator/src/backend.rs b/crates/challenge-orchestrator/src/backend.rs index 140df4777..4a7406db3 100644 --- a/crates/challenge-orchestrator/src/backend.rs +++ b/crates/challenge-orchestrator/src/backend.rs @@ -1,8 +1,9 @@ //! Container backend abstraction //! -//! Provides a unified interface for container management that can use: -//! - SecureContainerClient via broker (DEFAULT for production validators) -//! - Direct Docker (ONLY for local development when DEVELOPMENT_MODE=true) +//! This module selects the concrete runtime bridge that the orchestrator uses +//! to manipulate containers. In production it proxies through the +//! `secure-container-runtime` broker while still allowing direct Docker access +//! when a developer explicitly opts into `DEVELOPMENT_MODE=true`. //! //! ## Backend Selection (Priority Order) //! @@ -13,23 +14,29 @@ //! //! ## Security //! -//! In production, challenges MUST run through the secure broker. -//! The broker enforces: -//! - Image whitelisting (only ghcr.io/platformnetwork/) -//! - Non-privileged containers -//! - Resource limits -//! - No Docker socket access for challenges - -use crate::{ChallengeContainerConfig, ChallengeInstance, ContainerStatus}; +//! The secure backend enforces: +//! - Image allow-listing (`ghcr.io/platformnetwork/`) +//! - Non-privileged containers with resource limits baked in +//! - Network isolation handled by the broker +//! - No direct Docker socket exposure for workloads + +use crate::{ChallengeContainerConfig, ChallengeDocker, ChallengeInstance, ContainerStatus}; use async_trait::async_trait; use secure_container_runtime::{ - ContainerConfigBuilder, ContainerState, NetworkMode, SecureContainerClient, + CleanupResult as BrokerCleanupResult, ContainerConfig, ContainerConfigBuilder, ContainerError, + ContainerInfo, ContainerStartResult, ContainerState, NetworkMode, SecureContainerClient, }; use std::path::Path; +use std::sync::Arc; use tracing::{error, info, warn}; /// Default broker socket path pub const DEFAULT_BROKER_SOCKET: &str = "/var/run/platform/broker.sock"; +const BROKER_SOCKET_OVERRIDE_ENV: &str = "BROKER_SOCKET_OVERRIDE"; + +fn default_broker_socket_path() -> String { + std::env::var(BROKER_SOCKET_OVERRIDE_ENV).unwrap_or_else(|_| DEFAULT_BROKER_SOCKET.to_string()) +} /// Container backend trait for managing challenge containers #[async_trait] @@ -62,23 +69,159 @@ pub trait ContainerBackend: Send + Sync { async fn list_challenge_containers(&self, challenge_id: &str) -> anyhow::Result>; } +#[async_trait] +pub trait SecureContainerBridge: Send + Sync { + async fn create_container( + &self, + config: ContainerConfig, + ) -> Result<(String, String), ContainerError>; + async fn start_container( + &self, + container_id: &str, + ) -> Result; + async fn get_endpoint(&self, container_id: &str, port: u16) -> Result; + async fn stop_container( + &self, + container_id: &str, + timeout_secs: u32, + ) -> Result<(), ContainerError>; + async fn remove_container(&self, container_id: &str, force: bool) + -> Result<(), ContainerError>; + async fn inspect(&self, container_id: &str) -> Result; + async fn pull_image(&self, image: &str) -> Result<(), ContainerError>; + async fn logs(&self, container_id: &str, tail: usize) -> Result; + async fn cleanup_challenge( + &self, + challenge_id: &str, + ) -> Result; + async fn list_by_challenge( + &self, + challenge_id: &str, + ) -> Result, ContainerError>; +} + +struct SecureClientBridge { + client: SecureContainerClient, +} + +impl SecureClientBridge { + fn new(socket_path: &str) -> Self { + Self { + client: SecureContainerClient::new(socket_path), + } + } +} + +#[async_trait] +impl SecureContainerBridge for SecureClientBridge { + async fn create_container( + &self, + config: ContainerConfig, + ) -> Result<(String, String), ContainerError> { + self.client.create_container(config).await + } + + async fn start_container( + &self, + container_id: &str, + ) -> Result { + self.client.start_container(container_id).await + } + + async fn get_endpoint(&self, container_id: &str, port: u16) -> Result { + self.client.get_endpoint(container_id, port).await + } + + async fn stop_container( + &self, + container_id: &str, + timeout_secs: u32, + ) -> Result<(), ContainerError> { + self.client.stop_container(container_id, timeout_secs).await + } + + async fn remove_container( + &self, + container_id: &str, + force: bool, + ) -> Result<(), ContainerError> { + self.client.remove_container(container_id, force).await + } + + async fn inspect(&self, container_id: &str) -> Result { + self.client.inspect(container_id).await + } + + async fn pull_image(&self, image: &str) -> Result<(), ContainerError> { + self.client.pull_image(image).await + } + + async fn logs(&self, container_id: &str, tail: usize) -> Result { + self.client.logs(container_id, tail).await + } + + async fn cleanup_challenge( + &self, + challenge_id: &str, + ) -> Result { + self.client.cleanup_challenge(challenge_id).await + } + + async fn list_by_challenge( + &self, + challenge_id: &str, + ) -> Result, ContainerError> { + self.client.list_by_challenge(challenge_id).await + } +} + /// Secure container backend using the broker pub struct SecureBackend { - client: SecureContainerClient, + client: Arc, validator_id: String, } impl SecureBackend { /// Create a new secure backend pub fn new(socket_path: &str, validator_id: &str) -> Self { + Self::with_bridge(SecureClientBridge::new(socket_path), validator_id) + } + + #[cfg(test)] + fn test_backend_slot() -> &'static std::sync::Mutex> { + use std::sync::{Mutex, OnceLock}; + static SLOT: OnceLock>> = OnceLock::new(); + SLOT.get_or_init(|| Mutex::new(None)) + } + + #[cfg(test)] + fn take_test_backend() -> Option { + Self::test_backend_slot().lock().unwrap().take() + } + + #[cfg(test)] + pub(crate) fn set_test_backend(backend: SecureBackend) { + Self::test_backend_slot().lock().unwrap().replace(backend); + } + + /// Build a backend from an arbitrary bridge (used for tests) + pub fn with_bridge( + client: impl SecureContainerBridge + 'static, + validator_id: impl Into, + ) -> Self { Self { - client: SecureContainerClient::new(socket_path), - validator_id: validator_id.to_string(), + client: Arc::new(client), + validator_id: validator_id.into(), } } /// Create from environment or default socket pub fn from_env() -> Option { + #[cfg(test)] + if let Some(backend) = Self::take_test_backend() { + return Some(backend); + } + let validator_id = std::env::var("VALIDATOR_HOTKEY").unwrap_or_else(|_| "unknown".to_string()); @@ -91,10 +234,11 @@ impl SecureBackend { warn!(socket = %socket, "Broker socket from env does not exist"); } - // Priority 2: Default socket path - if Path::new(DEFAULT_BROKER_SOCKET).exists() { - info!(socket = %DEFAULT_BROKER_SOCKET, "Using default broker socket"); - return Some(Self::new(DEFAULT_BROKER_SOCKET, &validator_id)); + // Priority 2: Default socket path (allow override for tests) + let default_socket = default_broker_socket_path(); + if Path::new(&default_socket).exists() { + info!(socket = %default_socket, "Using default broker socket"); + return Some(Self::new(&default_socket, &validator_id)); } None @@ -107,7 +251,8 @@ impl SecureBackend { return true; } } - Path::new(DEFAULT_BROKER_SOCKET).exists() + let default_socket = default_broker_socket_path(); + Path::new(&default_socket).exists() } } @@ -233,15 +378,47 @@ impl ContainerBackend for SecureBackend { } /// Direct Docker backend (for local development) +#[derive(Clone)] pub struct DirectDockerBackend { - docker: crate::docker::DockerClient, + docker: Arc, } impl DirectDockerBackend { /// Create a new direct Docker backend pub async fn new() -> anyhow::Result { + #[cfg(test)] + if let Some(result) = Self::take_test_result() { + return result; + } + let docker = crate::docker::DockerClient::connect().await?; - Ok(Self { docker }) + Ok(Self::with_docker(docker)) + } + + /// Build a backend from a custom docker implementation (used for tests) + pub fn with_docker(docker: impl ChallengeDocker + 'static) -> Self { + Self { + docker: Arc::new(docker), + } + } + + #[cfg(test)] + fn test_backend_slot() -> &'static std::sync::Mutex>> + { + use std::sync::OnceLock; + static SLOT: OnceLock>>> = + OnceLock::new(); + SLOT.get_or_init(|| std::sync::Mutex::new(None)) + } + + #[cfg(test)] + fn take_test_result() -> Option> { + Self::test_backend_slot().lock().unwrap().take() + } + + #[cfg(test)] + pub(crate) fn set_test_result(result: anyhow::Result) { + Self::test_backend_slot().lock().unwrap().replace(result); } } @@ -302,24 +479,28 @@ impl ContainerBackend for DirectDockerBackend { /// 2. Broker socket available -> Secure broker (production default) /// 3. No broker + not dev mode -> Error (production requires broker) pub async fn create_backend() -> anyhow::Result> { - // Check if explicitly in development mode - let dev_mode = std::env::var("DEVELOPMENT_MODE") - .map(|v| v == "true" || v == "1") - .unwrap_or(false); - - if dev_mode { - info!("DEVELOPMENT_MODE=true: Using direct Docker (local development)"); - let direct = DirectDockerBackend::new().await?; - return Ok(Box::new(direct)); - } - - // Try to use secure broker (default for production) - if let Some(secure) = SecureBackend::from_env() { - info!("Using secure container broker (production mode)"); - return Ok(Box::new(secure)); + match select_backend_mode() { + BackendMode::Development => { + info!("DEVELOPMENT_MODE=true: Using direct Docker (local development)"); + let direct = DirectDockerBackend::new().await?; + Ok(Box::new(direct)) + } + BackendMode::Secure => { + if let Some(secure) = SecureBackend::from_env() { + info!("Using secure container broker (production mode)"); + Ok(Box::new(secure)) + } else { + warn!( + "Secure backend reported as available but failed to initialize; falling back to Docker" + ); + create_docker_fallback_backend().await + } + } + BackendMode::Fallback => create_docker_fallback_backend().await, } +} - // No broker available - try Docker as last resort but warn +async fn create_docker_fallback_backend() -> anyhow::Result> { warn!("Broker not available. Attempting Docker fallback..."); warn!("This should only happen in local development!"); warn!("Set DEVELOPMENT_MODE=true to suppress this warning, or start the broker."); @@ -333,15 +514,33 @@ pub async fn create_backend() -> anyhow::Result> { error!("Cannot connect to Docker: {}", e); error!("For production: Start the container-broker service"); error!("For development: Set DEVELOPMENT_MODE=true and ensure Docker is running"); + let default_socket = default_broker_socket_path(); Err(anyhow::anyhow!( "No container backend available. \ Start broker at {} or set DEVELOPMENT_MODE=true for local Docker", - DEFAULT_BROKER_SOCKET + default_socket )) } } } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum BackendMode { + Development, + Secure, + Fallback, +} + +pub fn select_backend_mode() -> BackendMode { + if is_development_mode() { + BackendMode::Development + } else if SecureBackend::is_available() { + BackendMode::Secure + } else { + BackendMode::Fallback + } +} + /// Check if running in secure mode (broker available) pub fn is_secure_mode() -> bool { SecureBackend::is_available() @@ -353,3 +552,805 @@ pub fn is_development_mode() -> bool { .map(|v| v == "true" || v == "1") .unwrap_or(false) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::docker::CleanupResult as DockerCleanupResult; + use chrono::Utc; + use platform_core::ChallengeId; + use serial_test::serial; + use std::collections::HashMap; + use std::sync::{Arc, Mutex}; + use tempfile::{tempdir, NamedTempFile}; + + fn reset_env() { + for key in [ + "DEVELOPMENT_MODE", + "CONTAINER_BROKER_SOCKET", + "VALIDATOR_HOTKEY", + BROKER_SOCKET_OVERRIDE_ENV, + ] { + std::env::remove_var(key); + } + } + + #[test] + #[serial] + fn test_is_development_mode_reflects_env() { + reset_env(); + assert!(!is_development_mode()); + + std::env::set_var("DEVELOPMENT_MODE", "1"); + assert!(is_development_mode()); + + std::env::set_var("DEVELOPMENT_MODE", "false"); + assert!(!is_development_mode()); + reset_env(); + } + + #[test] + #[serial] + fn test_secure_backend_from_env_detects_socket() { + reset_env(); + let temp_socket = NamedTempFile::new().expect("temp socket path"); + let socket_path = temp_socket.path().to_path_buf(); + std::env::set_var("CONTAINER_BROKER_SOCKET", &socket_path); + std::env::set_var("VALIDATOR_HOTKEY", "validator123"); + + let backend = SecureBackend::from_env().expect("should create backend from env"); + assert_eq!(backend.validator_id, "validator123"); + + reset_env(); + drop(temp_socket); + } + + #[test] + #[serial] + fn test_is_secure_mode_uses_env_socket() { + reset_env(); + let temp_socket = NamedTempFile::new().expect("temp socket path"); + let socket_path = temp_socket.path().to_path_buf(); + std::env::set_var("CONTAINER_BROKER_SOCKET", &socket_path); + + assert!(is_secure_mode()); + + reset_env(); + drop(temp_socket); + } + + #[test] + #[serial] + fn test_secure_backend_is_available_with_override_socket() { + reset_env(); + let temp_socket = NamedTempFile::new().expect("temp socket path"); + let socket_path = temp_socket.path().to_path_buf(); + std::env::set_var(BROKER_SOCKET_OVERRIDE_ENV, &socket_path); + + assert!(SecureBackend::is_available()); + + reset_env(); + drop(temp_socket); + } + + #[test] + #[serial] + fn test_select_backend_mode_prefers_development_mode() { + reset_env(); + std::env::set_var("DEVELOPMENT_MODE", "true"); + + assert_eq!(select_backend_mode(), BackendMode::Development); + + reset_env(); + } + + #[test] + #[serial] + fn test_select_backend_mode_prefers_secure_when_broker_available() { + reset_env(); + let temp_socket = NamedTempFile::new().expect("temp socket path"); + let socket_path = temp_socket.path().to_path_buf(); + std::env::set_var(BROKER_SOCKET_OVERRIDE_ENV, &socket_path); + + assert_eq!(select_backend_mode(), BackendMode::Secure); + + reset_env(); + drop(temp_socket); + } + + #[test] + #[serial] + fn test_select_backend_mode_falls_back_without_broker() { + reset_env(); + let dir = tempdir().expect("temp dir"); + let missing_socket = dir.path().join("missing.sock"); + std::env::set_var(BROKER_SOCKET_OVERRIDE_ENV, &missing_socket); + + assert_eq!(select_backend_mode(), BackendMode::Fallback); + + reset_env(); + } + + #[test] + #[serial] + fn test_secure_backend_from_env_uses_default_socket() { + reset_env(); + let temp_socket = NamedTempFile::new().expect("temp socket path"); + let socket_path = temp_socket.path().to_path_buf(); + std::env::set_var(BROKER_SOCKET_OVERRIDE_ENV, &socket_path); + + let backend = SecureBackend::from_env().expect("backend from default socket"); + assert_eq!(backend.validator_id, "unknown"); + + reset_env(); + } + + #[tokio::test] + #[serial] + async fn test_secure_backend_start_challenge_via_bridge() { + reset_env(); + let bridge = RecordingSecureBridge::default(); + bridge.set_create_response("container-123", "challenge-container"); + bridge.set_endpoint("container-123", "http://sandbox:8080"); + + let backend = SecureBackend::with_bridge(bridge.clone(), "validator-abc"); + let config = sample_config("ghcr.io/platformnetwork/demo:v1"); + + let instance = backend + .start_challenge(&config) + .await + .expect("start succeeds"); + + assert_eq!(instance.container_id, "container-123"); + assert_eq!(instance.endpoint, "http://sandbox:8080"); + assert_eq!(instance.image, config.docker_image); + + let ops = bridge.operations(); + assert!(ops.iter().any(|op| op.starts_with("create:"))); + assert!(ops.iter().any(|op| op.starts_with("start:"))); + assert!(ops.iter().any(|op| op.starts_with("endpoint:"))); + + reset_env(); + } + + #[tokio::test] + #[serial] + async fn test_secure_backend_covers_remaining_methods() { + reset_env(); + let bridge = RecordingSecureBridge::default(); + bridge.set_inspect_state("running", ContainerState::Running); + bridge.set_inspect_state("stopped", ContainerState::Stopped); + bridge.set_logs("running", "log output"); + bridge.set_cleanup_result(BrokerCleanupResult { + total: 2, + stopped: 2, + removed: 2, + errors: Vec::new(), + }); + bridge.set_list( + "challenge-1", + vec![ + container_info("alpha", ContainerState::Running), + container_info("beta", ContainerState::Stopped), + ], + ); + let backend = SecureBackend::with_bridge(bridge.clone(), "validator-xyz"); + + backend + .stop_container("running") + .await + .expect("stop delegates"); + backend + .remove_container("running") + .await + .expect("remove delegates"); + backend + .pull_image("ghcr.io/platformnetwork/demo:v2") + .await + .expect("pull delegates"); + let logs = backend + .get_logs("running", 50) + .await + .expect("logs delegates"); + assert_eq!(logs, "log output"); + assert!(backend + .is_container_running("running") + .await + .expect("running state")); + assert!(!backend + .is_container_running("stopped") + .await + .expect("stopped state")); + + let removed = backend + .cleanup_challenge("challenge-1") + .await + .expect("cleanup delegates"); + assert_eq!(removed, 2); + + let ids = backend + .list_challenge_containers("challenge-1") + .await + .expect("list delegates"); + assert_eq!(ids, vec!["alpha".to_string(), "beta".to_string()]); + + let ops = bridge.operations(); + assert!(ops.iter().any(|op| op.starts_with("stop:"))); + assert!(ops.iter().any(|op| op.starts_with("remove:"))); + assert!(ops.iter().any(|op| op.starts_with("pull:"))); + assert!(ops.iter().any(|op| op.starts_with("logs:"))); + assert!(ops.iter().any(|op| op.starts_with("inspect:"))); + assert!(ops.iter().any(|op| op.starts_with("cleanup:"))); + assert!(ops.iter().any(|op| op.starts_with("list:"))); + + reset_env(); + } + + #[tokio::test] + #[serial] + async fn test_direct_backend_delegates_to_docker() { + let docker = RecordingChallengeDocker::default(); + docker.set_list(vec!["container-1".to_string(), "other".to_string()]); + + let backend = DirectDockerBackend::with_docker(docker.clone()); + let mut config = sample_config("ghcr.io/platformnetwork/demo:v3"); + config.challenge_id = ChallengeId::new(); + + backend.pull_image(&config.docker_image).await.unwrap(); + let instance = backend.start_challenge(&config).await.unwrap(); + docker.set_running(&instance.container_id, true); + docker.set_logs(&instance.container_id, "container logs"); + backend + .stop_container(&instance.container_id) + .await + .unwrap(); + backend + .remove_container(&instance.container_id) + .await + .unwrap(); + assert!(backend + .is_container_running(&instance.container_id) + .await + .unwrap()); + let logs = backend.get_logs(&instance.container_id, 10).await.unwrap(); + assert_eq!(logs, "container logs"); + + let listed = backend.list_challenge_containers("unused").await.unwrap(); + assert_eq!(listed.len(), 2); + + let ops = docker.operations(); + assert!(ops.iter().any(|op| op.starts_with("pull:"))); + assert!(ops.iter().any(|op| op.starts_with("start:"))); + assert!(ops.iter().any(|op| op.starts_with("stop:"))); + assert!(ops.iter().any(|op| op.starts_with("remove:"))); + assert!(ops.iter().any(|op| op.starts_with("logs:"))); + } + + #[tokio::test] + #[serial] + async fn test_direct_backend_cleanup_filters_by_challenge_id() { + let docker = RecordingChallengeDocker::default(); + let challenge_id = ChallengeId::new(); + let challenge_str = challenge_id.to_string(); + docker.set_list(vec![ + format!("{challenge_str}-a"), + "platform-helper".to_string(), + format!("other-{challenge_str}"), + ]); + + let backend = DirectDockerBackend::with_docker(docker.clone()); + let removed = backend + .cleanup_challenge(&challenge_str) + .await + .expect("cleanup succeeds"); + assert_eq!(removed, 2); + + let ops = docker.operations(); + assert!(ops.iter().filter(|op| op.starts_with("stop:")).count() >= 2); + assert!(ops.iter().filter(|op| op.starts_with("remove:")).count() >= 2); + } + + #[tokio::test] + #[serial] + async fn test_create_backend_uses_direct_in_dev_mode() { + reset_env(); + std::env::set_var("DEVELOPMENT_MODE", "true"); + let docker = RecordingChallengeDocker::default(); + DirectDockerBackend::set_test_result(Ok(DirectDockerBackend::with_docker(docker.clone()))); + + let backend = create_backend().await.expect("backend"); + backend + .pull_image("ghcr.io/platformnetwork/test:v1") + .await + .unwrap(); + + assert!(docker + .operations() + .iter() + .any(|op| op == "pull:ghcr.io/platformnetwork/test:v1")); + + reset_env(); + } + + #[tokio::test] + #[serial] + async fn test_create_backend_uses_secure_when_broker_available() { + reset_env(); + let temp_socket = NamedTempFile::new().expect("temp socket path"); + let socket_path = temp_socket.path().to_path_buf(); + std::env::set_var(BROKER_SOCKET_OVERRIDE_ENV, &socket_path); + + let bridge = RecordingSecureBridge::default(); + SecureBackend::set_test_backend(SecureBackend::with_bridge( + bridge.clone(), + "validator-secure", + )); + + let backend = create_backend().await.expect("secure backend"); + backend + .pull_image("ghcr.io/platformnetwork/secure:v1") + .await + .unwrap(); + + assert!(bridge + .operations() + .iter() + .any(|op| op == "pull:ghcr.io/platformnetwork/secure:v1")); + + reset_env(); + drop(temp_socket); + } + + #[tokio::test] + #[serial] + async fn test_create_backend_falls_back_when_secure_missing() { + reset_env(); + let dir = tempdir().expect("temp dir"); + let missing_socket = dir.path().join("missing.sock"); + std::env::set_var(BROKER_SOCKET_OVERRIDE_ENV, &missing_socket); + DirectDockerBackend::set_test_result(Ok(DirectDockerBackend::with_docker( + RecordingChallengeDocker::default(), + ))); + + let backend = create_backend().await.expect("fallback backend"); + backend + .pull_image("ghcr.io/platformnetwork/fallback:v1") + .await + .unwrap(); + + reset_env(); + } + + #[tokio::test] + #[serial] + async fn test_create_docker_fallback_backend_reports_error() { + reset_env(); + DirectDockerBackend::set_test_result(Err(anyhow::anyhow!("boom"))); + let err = match create_docker_fallback_backend().await { + Ok(_) => panic!("expected error"), + Err(err) => err, + }; + assert!(err.to_string().contains("No container backend available")); + reset_env(); + } + + fn sample_config(image: &str) -> ChallengeContainerConfig { + ChallengeContainerConfig { + challenge_id: ChallengeId::new(), + name: "challenge".to_string(), + docker_image: image.to_string(), + mechanism_id: 0, + emission_weight: 1.0, + timeout_secs: 300, + cpu_cores: 1.0, + memory_mb: 512, + gpu_required: false, + } + } + + fn container_info(id: &str, state: ContainerState) -> ContainerInfo { + ContainerInfo { + id: id.to_string(), + name: format!("{id}-container"), + challenge_id: "challenge-1".to_string(), + owner_id: "owner".to_string(), + image: "ghcr.io/platformnetwork/demo".to_string(), + state, + created_at: Utc::now(), + ports: HashMap::new(), + endpoint: None, + labels: HashMap::new(), + } + } + + #[derive(Clone, Default)] + struct RecordingSecureBridge { + inner: Arc, + } + + struct RecordingSecureBridgeInner { + operations: Mutex>, + inspect_map: Mutex>, + endpoint_map: Mutex>, + logs_map: Mutex>, + list_map: Mutex>>, + cleanup_result: Mutex, + create_response: Mutex<(String, String)>, + } + + impl Default for RecordingSecureBridgeInner { + fn default() -> Self { + Self { + operations: Mutex::new(Vec::new()), + inspect_map: Mutex::new(HashMap::new()), + endpoint_map: Mutex::new(HashMap::new()), + logs_map: Mutex::new(HashMap::new()), + list_map: Mutex::new(HashMap::new()), + cleanup_result: Mutex::new(BrokerCleanupResult { + total: 0, + stopped: 0, + removed: 0, + errors: Vec::new(), + }), + create_response: Mutex::new(("container-id".to_string(), "container".to_string())), + } + } + } + + impl RecordingSecureBridge { + fn operations(&self) -> Vec { + self.inner.operations.lock().unwrap().clone() + } + + fn set_inspect_state(&self, id: &str, state: ContainerState) { + self.inner + .inspect_map + .lock() + .unwrap() + .insert(id.to_string(), container_info(id, state)); + } + + fn set_endpoint(&self, id: &str, endpoint: &str) { + self.inner + .endpoint_map + .lock() + .unwrap() + .insert(id.to_string(), endpoint.to_string()); + } + + fn set_logs(&self, id: &str, logs: &str) { + self.inner + .logs_map + .lock() + .unwrap() + .insert(id.to_string(), logs.to_string()); + } + + fn set_list(&self, challenge: &str, containers: Vec) { + self.inner + .list_map + .lock() + .unwrap() + .insert(challenge.to_string(), containers); + } + + fn set_cleanup_result(&self, result: BrokerCleanupResult) { + *self.inner.cleanup_result.lock().unwrap() = result; + } + + fn set_create_response(&self, id: &str, name: &str) { + *self.inner.create_response.lock().unwrap() = (id.to_string(), name.to_string()); + } + } + + #[async_trait] + impl SecureContainerBridge for RecordingSecureBridge { + async fn create_container( + &self, + config: ContainerConfig, + ) -> Result<(String, String), ContainerError> { + self.inner + .operations + .lock() + .unwrap() + .push(format!("create:{}", config.challenge_id)); + Ok(self.inner.create_response.lock().unwrap().clone()) + } + + async fn start_container( + &self, + container_id: &str, + ) -> Result { + self.inner + .operations + .lock() + .unwrap() + .push(format!("start:{container_id}")); + Ok(ContainerStartResult { + container_id: container_id.to_string(), + ports: HashMap::new(), + endpoint: None, + }) + } + + async fn get_endpoint( + &self, + container_id: &str, + port: u16, + ) -> Result { + self.inner + .operations + .lock() + .unwrap() + .push(format!("endpoint:{container_id}:{port}")); + self.inner + .endpoint_map + .lock() + .unwrap() + .get(container_id) + .cloned() + .ok_or_else(|| ContainerError::ContainerNotFound(container_id.to_string())) + } + + async fn stop_container( + &self, + container_id: &str, + timeout_secs: u32, + ) -> Result<(), ContainerError> { + self.inner + .operations + .lock() + .unwrap() + .push(format!("stop:{container_id}:{timeout_secs}")); + Ok(()) + } + + async fn remove_container( + &self, + container_id: &str, + force: bool, + ) -> Result<(), ContainerError> { + self.inner + .operations + .lock() + .unwrap() + .push(format!("remove:{container_id}:{force}")); + Ok(()) + } + + async fn inspect(&self, container_id: &str) -> Result { + self.inner + .operations + .lock() + .unwrap() + .push(format!("inspect:{container_id}")); + self.inner + .inspect_map + .lock() + .unwrap() + .get(container_id) + .cloned() + .ok_or_else(|| ContainerError::ContainerNotFound(container_id.to_string())) + } + + async fn pull_image(&self, image: &str) -> Result<(), ContainerError> { + self.inner + .operations + .lock() + .unwrap() + .push(format!("pull:{image}")); + Ok(()) + } + + async fn logs(&self, container_id: &str, tail: usize) -> Result { + self.inner + .operations + .lock() + .unwrap() + .push(format!("logs:{container_id}:{tail}")); + self.inner + .logs_map + .lock() + .unwrap() + .get(container_id) + .cloned() + .ok_or_else(|| ContainerError::ContainerNotFound(container_id.to_string())) + } + + async fn cleanup_challenge( + &self, + challenge_id: &str, + ) -> Result { + self.inner + .operations + .lock() + .unwrap() + .push(format!("cleanup:{challenge_id}")); + Ok(self.inner.cleanup_result.lock().unwrap().clone()) + } + + async fn list_by_challenge( + &self, + challenge_id: &str, + ) -> Result, ContainerError> { + self.inner + .operations + .lock() + .unwrap() + .push(format!("list:{challenge_id}")); + Ok(self + .inner + .list_map + .lock() + .unwrap() + .get(challenge_id) + .cloned() + .unwrap_or_default()) + } + } + + #[derive(Clone, Default)] + struct RecordingChallengeDocker { + inner: Arc, + } + + #[derive(Default)] + struct RecordingChallengeDockerInner { + operations: Mutex>, + running: Mutex>, + logs: Mutex>, + list: Mutex>, + next_id: Mutex, + } + + impl RecordingChallengeDocker { + fn operations(&self) -> Vec { + self.inner.operations.lock().unwrap().clone() + } + + fn set_running(&self, id: &str, running: bool) { + self.inner + .running + .lock() + .unwrap() + .insert(id.to_string(), running); + } + + fn set_logs(&self, id: &str, logs: &str) { + self.inner + .logs + .lock() + .unwrap() + .insert(id.to_string(), logs.to_string()); + } + + fn set_list(&self, items: Vec) { + *self.inner.list.lock().unwrap() = items; + } + + fn next_instance(&self, config: &ChallengeContainerConfig) -> ChallengeInstance { + let mut guard = self.inner.next_id.lock().unwrap(); + let value = *guard; + *guard += 1; + let suffix = value.to_string(); + sample_instance( + config.challenge_id, + &format!("container-{}", suffix), + &config.docker_image, + ContainerStatus::Running, + ) + } + } + + fn sample_instance( + challenge_id: ChallengeId, + container_id: &str, + image: &str, + status: ContainerStatus, + ) -> ChallengeInstance { + ChallengeInstance { + challenge_id, + container_id: container_id.to_string(), + image: image.to_string(), + endpoint: format!("http://{container_id}"), + started_at: Utc::now(), + status, + } + } + + #[async_trait] + impl ChallengeDocker for RecordingChallengeDocker { + async fn pull_image(&self, image: &str) -> anyhow::Result<()> { + self.inner + .operations + .lock() + .unwrap() + .push(format!("pull:{image}")); + Ok(()) + } + + async fn start_challenge( + &self, + config: &ChallengeContainerConfig, + ) -> anyhow::Result { + self.inner + .operations + .lock() + .unwrap() + .push(format!("start:{}", config.challenge_id)); + Ok(self.next_instance(config)) + } + + async fn stop_container(&self, container_id: &str) -> anyhow::Result<()> { + self.inner + .operations + .lock() + .unwrap() + .push(format!("stop:{container_id}")); + Ok(()) + } + + async fn remove_container(&self, container_id: &str) -> anyhow::Result<()> { + self.inner + .operations + .lock() + .unwrap() + .push(format!("remove:{container_id}")); + Ok(()) + } + + async fn is_container_running(&self, container_id: &str) -> anyhow::Result { + self.inner + .operations + .lock() + .unwrap() + .push(format!("is_running:{container_id}")); + Ok(*self + .inner + .running + .lock() + .unwrap() + .get(container_id) + .unwrap_or(&false)) + } + + async fn get_logs(&self, container_id: &str, tail: usize) -> anyhow::Result { + self.inner + .operations + .lock() + .unwrap() + .push(format!("logs:{container_id}:{tail}")); + Ok(self + .inner + .logs + .lock() + .unwrap() + .get(container_id) + .cloned() + .unwrap_or_default()) + } + + async fn list_challenge_containers(&self) -> anyhow::Result> { + self.inner + .operations + .lock() + .unwrap() + .push("list_containers".to_string()); + Ok(self.inner.list.lock().unwrap().clone()) + } + + async fn cleanup_stale_containers( + &self, + prefix: &str, + _max_age_minutes: u64, + _exclude_patterns: &[&str], + ) -> anyhow::Result { + self.inner + .operations + .lock() + .unwrap() + .push(format!("cleanup:{prefix}")); + Ok(DockerCleanupResult::default()) + } + } +} diff --git a/crates/challenge-orchestrator/src/config.rs b/crates/challenge-orchestrator/src/config.rs index 5ba4f43ec..18fa09b78 100644 --- a/crates/challenge-orchestrator/src/config.rs +++ b/crates/challenge-orchestrator/src/config.rs @@ -1,4 +1,8 @@ //! Configuration types for challenge orchestrator +//! +//! Exposes `OrchestratorConfig`, which is serializable/deserializable with +//! human-friendly duration fields (plain seconds) so it can be shared between +//! the validator process and external tooling. use serde::{Deserialize, Serialize}; use std::time::Duration; @@ -14,10 +18,10 @@ pub struct OrchestratorConfig { /// Health check interval #[serde(with = "humantime_serde")] pub health_check_interval: Duration, - /// Container stop timeout + /// Grace period to give Docker before force-stopping a container #[serde(with = "humantime_serde")] pub stop_timeout: Duration, - /// Docker registry (optional, for private registries) + /// Optional registry credentials for private images pub registry: Option, } @@ -32,7 +36,7 @@ impl Default for OrchestratorConfig { } } -/// Docker registry configuration +/// Optional Docker registry credentials for pulling private challenge images. #[derive(Clone, Debug, Serialize, Deserialize)] pub struct RegistryConfig { pub url: String, @@ -71,4 +75,69 @@ mod tests { assert_eq!(config.network_name, "platform-network"); assert_eq!(config.health_check_interval, Duration::from_secs(30)); } + + #[test] + fn test_default_config_stop_timeout_and_registry() { + let config = OrchestratorConfig::default(); + assert_eq!(config.stop_timeout, Duration::from_secs(30)); + assert!(config.registry.is_none()); + } + + #[test] + fn test_config_serializes_durations_as_seconds() { + let config = OrchestratorConfig { + network_name: "custom".into(), + health_check_interval: Duration::from_secs(45), + stop_timeout: Duration::from_secs(120), + registry: Some(RegistryConfig { + url: "https://registry.example.com".into(), + username: Some("alice".into()), + password: Some("secret".into()), + }), + }; + + let json = serde_json::to_value(&config).expect("serialize config"); + assert_eq!(json["health_check_interval"], 45); + assert_eq!(json["stop_timeout"], 120); + + let round_trip: OrchestratorConfig = serde_json::from_value(json).expect("deserialize"); + assert_eq!(round_trip.health_check_interval, Duration::from_secs(45)); + assert_eq!(round_trip.stop_timeout, Duration::from_secs(120)); + assert_eq!( + round_trip.registry.unwrap().username.as_deref(), + Some("alice") + ); + } + + #[test] + fn test_humantime_deserialize_rejects_negative_values() { + #[derive(Debug, Deserialize)] + struct DurationWrapper { + #[serde(with = "super::humantime_serde")] + value: Duration, + } + + let err = serde_json::from_str::(r#"{"value": -5}"#) + .expect_err("negative durations rejected"); + assert!(err.to_string().contains("invalid value")); + } + + #[test] + fn test_humantime_serializes_large_values() { + #[derive(Serialize, Deserialize, PartialEq, Debug)] + struct DurationWrapper { + #[serde(with = "super::humantime_serde")] + value: Duration, + } + + let original = DurationWrapper { + value: Duration::from_secs(24 * 60 * 60), + }; + let json = serde_json::to_string(&original).expect("serialize duration wrapper"); + assert!(json.contains("86400")); + + let round_trip: DurationWrapper = + serde_json::from_str(&json).expect("deserialize duration wrapper"); + assert_eq!(round_trip, original); + } } diff --git a/crates/challenge-orchestrator/src/docker.rs b/crates/challenge-orchestrator/src/docker.rs index 2c71ffb98..5d3555296 100644 --- a/crates/challenge-orchestrator/src/docker.rs +++ b/crates/challenge-orchestrator/src/docker.rs @@ -1,62 +1,311 @@ //! Docker client wrapper for container management //! -//! SECURITY: Only images from whitelisted registries (ghcr.io/platformnetwork/) -//! are allowed to be pulled or run. This prevents malicious container attacks. +//! Provides the low-level primitives required when the orchestrator is +//! connected directly to Docker (typically during development or when the +//! secure broker is unavailable). Network bootstrap, log streaming, and image +//! pulls are all funneled through a thin trait (`DockerBridge`) that makes it +//! easy to stub the Docker daemon in tests. +//! +//! SECURITY: Only images from allow-listed registries +//! (`ghcr.io/platformnetwork/`) are allowed to be pulled or run. This prevents +//! malicious container attacks when bypassing the broker. use crate::{ChallengeContainerConfig, ChallengeInstance, ContainerStatus}; +use async_trait::async_trait; use bollard::container::{ - Config, CreateContainerOptions, ListContainersOptions, RemoveContainerOptions, - StartContainerOptions, StopContainerOptions, + Config, CreateContainerOptions, InspectContainerOptions, ListContainersOptions, LogsOptions, + RemoveContainerOptions, StartContainerOptions, StopContainerOptions, }; +use bollard::errors::Error as DockerError; use bollard::image::CreateImageOptions; -use bollard::models::{DeviceRequest, HostConfig, PortBinding}; +use bollard::models::{ + ContainerCreateResponse, ContainerInspectResponse, ContainerSummary, CreateImageInfo, + DeviceRequest, HostConfig, Network, PortBinding, +}; +use bollard::network::{ConnectNetworkOptions, CreateNetworkOptions, ListNetworksOptions}; +use bollard::volume::CreateVolumeOptions; use bollard::Docker; -use futures::StreamExt; +use futures::{Stream, StreamExt}; use platform_core::ALLOWED_DOCKER_PREFIXES; use std::collections::HashMap; +use std::pin::Pin; +use std::sync::Arc; use tracing::{debug, error, info, warn}; +type ImageStream = Pin> + Send>>; +type LogStream = + Pin> + Send>>; + +#[async_trait] +pub trait DockerBridge: Send + Sync { + async fn ping(&self) -> Result<(), DockerError>; + async fn list_networks( + &self, + options: Option>, + ) -> Result, DockerError>; + async fn create_network( + &self, + options: CreateNetworkOptions, + ) -> Result<(), DockerError>; + async fn inspect_container( + &self, + id: &str, + options: Option, + ) -> Result; + async fn connect_network( + &self, + network: &str, + options: ConnectNetworkOptions, + ) -> Result<(), DockerError>; + fn create_image_stream(&self, options: Option>) -> ImageStream; + async fn create_volume(&self, options: CreateVolumeOptions) -> Result<(), DockerError>; + async fn create_container( + &self, + options: Option>, + config: Config, + ) -> Result; + async fn start_container( + &self, + id: &str, + options: Option>, + ) -> Result<(), DockerError>; + async fn stop_container( + &self, + id: &str, + options: Option, + ) -> Result<(), DockerError>; + async fn remove_container( + &self, + id: &str, + options: Option, + ) -> Result<(), DockerError>; + async fn list_containers( + &self, + options: Option>, + ) -> Result, DockerError>; + fn logs_stream(&self, id: &str, options: LogsOptions) -> LogStream; +} + +#[derive(Clone)] +struct BollardBridge { + docker: Docker, +} + +impl BollardBridge { + fn new(docker: Docker) -> Self { + Self { docker } + } +} + +#[async_trait] +impl DockerBridge for BollardBridge { + async fn ping(&self) -> Result<(), DockerError> { + self.docker.ping().await.map(|_| ()) + } + + async fn list_networks( + &self, + options: Option>, + ) -> Result, DockerError> { + self.docker.list_networks(options).await + } + + async fn create_network( + &self, + options: CreateNetworkOptions, + ) -> Result<(), DockerError> { + self.docker.create_network(options).await.map(|_| ()) + } + + async fn inspect_container( + &self, + id: &str, + options: Option, + ) -> Result { + self.docker.inspect_container(id, options).await + } + + async fn connect_network( + &self, + network: &str, + options: ConnectNetworkOptions, + ) -> Result<(), DockerError> { + self.docker.connect_network(network, options).await + } + + fn create_image_stream(&self, options: Option>) -> ImageStream { + Box::pin(self.docker.create_image(options, None, None)) + } + + async fn create_volume(&self, options: CreateVolumeOptions) -> Result<(), DockerError> { + self.docker.create_volume(options).await.map(|_| ()) + } + + async fn create_container( + &self, + options: Option>, + config: Config, + ) -> Result { + self.docker.create_container(options, config).await + } + + async fn start_container( + &self, + id: &str, + options: Option>, + ) -> Result<(), DockerError> { + self.docker.start_container(id, options).await + } + + async fn stop_container( + &self, + id: &str, + options: Option, + ) -> Result<(), DockerError> { + self.docker.stop_container(id, options).await + } + + async fn remove_container( + &self, + id: &str, + options: Option, + ) -> Result<(), DockerError> { + self.docker.remove_container(id, options).await + } + + async fn list_containers( + &self, + options: Option>, + ) -> Result, DockerError> { + self.docker.list_containers(options).await + } + + fn logs_stream(&self, id: &str, options: LogsOptions) -> LogStream { + Box::pin(self.docker.logs(id, Some(options))) + } +} + /// Docker client for managing challenge containers +/// +/// The client ensures challenge containers are attached to the configured +/// network, reuses volumes when possible, and funnels blocking Docker API +/// calls through an async-friendly bridge. pub struct DockerClient { - docker: Docker, + docker: Arc, network_name: String, } +#[async_trait] +pub trait ChallengeDocker: Send + Sync { + async fn pull_image(&self, image: &str) -> anyhow::Result<()>; + async fn start_challenge( + &self, + config: &ChallengeContainerConfig, + ) -> anyhow::Result; + async fn stop_container(&self, container_id: &str) -> anyhow::Result<()>; + async fn remove_container(&self, container_id: &str) -> anyhow::Result<()>; + async fn is_container_running(&self, container_id: &str) -> anyhow::Result; + async fn get_logs(&self, container_id: &str, tail: usize) -> anyhow::Result; + async fn list_challenge_containers(&self) -> anyhow::Result>; + async fn cleanup_stale_containers( + &self, + prefix: &str, + max_age_minutes: u64, + exclude_patterns: &[&str], + ) -> anyhow::Result; +} + +#[async_trait] +impl ChallengeDocker for DockerClient { + async fn pull_image(&self, image: &str) -> anyhow::Result<()> { + DockerClient::pull_image(self, image).await + } + + async fn start_challenge( + &self, + config: &ChallengeContainerConfig, + ) -> anyhow::Result { + DockerClient::start_challenge(self, config).await + } + + async fn stop_container(&self, container_id: &str) -> anyhow::Result<()> { + DockerClient::stop_container(self, container_id).await + } + + async fn remove_container(&self, container_id: &str) -> anyhow::Result<()> { + DockerClient::remove_container(self, container_id).await + } + + async fn is_container_running(&self, container_id: &str) -> anyhow::Result { + DockerClient::is_container_running(self, container_id).await + } + + async fn get_logs(&self, container_id: &str, tail: usize) -> anyhow::Result { + DockerClient::get_logs(self, container_id, tail).await + } + + async fn list_challenge_containers(&self) -> anyhow::Result> { + DockerClient::list_challenge_containers(self).await + } + + async fn cleanup_stale_containers( + &self, + prefix: &str, + max_age_minutes: u64, + exclude_patterns: &[&str], + ) -> anyhow::Result { + DockerClient::cleanup_stale_containers(self, prefix, max_age_minutes, exclude_patterns) + .await + } +} + impl DockerClient { + fn from_bridge(docker: Arc, network_name: impl Into) -> Self { + Self { + docker, + network_name: network_name.into(), + } + } + + /// Build a client from a custom bridge (used for tests/mocks) + pub fn with_bridge( + docker: impl DockerBridge + 'static, + network_name: impl Into, + ) -> Self { + Self::from_bridge(Arc::new(docker), network_name) + } + /// Connect to Docker daemon pub async fn connect() -> anyhow::Result { let docker = Docker::connect_with_local_defaults()?; // Verify connection - docker.ping().await?; + let bridge = Arc::new(BollardBridge::new(docker)); + bridge.ping().await?; info!("Connected to Docker daemon"); - Ok(Self { - docker, - network_name: "platform-network".to_string(), - }) + Ok(Self::from_bridge(bridge, "platform-network")) } /// Connect with custom network name pub async fn connect_with_network(network_name: &str) -> anyhow::Result { let docker = Docker::connect_with_local_defaults()?; - docker.ping().await?; + let bridge = Arc::new(BollardBridge::new(docker)); + bridge.ping().await?; - Ok(Self { - docker, - network_name: network_name.to_string(), - }) + Ok(Self::from_bridge(bridge, network_name)) } /// Connect and auto-detect the network from the validator container /// This ensures challenge containers are on the same network as the validator pub async fn connect_auto_detect() -> anyhow::Result { let docker = Docker::connect_with_local_defaults()?; - docker.ping().await?; + let bridge = Arc::new(BollardBridge::new(docker)); + bridge.ping().await?; info!("Connected to Docker daemon"); // Try to detect the network from the current container - let network_name = Self::detect_validator_network_static(&docker) + let network_name = Self::detect_validator_network(&*bridge) .await .unwrap_or_else(|e| { warn!( @@ -68,14 +317,11 @@ impl DockerClient { info!(network = %network_name, "Using network for challenge containers"); - Ok(Self { - docker, - network_name, - }) + Ok(Self::from_bridge(bridge, network_name)) } /// Detect the network the validator container is running on - async fn detect_validator_network_static(docker: &Docker) -> anyhow::Result { + async fn detect_validator_network(docker: &dyn DockerBridge) -> anyhow::Result { // Get our container ID let container_id = Self::get_container_id_static()?; @@ -198,7 +444,11 @@ impl DockerClient { /// Ensure the Docker network exists pub async fn ensure_network(&self) -> anyhow::Result<()> { - let networks = self.docker.list_networks::(None).await?; + let networks = self + .docker + .list_networks(None::>) + .await + .map_err(anyhow::Error::from)?; let exists = networks.iter().any(|n| { n.name @@ -216,7 +466,10 @@ impl DockerClient { ..Default::default() }; - self.docker.create_network(config).await?; + self.docker + .create_network(config) + .await + .map_err(anyhow::Error::from)?; info!(network = %self.network_name, "Created Docker network"); } else { debug!(network = %self.network_name, "Docker network already exists"); @@ -232,7 +485,11 @@ impl DockerClient { let container_id = self.get_self_container_id()?; // Check if already connected - let inspect = self.docker.inspect_container(&container_id, None).await?; + let inspect = self + .docker + .inspect_container(&container_id, None) + .await + .map_err(anyhow::Error::from)?; let networks = inspect .network_settings .as_ref() @@ -260,7 +517,8 @@ impl DockerClient { self.docker .connect_network(&self.network_name, config) - .await?; + .await + .map_err(anyhow::Error::from)?; info!( container = %container_id, @@ -359,11 +617,11 @@ impl DockerClient { info!(image = %image, "Pulling Docker image (whitelisted)"); let options = CreateImageOptions { - from_image: image, + from_image: image.to_string(), ..Default::default() }; - let mut stream = self.docker.create_image(Some(options), None, None); + let mut stream = self.docker.create_image_stream(Some(options)); while let Some(result) = stream.next().await { match result { @@ -457,9 +715,9 @@ impl DockerClient { let volume_name = format!("{}-data", container_name); // Create volumes if they don't exist (Docker will auto-create on mount, but explicit is clearer) - let volume_opts = bollard::volume::CreateVolumeOptions { - name: volume_name.as_str(), - driver: "local", + let volume_opts = CreateVolumeOptions { + name: volume_name.clone(), + driver: "local".to_string(), ..Default::default() }; if let Err(e) = self.docker.create_volume(volume_opts).await { @@ -473,9 +731,9 @@ impl DockerClient { "challenge-{}-cache", config.name.to_lowercase().replace(' ', "-") ); - let cache_volume_opts = bollard::volume::CreateVolumeOptions { - name: cache_volume_name.as_str(), - driver: "local", + let cache_volume_opts = CreateVolumeOptions { + name: cache_volume_name.clone(), + driver: "local".to_string(), ..Default::default() }; if let Err(e) = self.docker.create_volume(cache_volume_opts).await { @@ -489,9 +747,9 @@ impl DockerClient { let evals_volume = "term-challenge-evals"; for vol_name in [tasks_volume, dind_cache_volume, evals_volume] { - let vol_opts = bollard::volume::CreateVolumeOptions { - name: vol_name, - driver: "local", + let vol_opts = CreateVolumeOptions { + name: vol_name.to_string(), + driver: "local".to_string(), ..Default::default() }; if let Err(e) = self.docker.create_volume(vol_opts).await { @@ -687,23 +945,29 @@ impl DockerClient { // Create container let options = CreateContainerOptions { - name: &container_name, + name: container_name.clone(), platform: None, }; let response = self .docker .create_container(Some(options), container_config) - .await?; + .await + .map_err(anyhow::Error::from)?; let container_id = response.id; // Start container self.docker .start_container(&container_id, None::>) - .await?; + .await + .map_err(anyhow::Error::from)?; // Get assigned port - let inspect = self.docker.inspect_container(&container_id, None).await?; + let inspect = self + .docker + .inspect_container(&container_id, None) + .await + .map_err(anyhow::Error::from)?; let port = inspect .network_settings .and_then(|ns| ns.ports) @@ -805,24 +1069,27 @@ impl DockerClient { /// List all challenge containers pub async fn list_challenge_containers(&self) -> anyhow::Result> { - let mut filters = HashMap::new(); - filters.insert("name", vec!["challenge-"]); - filters.insert("network", vec![self.network_name.as_str()]); + let mut filters: HashMap> = HashMap::new(); + filters.insert("name".to_string(), vec!["challenge-".to_string()]); + filters.insert("network".to_string(), vec![self.network_name.clone()]); - let options = ListContainersOptions { + let options = ListContainersOptions:: { all: true, filters, ..Default::default() }; - let containers = self.docker.list_containers(Some(options)).await?; + let containers = self + .docker + .list_containers(Some(options)) + .await + .map_err(anyhow::Error::from)?; Ok(containers.into_iter().filter_map(|c| c.id).collect()) } /// Get container logs pub async fn get_logs(&self, container_id: &str, tail: usize) -> anyhow::Result { - use bollard::container::LogsOptions; use futures::TryStreamExt; let options = LogsOptions:: { @@ -834,7 +1101,7 @@ impl DockerClient { let logs: Vec<_> = self .docker - .logs(container_id, Some(options)) + .logs_stream(container_id, options) .try_collect() .await?; @@ -867,12 +1134,14 @@ impl DockerClient { let mut result = CleanupResult::default(); // List ALL containers (including stopped) - let options = ListContainersOptions:: { - all: true, - ..Default::default() - }; + let mut options: ListContainersOptions = Default::default(); + options.all = true; - let containers = self.docker.list_containers(Some(options)).await?; + let containers = self + .docker + .list_containers(Some(options)) + .await + .map_err(anyhow::Error::from)?; let now = chrono::Utc::now().timestamp(); let max_age_secs = (max_age_minutes * 60) as i64; @@ -941,8 +1210,335 @@ impl DockerClient { } } +#[cfg(test)] +mod tests { + use super::*; + use bollard::models::EndpointSettings; + use futures::StreamExt; + use serial_test::serial; + use std::collections::HashMap; + use std::sync::{Arc, Mutex}; + + fn reset_env(keys: &[&str]) { + for key in keys { + std::env::remove_var(key); + } + } + + #[test] + #[serial] + fn test_is_image_allowed_enforces_whitelist() { + reset_env(&["DEVELOPMENT_MODE"]); + assert!(DockerClient::is_image_allowed( + "ghcr.io/platformnetwork/challenge:latest" + )); + assert!(!DockerClient::is_image_allowed( + "docker.io/library/alpine:latest" + )); + } + + #[test] + #[serial] + fn test_is_image_allowed_allows_dev_mode_override() { + std::env::set_var("DEVELOPMENT_MODE", "true"); + assert!(DockerClient::is_image_allowed( + "docker.io/library/alpine:latest" + )); + reset_env(&["DEVELOPMENT_MODE"]); + } + + #[test] + #[serial] + fn test_is_image_allowed_case_insensitive() { + reset_env(&["DEVELOPMENT_MODE"]); + assert!(DockerClient::is_image_allowed( + "GHCR.IO/PLATFORMNETWORK/IMAGE:TAG" + )); + } + + #[test] + #[serial] + fn test_get_validator_suffix_prefers_validator_name() { + reset_env(&["VALIDATOR_NAME", "HOSTNAME"]); + std::env::set_var("VALIDATOR_NAME", "Node 42-Test"); + std::env::set_var("HOSTNAME", "should_not_be_used"); + + let suffix = DockerClient::get_validator_suffix(); + assert_eq!(suffix, "node42test"); + + reset_env(&["VALIDATOR_NAME", "HOSTNAME"]); + } + + #[test] + #[serial] + fn test_get_validator_suffix_uses_container_id_from_hostname() { + reset_env(&["VALIDATOR_NAME"]); + std::env::set_var("HOSTNAME", "abcdef123456"); + + let suffix = DockerClient::get_validator_suffix(); + assert_eq!(suffix, "abcdef123456"); + + reset_env(&["HOSTNAME"]); + } + + #[tokio::test] + #[ignore = "requires Docker"] + async fn test_docker_connect() { + let client = DockerClient::connect().await; + assert!(client.is_ok()); + } + + #[derive(Clone, Default)] + struct RecordingBridge { + inner: Arc, + } + + #[derive(Default)] + struct RecordingBridgeInner { + networks: Mutex>, + created_networks: Mutex>, + containers: Mutex>, + removed: Mutex>, + inspect_map: Mutex>, + connect_calls: Mutex>, + } + + impl RecordingBridge { + fn with_networks(names: &[&str]) -> Self { + let bridge = RecordingBridge::default(); + { + let mut lock = bridge.inner.networks.lock().unwrap(); + for name in names { + lock.push(Network { + name: Some(name.to_string()), + ..Default::default() + }); + } + } + bridge + } + + fn created_networks(&self) -> Vec { + self.inner.created_networks.lock().unwrap().clone() + } + + fn set_inspect_networks(&self, container_id: &str, networks: &[&str]) { + let mut map: HashMap = HashMap::new(); + for name in networks { + map.insert(name.to_string(), Default::default()); + } + let response = ContainerInspectResponse { + network_settings: Some(bollard::models::NetworkSettings { + networks: Some(map), + ..Default::default() + }), + ..Default::default() + }; + self.inner + .inspect_map + .lock() + .unwrap() + .insert(container_id.to_string(), response); + } + + fn set_containers(&self, containers: Vec) { + *self.inner.containers.lock().unwrap() = containers; + } + + fn removed_containers(&self) -> Vec { + self.inner.removed.lock().unwrap().clone() + } + + fn connect_calls(&self) -> Vec<(String, String)> { + self.inner.connect_calls.lock().unwrap().clone() + } + } + + #[async_trait] + impl DockerBridge for RecordingBridge { + async fn ping(&self) -> Result<(), DockerError> { + Ok(()) + } + + async fn list_networks( + &self, + _options: Option>, + ) -> Result, DockerError> { + Ok(self.inner.networks.lock().unwrap().clone()) + } + + async fn create_network( + &self, + options: CreateNetworkOptions, + ) -> Result<(), DockerError> { + self.inner + .created_networks + .lock() + .unwrap() + .push(options.name); + Ok(()) + } + + async fn inspect_container( + &self, + id: &str, + _options: Option, + ) -> Result { + self.inner + .inspect_map + .lock() + .unwrap() + .get(id) + .cloned() + .ok_or_else(|| DockerError::IOError { + err: std::io::Error::new(std::io::ErrorKind::NotFound, "missing inspect"), + }) + } + + async fn connect_network( + &self, + network: &str, + options: ConnectNetworkOptions, + ) -> Result<(), DockerError> { + self.inner + .connect_calls + .lock() + .unwrap() + .push((options.container, network.to_string())); + Ok(()) + } + + fn create_image_stream(&self, _options: Option>) -> ImageStream { + futures::stream::empty().boxed() + } + + async fn create_volume( + &self, + _options: CreateVolumeOptions, + ) -> Result<(), DockerError> { + Ok(()) + } + + async fn create_container( + &self, + _options: Option>, + _config: Config, + ) -> Result { + panic!("not used in tests") + } + + async fn start_container( + &self, + _id: &str, + _options: Option>, + ) -> Result<(), DockerError> { + panic!("not used in tests") + } + + async fn stop_container( + &self, + _id: &str, + _options: Option, + ) -> Result<(), DockerError> { + panic!("not used in tests") + } + + async fn remove_container( + &self, + id: &str, + _options: Option, + ) -> Result<(), DockerError> { + self.inner.removed.lock().unwrap().push(id.to_string()); + Ok(()) + } + + async fn list_containers( + &self, + _options: Option>, + ) -> Result, DockerError> { + Ok(self.inner.containers.lock().unwrap().clone()) + } + + fn logs_stream(&self, _id: &str, _options: LogsOptions) -> LogStream { + futures::stream::empty().boxed() + } + } + + #[tokio::test] + async fn test_ensure_network_creates_when_missing() { + let bridge = RecordingBridge::default(); + let client = DockerClient::with_bridge(bridge.clone(), "platform-network"); + client.ensure_network().await.unwrap(); + assert_eq!( + bridge.created_networks(), + vec!["platform-network".to_string()] + ); + } + + #[tokio::test] + async fn test_ensure_network_skips_existing() { + let bridge = RecordingBridge::with_networks(&["platform-network"]); + let client = DockerClient::with_bridge(bridge.clone(), "platform-network"); + client.ensure_network().await.unwrap(); + assert!(bridge.created_networks().is_empty()); + } + + #[tokio::test] + #[serial] + async fn test_connect_self_to_network_only_when_needed() { + let bridge = RecordingBridge::default(); + let container_id = "aaaaaaaaaaaa"; + std::env::set_var("HOSTNAME", container_id); + bridge.set_inspect_networks(container_id, &[]); + let client = DockerClient::with_bridge(bridge.clone(), "platform-network"); + client.connect_self_to_network().await.unwrap(); + assert_eq!( + bridge.connect_calls(), + vec![(container_id.to_string(), "platform-network".to_string())] + ); + + let bridge2 = RecordingBridge::default(); + let container_two = "bbbbbbbbbbbb"; + std::env::set_var("HOSTNAME", container_two); + bridge2.set_inspect_networks(container_two, &["platform-network"]); + let client2 = DockerClient::with_bridge(bridge2.clone(), "platform-network"); + client2.connect_self_to_network().await.unwrap(); + assert!(bridge2.connect_calls().is_empty()); + std::env::remove_var("HOSTNAME"); + } + + fn make_container_summary(id: &str, name: &str, created: i64) -> ContainerSummary { + ContainerSummary { + id: Some(id.to_string()), + names: Some(vec![format!("/{name}")]), + created: Some(created), + ..Default::default() + } + } + + #[tokio::test] + async fn test_cleanup_stale_containers_filters_entries() { + let bridge = RecordingBridge::default(); + let now = chrono::Utc::now().timestamp(); + bridge.set_containers(vec![ + make_container_summary("old", "term-challenge-old", now - 10_000), + make_container_summary("exclude", "platform-helper", now - 10_000), + make_container_summary("young", "term-challenge-young", now - 100), + ]); + let client = DockerClient::with_bridge(bridge.clone(), "platform-network"); + + let result = client + .cleanup_stale_containers("term-challenge-", 120, &["platform-"]) + .await + .unwrap(); + assert_eq!(result.total_found, 1); + assert_eq!(result.removed, 1); + assert_eq!(bridge.removed_containers(), vec!["old".to_string()]); + } +} + /// Result of container cleanup operation -#[derive(Debug, Default)] +#[derive(Debug, Default, Clone)] pub struct CleanupResult { pub total_found: usize, pub removed: usize, @@ -956,13 +1552,15 @@ impl CleanupResult { } #[cfg(test)] -mod tests { - use super::*; +mod cleanup_tests { + use super::CleanupResult; - #[tokio::test] - #[ignore = "requires Docker"] - async fn test_docker_connect() { - let client = DockerClient::connect().await; - assert!(client.is_ok()); + #[test] + fn test_cleanup_result_success_flag() { + let mut result = CleanupResult::default(); + assert!(result.success()); + + result.errors.push("boom".into()); + assert!(!result.success()); } } diff --git a/crates/challenge-orchestrator/src/evaluator.rs b/crates/challenge-orchestrator/src/evaluator.rs index 24af7235a..85647ffc9 100644 --- a/crates/challenge-orchestrator/src/evaluator.rs +++ b/crates/challenge-orchestrator/src/evaluator.rs @@ -1,9 +1,11 @@ //! Challenge evaluator - generic proxy for routing requests to challenge containers //! -//! IMPORTANT: This evaluator is challenge-agnostic. Each challenge defines its own -//! request/response format. The evaluator simply proxies JSON payloads. +//! The evaluator keeps HTTP plumbing separate from challenge logic. It simply +//! forwards JSON payloads to the configured container endpoint, enforces +//! timeouts, and surfaces useful errors back to the validator. //! -//! For term-challenge specific formats, see term-challenge-repo/src/server.rs +//! For challenge-specific schemas, see each challenge repository (for example, +//! `term-challenge-repo/src/server.rs`). use crate::{ChallengeInstance, ContainerStatus}; use parking_lot::RwLock; @@ -14,7 +16,8 @@ use std::sync::Arc; use std::time::Duration; use tracing::{debug, info, warn}; -/// Generic evaluator for routing requests to challenge containers +/// Generic evaluator for routing requests to challenge containers with baked-in +/// HTTP client configuration (timeouts, retries handled upstream). pub struct ChallengeEvaluator { challenges: Arc>>, client: reqwest::Client, @@ -270,6 +273,25 @@ pub enum EvaluatorError { #[cfg(test)] mod tests { use super::*; + use parking_lot::RwLock; + use platform_core::ChallengeId; + use std::collections::HashMap; + use std::sync::Arc; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio::net::TcpListener; + use tokio::task::JoinHandle; + use tokio_test::block_on; + + fn sample_instance(status: ContainerStatus) -> ChallengeInstance { + ChallengeInstance { + challenge_id: ChallengeId::new(), + container_id: "cid".into(), + image: "ghcr.io/platformnetwork/example:latest".into(), + endpoint: "http://127.0.0.1:9000".into(), + started_at: chrono::Utc::now(), + status, + } + } #[test] fn test_challenge_info_deserialize() { @@ -283,4 +305,337 @@ mod tests { assert_eq!(info.name, "term-challenge"); assert_eq!(info.mechanism_id, 0); // default } + + #[test] + fn test_evaluate_generic_requires_running_status() { + let challenges = Arc::new(RwLock::new(HashMap::new())); + let instance = sample_instance(ContainerStatus::Starting); + let challenge_id = instance.challenge_id; + challenges.write().insert(challenge_id, instance.clone()); + + let evaluator = ChallengeEvaluator::new(challenges); + let err = block_on(evaluator.evaluate_generic(challenge_id, serde_json::json!({}), None)) + .expect_err("should fail when not running"); + + match err { + EvaluatorError::ChallengeNotReady(id) => assert_eq!(id, challenge_id), + other => panic!("unexpected error: {:?}", other), + } + } + + #[test] + fn test_proxy_request_missing_challenge() { + let evaluator = ChallengeEvaluator::new(Arc::new(RwLock::new(HashMap::new()))); + let challenge_id = ChallengeId::new(); + let err = block_on(evaluator.proxy_request( + challenge_id, + "status", + reqwest::Method::GET, + None, + None, + )) + .expect_err("missing challenge should error"); + + match err { + EvaluatorError::ChallengeNotFound(id) => assert_eq!(id, challenge_id), + other => panic!("unexpected error: {:?}", other), + } + } + + #[tokio::test] + async fn test_proxy_request_requires_running_status() { + let challenges = Arc::new(RwLock::new(HashMap::new())); + let instance = sample_instance(ContainerStatus::Starting); + let challenge_id = instance.challenge_id; + challenges.write().insert(challenge_id, instance); + + let evaluator = ChallengeEvaluator::new(challenges); + let err = evaluator + .proxy_request(challenge_id, "health", reqwest::Method::GET, None, None) + .await + .expect_err("non-running challenge should be rejected"); + + match err { + EvaluatorError::ChallengeNotReady(id) => assert_eq!(id, challenge_id), + other => panic!("unexpected error: {:?}", other), + } + } + + #[test] + fn test_list_challenges_returns_current_instances() { + let challenges = Arc::new(RwLock::new(HashMap::new())); + let instance_a = sample_instance(ContainerStatus::Running); + let instance_b = sample_instance(ContainerStatus::Unhealthy); + let id_a = instance_a.challenge_id; + let id_b = instance_b.challenge_id; + challenges.write().insert(id_a, instance_a.clone()); + challenges.write().insert(id_b, instance_b.clone()); + + let evaluator = ChallengeEvaluator::new(challenges); + let list = evaluator.list_challenges(); + assert_eq!(list.len(), 2); + + let status_map: std::collections::HashMap = list + .into_iter() + .map(|entry| (entry.challenge_id, entry.status)) + .collect(); + + assert_eq!(status_map.get(&id_a), Some(&ContainerStatus::Running)); + assert_eq!(status_map.get(&id_b), Some(&ContainerStatus::Unhealthy)); + } + + #[tokio::test] + async fn test_evaluate_generic_succeeds_with_ok_response() { + let (addr, handle) = + spawn_static_http_server("200 OK", r#"{"value": 42}"#, "application/json").await; + let endpoint = format!("http://{}", addr); + let (evaluator, challenge_id) = evaluator_with_instance(endpoint, ContainerStatus::Running); + + let response = evaluator + .evaluate_generic(challenge_id, serde_json::json!({"input": 1}), Some(5)) + .await + .expect("evaluation succeeds"); + + assert_eq!(response["value"], 42); + handle.await.expect("server finished"); + } + + #[tokio::test] + async fn test_evaluate_generic_reports_challenge_error() { + let (addr, handle) = + spawn_static_http_server("500 Internal Server Error", "boom", "text/plain").await; + let endpoint = format!("http://{}", addr); + let (evaluator, challenge_id) = evaluator_with_instance(endpoint, ContainerStatus::Running); + + let err = evaluator + .evaluate_generic(challenge_id, serde_json::json!({}), Some(5)) + .await + .expect_err("should propagate challenge error"); + + match err { + EvaluatorError::ChallengeError { status, message } => { + assert_eq!(status, 500); + assert_eq!(message, "boom"); + } + other => panic!("unexpected error: {:?}", other), + } + + handle.await.expect("server finished"); + } + + #[tokio::test] + async fn test_evaluate_generic_reports_parse_error() { + let (addr, handle) = spawn_static_http_server("200 OK", "not json", "text/plain").await; + let endpoint = format!("http://{}", addr); + let (evaluator, challenge_id) = evaluator_with_instance(endpoint, ContainerStatus::Running); + + let err = evaluator + .evaluate_generic(challenge_id, serde_json::json!({}), Some(5)) + .await + .expect_err("invalid JSON should error"); + + assert!(matches!(err, EvaluatorError::ParseError(_))); + + handle.await.expect("server finished"); + } + + #[tokio::test] + async fn test_evaluate_generic_reports_network_error() { + let (addr, handle) = spawn_drop_http_server().await; + let endpoint = format!("http://{}", addr); + let (evaluator, challenge_id) = evaluator_with_instance(endpoint, ContainerStatus::Running); + + let err = evaluator + .evaluate_generic(challenge_id, serde_json::json!({}), Some(1)) + .await + .expect_err("network failure should bubble up"); + + assert!(matches!(err, EvaluatorError::NetworkError(_))); + handle.await.expect("server finished"); + } + + #[tokio::test] + async fn test_proxy_request_returns_payload() { + let (addr, handle) = + spawn_static_http_server("200 OK", r#"{"ok":true}"#, "application/json").await; + let endpoint = format!("http://{}", addr); + let (evaluator, challenge_id) = evaluator_with_instance(endpoint, ContainerStatus::Running); + + let response = evaluator + .proxy_request( + challenge_id, + "custom/path", + reqwest::Method::POST, + Some(serde_json::json!({"payload": true})), + Some(5), + ) + .await + .expect("proxy request succeeds"); + + assert_eq!(response["ok"], true); + handle.await.expect("server finished"); + } + + #[tokio::test] + async fn test_proxy_request_reports_challenge_error() { + let (addr, handle) = + spawn_static_http_server("503 Service Unavailable", "oops", "text/plain").await; + let endpoint = format!("http://{}", addr); + let (evaluator, challenge_id) = evaluator_with_instance(endpoint, ContainerStatus::Running); + + let err = evaluator + .proxy_request(challenge_id, "custom", reqwest::Method::GET, None, Some(5)) + .await + .expect_err("should surface challenge error"); + + match err { + EvaluatorError::ChallengeError { status, message } => { + assert_eq!(status, 503); + assert_eq!(message, "oops"); + } + other => panic!("unexpected error: {:?}", other), + } + + handle.await.expect("server finished"); + } + + #[tokio::test] + async fn test_get_info_fetches_metadata() { + let body = r#"{"name":"demo","version":"0.1.0","mechanism_id":7}"#; + let (addr, handle) = spawn_static_http_server("200 OK", body, "application/json").await; + let endpoint = format!("http://{}", addr); + let (evaluator, challenge_id) = evaluator_with_instance(endpoint, ContainerStatus::Running); + + let info = evaluator + .get_info(challenge_id) + .await + .expect("info should deserialize"); + + assert_eq!(info.name, "demo"); + assert_eq!(info.version, "0.1.0"); + assert_eq!(info.mechanism_id, 7); + handle.await.expect("server finished"); + } + + #[tokio::test] + async fn test_get_info_reports_error_status() { + let (addr, handle) = + spawn_static_http_server("404 Not Found", "missing", "text/plain").await; + let endpoint = format!("http://{}", addr); + let (evaluator, challenge_id) = evaluator_with_instance(endpoint, ContainerStatus::Running); + + let err = evaluator + .get_info(challenge_id) + .await + .expect_err("should return challenge error for non-200 info"); + + match err { + EvaluatorError::ChallengeError { status, message } => { + assert_eq!(status, 404); + assert_eq!(message, "Failed to get challenge info"); + } + other => panic!("unexpected error: {:?}", other), + } + + handle.await.expect("server finished"); + } + + #[tokio::test] + async fn test_check_health_reflects_status_code() { + let (addr_ok, handle_ok) = + spawn_static_http_server("200 OK", "{}", "application/json").await; + let (evaluator, ok_id) = + evaluator_with_instance(format!("http://{}", addr_ok), ContainerStatus::Running); + + assert!(evaluator + .check_health(ok_id) + .await + .expect("health request succeeds")); + handle_ok.await.expect("server finished"); + + let (addr_err, handle_err) = + spawn_static_http_server("503 Service Unavailable", "oops", "text/plain").await; + let (evaluator, fail_id) = + evaluator_with_instance(format!("http://{}", addr_err), ContainerStatus::Running); + + assert!(!evaluator + .check_health(fail_id) + .await + .expect("health request succeeds")); + handle_err.await.expect("server finished"); + } + + #[tokio::test] + async fn test_check_health_handles_request_failure() { + let (addr, handle) = spawn_drop_http_server().await; + let (evaluator, challenge_id) = + evaluator_with_instance(format!("http://{}", addr), ContainerStatus::Running); + + let result = evaluator + .check_health(challenge_id) + .await + .expect("network errors should be converted to false"); + + assert!(!result); + handle.await.expect("server finished"); + } + + fn evaluator_with_instance( + endpoint: String, + status: ContainerStatus, + ) -> (ChallengeEvaluator, ChallengeId) { + let challenges = Arc::new(RwLock::new(HashMap::new())); + let mut instance = sample_instance(status); + instance.endpoint = endpoint; + let challenge_id = instance.challenge_id; + challenges.write().insert(challenge_id, instance); + (ChallengeEvaluator::new(challenges), challenge_id) + } + + async fn spawn_static_http_server( + status_line: &str, + body: &str, + content_type: &str, + ) -> (std::net::SocketAddr, JoinHandle<()>) { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("bind local server"); + let addr = listener.local_addr().expect("read addr"); + let body = body.to_string(); + let content_type = content_type.to_string(); + let status_line = status_line.to_string(); + + let handle = tokio::spawn(async move { + if let Ok((mut socket, _)) = listener.accept().await { + let mut buf = vec![0u8; 1024]; + let _ = socket.read(&mut buf).await; + let response = format!( + "HTTP/1.1 {status}\r\nContent-Type: {content_type}\r\nContent-Length: {}\r\n\r\n{}", + body.len(), body, + status = status_line, + ); + let _ = socket.write_all(response.as_bytes()).await; + let _ = socket.shutdown().await; + } + }); + + (addr, handle) + } + + async fn spawn_drop_http_server() -> (std::net::SocketAddr, JoinHandle<()>) { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("bind local server"); + let addr = listener.local_addr().expect("read addr"); + + let handle = tokio::spawn(async move { + if let Ok((mut socket, _)) = listener.accept().await { + let mut buf = vec![0u8; 1024]; + let _ = socket.read(&mut buf).await; + // Drop connection without responding to trigger client-side network error. + } + }); + + (addr, handle) + } } diff --git a/crates/challenge-orchestrator/src/health.rs b/crates/challenge-orchestrator/src/health.rs index a7c30e996..d289089a3 100644 --- a/crates/challenge-orchestrator/src/health.rs +++ b/crates/challenge-orchestrator/src/health.rs @@ -200,6 +200,24 @@ impl HealthSummary { #[cfg(test)] mod tests { use super::*; + use parking_lot::RwLock; + use platform_core::ChallengeId; + use std::collections::HashMap; + use std::sync::Arc; + use std::time::{Duration, Instant}; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio::net::TcpListener; + + fn sample_instance(status: ContainerStatus) -> ChallengeInstance { + ChallengeInstance { + challenge_id: ChallengeId::new(), + container_id: "cid".into(), + image: "ghcr.io/platformnetwork/example:latest".into(), + endpoint: "http://127.0.0.1:9000".into(), + started_at: chrono::Utc::now(), + status, + } + } #[test] fn test_health_summary() { @@ -228,4 +246,324 @@ mod tests { assert!(summary.all_healthy()); assert_eq!(summary.percentage_healthy(), 100.0); } + + #[test] + fn test_percentage_healthy_handles_zero_total() { + let summary = HealthSummary { + total: 0, + running: 0, + unhealthy: 0, + starting: 0, + stopped: 0, + }; + + assert_eq!(summary.percentage_healthy(), 100.0); + } + + #[test] + fn test_get_unhealthy_lists_ids() { + let challenges = Arc::new(RwLock::new(HashMap::new())); + let healthy_instance = sample_instance(ContainerStatus::Running); + let healthy_id = healthy_instance.challenge_id; + let unhealthy_instance = sample_instance(ContainerStatus::Unhealthy); + let unhealthy_id = unhealthy_instance.challenge_id; + + { + let mut guard = challenges.write(); + guard.insert(healthy_id, healthy_instance.clone()); + guard.insert(unhealthy_id, unhealthy_instance.clone()); + } + + let monitor = HealthMonitor::new(challenges, Duration::from_secs(5)); + let ids = monitor.get_unhealthy(); + + assert_eq!(ids.len(), 1); + assert_eq!(ids[0], unhealthy_id); + } + + #[test] + fn test_health_monitor_summary_counts_statuses() { + let challenges = Arc::new(RwLock::new(HashMap::new())); + { + let mut guard = challenges.write(); + guard.insert( + ChallengeId::new(), + sample_instance(ContainerStatus::Running), + ); + guard.insert( + ChallengeId::new(), + sample_instance(ContainerStatus::Unhealthy), + ); + guard.insert( + ChallengeId::new(), + sample_instance(ContainerStatus::Starting), + ); + } + + let monitor = HealthMonitor::new(challenges, Duration::from_secs(5)); + let summary = monitor.summary(); + + assert_eq!(summary.total, 3); + assert_eq!(summary.running, 1); + assert_eq!(summary.unhealthy, 1); + assert_eq!(summary.starting, 1); + assert_eq!(summary.stopped, 0); + } + + async fn spawn_health_server( + status_line: &str, + body: &str, + ) -> (std::net::SocketAddr, tokio::task::JoinHandle<()>) { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("bind local server"); + let addr = listener.local_addr().expect("read addr"); + let body = body.to_string(); + let response = format!( + "HTTP/1.1 {status_line}\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + body.len(), body + ); + + let handle = tokio::spawn(async move { + if let Ok((mut socket, _)) = listener.accept().await { + let mut buf = [0u8; 1024]; + let _ = socket.read(&mut buf).await; + let _ = socket.write_all(response.as_bytes()).await; + } + }); + + (addr, handle) + } + + async fn spawn_repeating_health_server( + status_line: &str, + body: &str, + ) -> (std::net::SocketAddr, tokio::task::JoinHandle<()>) { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("bind repeating server"); + let addr = listener.local_addr().expect("read addr"); + let body = body.to_string(); + let response = Arc::new(format!( + "HTTP/1.1 {status_line}\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + body.len(), body + )); + + let handle = tokio::spawn(async move { + loop { + let (mut socket, _) = match listener.accept().await { + Ok(conn) => conn, + Err(_) => break, + }; + let resp = response.clone(); + tokio::spawn(async move { + let mut buf = [0u8; 1024]; + let _ = socket.read(&mut buf).await; + let _ = socket.write_all(resp.as_bytes()).await; + }); + } + }); + + (addr, handle) + } + + async fn spawn_closing_health_server() -> (std::net::SocketAddr, tokio::task::JoinHandle<()>) { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("bind closing server"); + let addr = listener.local_addr().expect("read addr"); + let handle = tokio::spawn(async move { + loop { + let (socket, _) = match listener.accept().await { + Ok(conn) => conn, + Err(_) => break, + }; + drop(socket); + } + }); + (addr, handle) + } + + #[tokio::test] + async fn test_health_monitor_check_sets_running_on_success() { + let (addr, handle) = spawn_health_server("200 OK", r#"{"status":"ok"}"#).await; + let challenges = Arc::new(RwLock::new(HashMap::new())); + let mut instance = sample_instance(ContainerStatus::Starting); + instance.endpoint = format!("http://{}", addr); + let challenge_id = instance.challenge_id; + challenges.write().insert(challenge_id, instance); + + let monitor = HealthMonitor::new(challenges.clone(), Duration::from_secs(5)); + let status = monitor + .check(&challenge_id) + .await + .expect("status should be returned"); + + assert_eq!(status, ContainerStatus::Running); + assert_eq!( + challenges + .read() + .get(&challenge_id) + .expect("challenge present") + .status, + ContainerStatus::Running + ); + + handle.await.expect("server finished"); + } + + #[tokio::test] + async fn test_health_monitor_check_marks_unhealthy_on_failure() { + let (addr, handle) = + spawn_health_server("500 Internal Server Error", r#"{"status":"error"}"#).await; + let challenges = Arc::new(RwLock::new(HashMap::new())); + let mut instance = sample_instance(ContainerStatus::Running); + instance.endpoint = format!("http://{}", addr); + let challenge_id = instance.challenge_id; + challenges.write().insert(challenge_id, instance); + + let monitor = HealthMonitor::new(challenges.clone(), Duration::from_secs(5)); + let status = monitor + .check(&challenge_id) + .await + .expect("status should be returned"); + + assert_eq!(status, ContainerStatus::Unhealthy); + assert_eq!( + challenges + .read() + .get(&challenge_id) + .expect("challenge present") + .status, + ContainerStatus::Unhealthy + ); + + handle.await.expect("server finished"); + } + + #[tokio::test] + async fn test_health_monitor_start_updates_status() { + let (addr, handle) = spawn_repeating_health_server("200 OK", r#"{"status":"ok"}"#).await; + let challenges = Arc::new(RwLock::new(HashMap::new())); + let mut instance = sample_instance(ContainerStatus::Starting); + instance.endpoint = format!("http://{}", addr); + let challenge_id = instance.challenge_id; + challenges.write().insert(challenge_id, instance); + + let monitor = HealthMonitor::new(challenges.clone(), Duration::from_millis(10)); + monitor.start().await.expect("monitor starts"); + + let deadline = Instant::now() + Duration::from_millis(500); + loop { + if challenges + .read() + .get(&challenge_id) + .map(|inst| inst.status == ContainerStatus::Running) + .unwrap_or(false) + { + break; + } + + if Instant::now() > deadline { + panic!("status never updated to running"); + } + + tokio::time::sleep(Duration::from_millis(20)).await; + } + + handle.abort(); + } + + #[tokio::test] + async fn test_health_monitor_start_marks_unhealthy_on_failed_response() { + let (addr, handle) = + spawn_repeating_health_server("500 Internal Server Error", r#"{"status":"error"}"#) + .await; + let challenges = Arc::new(RwLock::new(HashMap::new())); + let mut instance = sample_instance(ContainerStatus::Running); + instance.endpoint = format!("http://{}", addr); + let challenge_id = instance.challenge_id; + challenges.write().insert(challenge_id, instance); + + let monitor = HealthMonitor::new(challenges.clone(), Duration::from_millis(10)); + monitor.start().await.expect("monitor starts"); + + let deadline = Instant::now() + Duration::from_millis(500); + loop { + if challenges + .read() + .get(&challenge_id) + .map(|inst| inst.status == ContainerStatus::Unhealthy) + .unwrap_or(false) + { + break; + } + + if Instant::now() > deadline { + panic!("status never updated to unhealthy"); + } + + tokio::time::sleep(Duration::from_millis(20)).await; + } + + handle.abort(); + } + + #[tokio::test] + async fn test_health_monitor_start_handles_request_error() { + let (addr, handle) = spawn_closing_health_server().await; + let challenges = Arc::new(RwLock::new(HashMap::new())); + let mut instance = sample_instance(ContainerStatus::Running); + instance.endpoint = format!("http://{}", addr); + let challenge_id = instance.challenge_id; + challenges.write().insert(challenge_id, instance); + + let monitor = HealthMonitor::new(challenges.clone(), Duration::from_millis(10)); + monitor.start().await.expect("monitor starts"); + + let deadline = Instant::now() + Duration::from_millis(500); + loop { + if challenges + .read() + .get(&challenge_id) + .map(|inst| inst.status == ContainerStatus::Unhealthy) + .unwrap_or(false) + { + break; + } + + if Instant::now() > deadline { + panic!("status never updated after request error"); + } + + tokio::time::sleep(Duration::from_millis(20)).await; + } + + handle.abort(); + } + + #[tokio::test] + async fn test_health_monitor_check_treats_parse_error_as_healthy() { + let (addr, handle) = spawn_health_server("200 OK", "not-json").await; + let challenges = Arc::new(RwLock::new(HashMap::new())); + let mut instance = sample_instance(ContainerStatus::Starting); + instance.endpoint = format!("http://{}", addr); + let challenge_id = instance.challenge_id; + challenges.write().insert(challenge_id, instance); + + let monitor = HealthMonitor::new(challenges.clone(), Duration::from_secs(5)); + let status = monitor.check(&challenge_id).await.expect("status returned"); + + assert_eq!(status, ContainerStatus::Running); + assert_eq!( + challenges + .read() + .get(&challenge_id) + .expect("challenge present") + .status, + ContainerStatus::Running + ); + + handle.await.expect("server finished"); + } } diff --git a/crates/challenge-orchestrator/src/lib.rs b/crates/challenge-orchestrator/src/lib.rs index 466ba5cae..6140a7254 100644 --- a/crates/challenge-orchestrator/src/lib.rs +++ b/crates/challenge-orchestrator/src/lib.rs @@ -1,15 +1,22 @@ //! Challenge Orchestrator //! -//! Manages Docker containers for challenges. Provides: -//! - Container lifecycle (start, stop, update) -//! - Health monitoring -//! - Evaluation routing -//! - Hot-swap without core restart +//! Provides a high-level API for managing the full lifecycle of challenge +//! containers. The crate wires together networking bootstrap, backend +//! selection, container health monitoring, and the HTTP evaluator used by the +//! validator node. //! -//! ## Backend Selection (Secure by Default) +//! ### Responsibilities +//! - Detect the correct container backend (secure broker vs. direct Docker) +//! - Keep challenge containers on the `platform-network` with automatic +//! self-attachment for the validator container +//! - Track every running challenge and expose health + evaluation helpers +//! - Refresh or hot-swap containers without bouncing the validator //! -//! The orchestrator uses the **secure broker by default** in production. -//! Direct Docker is ONLY used when explicitly in development mode. +//! ### Backend Selection (Secure by Default) +//! +//! The orchestrator always prefers the secure broker. Direct Docker is only +//! selected when `DEVELOPMENT_MODE=true`, which explicitly opts into relaxed +//! security for local workflows. //! //! Priority order: //! 1. `DEVELOPMENT_MODE=true` -> Direct Docker (local dev only) @@ -30,20 +37,20 @@ pub use backend::{ SecureBackend, DEFAULT_BROKER_SOCKET, }; pub use config::*; -pub use docker::{CleanupResult, DockerClient}; +pub use docker::{ChallengeDocker, CleanupResult, DockerClient}; pub use evaluator::*; pub use health::*; pub use lifecycle::*; - use parking_lot::RwLock; use platform_core::ChallengeId; use std::collections::HashMap; use std::sync::Arc; -/// Main orchestrator managing all challenge containers +/// High-level façade that keeps container state, evaluator access, and health +/// monitoring in sync for every registered challenge. #[allow(dead_code)] pub struct ChallengeOrchestrator { - docker: DockerClient, + docker: Arc, challenges: Arc>>, health_monitor: HealthMonitor, config: OrchestratorConfig, @@ -53,11 +60,27 @@ pub struct ChallengeOrchestrator { pub const PLATFORM_NETWORK: &str = "platform-network"; impl ChallengeOrchestrator { + /// Create a new orchestrator by auto-detecting the Docker runtime inside + /// the validator container and ensuring networking prerequisites exist. pub async fn new(config: OrchestratorConfig) -> anyhow::Result { + #[cfg(test)] + if let Some(docker) = Self::take_test_docker_client() { + return Self::bootstrap_with_docker(docker, config).await; + } + // Auto-detect the network from the validator container // This ensures challenge containers are on the same network as the validator let docker = DockerClient::connect_auto_detect().await?; + Self::bootstrap_with_docker(docker, config).await + } + + /// Reusable constructor path shared between production and tests once a + /// concrete Docker client is available. + async fn bootstrap_with_docker( + docker: DockerClient, + config: OrchestratorConfig, + ) -> anyhow::Result { // Ensure the detected network exists (creates it if running outside Docker) docker.ensure_network().await?; @@ -67,6 +90,35 @@ impl ChallengeOrchestrator { tracing::warn!("Could not connect validator to platform network: {}", e); } + Self::with_docker(docker, config).await + } + + #[cfg(test)] + fn test_docker_client_slot() -> &'static std::sync::Mutex> { + use std::sync::{Mutex, OnceLock}; + static SLOT: OnceLock>> = OnceLock::new(); + SLOT.get_or_init(|| Mutex::new(None)) + } + + #[cfg(test)] + fn take_test_docker_client() -> Option { + Self::test_docker_client_slot().lock().unwrap().take() + } + + #[cfg(test)] + pub(crate) fn set_test_docker_client(docker: DockerClient) { + Self::test_docker_client_slot() + .lock() + .unwrap() + .replace(docker); + } + + /// Build an orchestrator with a custom Docker implementation + pub async fn with_docker( + docker: impl ChallengeDocker + 'static, + config: OrchestratorConfig, + ) -> anyhow::Result { + let docker = Arc::new(docker); let challenges = Arc::new(RwLock::new(HashMap::new())); let health_monitor = HealthMonitor::new(challenges.clone(), config.health_check_interval); @@ -276,8 +328,8 @@ impl ChallengeOrchestrator { } /// Get the Docker client for direct operations - pub fn docker(&self) -> &DockerClient { - &self.docker + pub fn docker(&self) -> &dyn ChallengeDocker { + self.docker.as_ref() } } @@ -299,3 +351,655 @@ pub enum ContainerStatus { Unhealthy, Stopped, } + +#[cfg(test)] +mod tests { + use super::*; + use crate::docker::DockerBridge; + use async_trait::async_trait; + use bollard::container::{ + Config, CreateContainerOptions, InspectContainerOptions, ListContainersOptions, LogOutput, + LogsOptions, RemoveContainerOptions, StartContainerOptions, StopContainerOptions, + }; + use bollard::errors::Error as DockerError; + use bollard::image::CreateImageOptions; + use bollard::models::{ + ContainerCreateResponse, ContainerInspectResponse, ContainerSummary, CreateImageInfo, + EndpointSettings, Network, NetworkSettings, + }; + use bollard::network::{ConnectNetworkOptions, CreateNetworkOptions, ListNetworksOptions}; + use bollard::volume::CreateVolumeOptions; + use chrono::Utc; + use futures::{stream, Stream}; + use platform_core::ChallengeId; + use serial_test::serial; + use std::collections::HashMap; + use std::pin::Pin; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::{Arc, Mutex}; + + #[derive(Clone, Default)] + struct TestDocker { + inner: Arc, + } + + struct TestDockerInner { + operations: Mutex>, + cleanup_result: Mutex, + cleanup_calls: Mutex)>>, + next_container_id: AtomicUsize, + } + + impl Default for TestDockerInner { + fn default() -> Self { + Self { + operations: Mutex::new(Vec::new()), + cleanup_result: Mutex::new(CleanupResult::default()), + cleanup_calls: Mutex::new(Vec::new()), + next_container_id: AtomicUsize::new(1), + } + } + } + + impl TestDocker { + fn record(&self, entry: impl Into) { + self.inner.operations.lock().unwrap().push(entry.into()); + } + + fn operations(&self) -> Vec { + self.inner.operations.lock().unwrap().clone() + } + + fn set_cleanup_result(&self, result: CleanupResult) { + *self.inner.cleanup_result.lock().unwrap() = result; + } + + fn cleanup_calls(&self) -> Vec<(String, u64, Vec)> { + self.inner.cleanup_calls.lock().unwrap().clone() + } + + fn next_instance(&self, config: &ChallengeContainerConfig) -> ChallengeInstance { + let idx = self.inner.next_container_id.fetch_add(1, Ordering::SeqCst); + let id_str = config.challenge_id.to_string(); + ChallengeInstance { + challenge_id: config.challenge_id, + container_id: format!("container-{id_str}-{idx}"), + image: config.docker_image.clone(), + endpoint: format!("http://{id_str}:{idx}"), + started_at: Utc::now(), + status: ContainerStatus::Running, + } + } + } + + #[async_trait] + impl ChallengeDocker for TestDocker { + async fn pull_image(&self, image: &str) -> anyhow::Result<()> { + self.record(format!("pull:{image}")); + Ok(()) + } + + async fn start_challenge( + &self, + config: &ChallengeContainerConfig, + ) -> anyhow::Result { + self.record(format!("start:{}", config.challenge_id)); + Ok(self.next_instance(config)) + } + + async fn stop_container(&self, container_id: &str) -> anyhow::Result<()> { + self.record(format!("stop:{container_id}")); + Ok(()) + } + + async fn remove_container(&self, container_id: &str) -> anyhow::Result<()> { + self.record(format!("remove:{container_id}")); + Ok(()) + } + + async fn is_container_running(&self, container_id: &str) -> anyhow::Result { + self.record(format!("is_running:{container_id}")); + Ok(true) + } + + async fn get_logs(&self, container_id: &str, tail: usize) -> anyhow::Result { + self.record(format!("logs:{container_id}:{tail}")); + Ok(format!("logs-{container_id}")) + } + + async fn list_challenge_containers(&self) -> anyhow::Result> { + self.record("list_containers".to_string()); + Ok(Vec::new()) + } + + async fn cleanup_stale_containers( + &self, + prefix: &str, + max_age_minutes: u64, + exclude_patterns: &[&str], + ) -> anyhow::Result { + self.record(format!("cleanup:{prefix}:{max_age_minutes}")); + self.inner.cleanup_calls.lock().unwrap().push(( + prefix.to_string(), + max_age_minutes, + exclude_patterns.iter().map(|s| s.to_string()).collect(), + )); + Ok(self.inner.cleanup_result.lock().unwrap().clone()) + } + } + + fn sample_config_with_id(challenge_id: ChallengeId, image: &str) -> ChallengeContainerConfig { + let id_str = challenge_id.to_string(); + ChallengeContainerConfig { + challenge_id, + name: format!("challenge-{id_str}"), + docker_image: image.to_string(), + mechanism_id: 0, + emission_weight: 1.0, + timeout_secs: 300, + cpu_cores: 1.0, + memory_mb: 512, + gpu_required: false, + } + } + + fn sample_config(image: &str) -> ChallengeContainerConfig { + sample_config_with_id(ChallengeId::new(), image) + } + + async fn orchestrator_with_mock(docker: TestDocker) -> ChallengeOrchestrator { + ChallengeOrchestrator::with_docker(docker, OrchestratorConfig::default()) + .await + .expect("build orchestrator") + } + + #[tokio::test] + async fn test_add_challenge_registers_instance() { + let docker = TestDocker::default(); + let orchestrator = orchestrator_with_mock(docker.clone()).await; + let config = sample_config("ghcr.io/platformnetwork/challenge:v1"); + let challenge_id = config.challenge_id; + + orchestrator + .add_challenge(config.clone()) + .await + .expect("add challenge"); + + let stored = orchestrator + .get_challenge(&challenge_id) + .expect("challenge stored"); + assert_eq!(stored.image, config.docker_image); + assert_eq!(orchestrator.list_challenges(), vec![challenge_id]); + + let ops = docker.operations(); + assert!(ops.contains(&format!("pull:{}", config.docker_image))); + assert!(ops.contains(&format!("start:{}", challenge_id))); + } + + #[tokio::test] + async fn test_update_challenge_restarts_with_new_image() { + let docker = TestDocker::default(); + let orchestrator = orchestrator_with_mock(docker.clone()).await; + let mut config = sample_config("ghcr.io/platformnetwork/challenge:v1"); + let challenge_id = config.challenge_id; + + orchestrator + .add_challenge(config.clone()) + .await + .expect("initial add"); + let initial_instance = orchestrator + .get_challenge(&challenge_id) + .expect("initial instance"); + + config.docker_image = "ghcr.io/platformnetwork/challenge:v2".into(); + orchestrator + .update_challenge(config.clone()) + .await + .expect("update succeeds"); + + let updated = orchestrator + .get_challenge(&challenge_id) + .expect("updated instance"); + assert_eq!(updated.image, config.docker_image); + assert_ne!(updated.container_id, initial_instance.container_id); + + let ops = docker.operations(); + assert!(ops + .iter() + .any(|op| op == &format!("stop:{}", initial_instance.container_id))); + assert!(ops + .iter() + .any(|op| op == &format!("pull:{}", config.docker_image))); + } + + #[tokio::test] + async fn test_remove_challenge_stops_and_removes_container() { + let docker = TestDocker::default(); + let orchestrator = orchestrator_with_mock(docker.clone()).await; + let config = sample_config("ghcr.io/platformnetwork/challenge:remove"); + let challenge_id = config.challenge_id; + + orchestrator + .add_challenge(config) + .await + .expect("added challenge"); + let container_id = orchestrator + .get_challenge(&challenge_id) + .unwrap() + .container_id; + + orchestrator + .remove_challenge(challenge_id) + .await + .expect("removed challenge"); + assert!(orchestrator.get_challenge(&challenge_id).is_none()); + + let ops = docker.operations(); + assert!(ops.contains(&format!("stop:{container_id}"))); + assert!(ops.contains(&format!("remove:{container_id}"))); + } + + #[tokio::test] + async fn test_refresh_challenge_repulls_image() { + let docker = TestDocker::default(); + let orchestrator = orchestrator_with_mock(docker.clone()).await; + let config = sample_config("ghcr.io/platformnetwork/challenge:refresh"); + let challenge_id = config.challenge_id; + + orchestrator + .add_challenge(config.clone()) + .await + .expect("added challenge"); + let initial = orchestrator + .get_challenge(&challenge_id) + .expect("initial instance"); + + orchestrator + .refresh_challenge(challenge_id) + .await + .expect("refresh succeeds"); + let refreshed = orchestrator + .get_challenge(&challenge_id) + .expect("refreshed instance"); + + assert_eq!(refreshed.image, initial.image); + assert_ne!(refreshed.container_id, initial.container_id); + + let ops = docker.operations(); + let pull_count = ops + .iter() + .filter(|op| *op == &format!("pull:{}", initial.image)) + .count(); + assert_eq!(pull_count, 2, "pull once for add, once for refresh"); + } + + #[tokio::test] + async fn test_sync_challenges_handles_all_paths() { + let docker = TestDocker::default(); + let orchestrator = orchestrator_with_mock(docker.clone()).await; + let update_config = sample_config("ghcr.io/platformnetwork/challenge:update-v1"); + let remove_config = sample_config("ghcr.io/platformnetwork/challenge:remove-v1"); + let update_id = update_config.challenge_id; + let remove_id = remove_config.challenge_id; + + orchestrator + .add_challenge(update_config.clone()) + .await + .expect("added update target"); + orchestrator + .add_challenge(remove_config.clone()) + .await + .expect("added removal target"); + + let remove_container_id = orchestrator.get_challenge(&remove_id).unwrap().container_id; + + let new_id = ChallengeId::new(); + let desired = vec![ + sample_config_with_id(update_id, "ghcr.io/platformnetwork/challenge:update-v2"), + sample_config_with_id(new_id, "ghcr.io/platformnetwork/challenge:new"), + ]; + + orchestrator + .sync_challenges(&desired) + .await + .expect("sync succeeds"); + + let ids = orchestrator.list_challenges(); + assert!(ids.contains(&update_id)); + assert!(ids.contains(&new_id)); + assert!(!ids.contains(&remove_id)); + + let ops = docker.operations(); + assert!(ops.contains(&format!("stop:{remove_container_id}"))); + assert!(ops.contains(&format!("remove:{remove_container_id}"))); + assert!(ops + .iter() + .any(|op| op == &"pull:ghcr.io/platformnetwork/challenge:update-v2".to_string())); + assert!(ops + .iter() + .any(|op| op == &"pull:ghcr.io/platformnetwork/challenge:new".to_string())); + } + + #[tokio::test] + async fn test_cleanup_stale_task_containers_propagates_result() { + let docker = TestDocker::default(); + docker.set_cleanup_result(CleanupResult { + total_found: 3, + removed: 2, + errors: vec!["dang".into()], + }); + let orchestrator = orchestrator_with_mock(docker.clone()).await; + + let result = orchestrator + .cleanup_stale_task_containers() + .await + .expect("cleanup ok"); + assert_eq!(result.total_found, 3); + assert_eq!(result.removed, 2); + assert_eq!(result.errors, vec!["dang".to_string()]); + + let calls = docker.cleanup_calls(); + assert_eq!(calls.len(), 1); + let (prefix, max_age, excludes) = &calls[0]; + assert_eq!(prefix, "term-challenge-"); + assert_eq!(*max_age, 120); + let expected: Vec = vec![ + "challenge-term-challenge".to_string(), + "platform-".to_string(), + ]; + assert_eq!(excludes, &expected); + } + + #[tokio::test] + async fn test_refresh_all_challenges_refreshes_each_container() { + let docker = TestDocker::default(); + let orchestrator = orchestrator_with_mock(docker.clone()).await; + let config_a = sample_config("ghcr.io/platformnetwork/challenge:refresh-a"); + let config_b = sample_config("ghcr.io/platformnetwork/challenge:refresh-b"); + let id_a = config_a.challenge_id; + let id_b = config_b.challenge_id; + + orchestrator + .add_challenge(config_a.clone()) + .await + .expect("added first challenge"); + orchestrator + .add_challenge(config_b.clone()) + .await + .expect("added second challenge"); + + let first_initial = orchestrator + .get_challenge(&id_a) + .expect("first challenge present") + .container_id; + let second_initial = orchestrator + .get_challenge(&id_b) + .expect("second challenge present") + .container_id; + + orchestrator + .refresh_all_challenges() + .await + .expect("refresh all succeeds"); + + let first_refreshed = orchestrator + .get_challenge(&id_a) + .expect("first challenge refreshed") + .container_id; + let second_refreshed = orchestrator + .get_challenge(&id_b) + .expect("second challenge refreshed") + .container_id; + + assert_ne!(first_initial, first_refreshed); + assert_ne!(second_initial, second_refreshed); + + let ops = docker.operations(); + assert!(ops.contains(&format!("stop:{first_initial}"))); + assert!(ops.contains(&format!("stop:{second_initial}"))); + } + + #[tokio::test] + async fn test_start_launches_health_monitor() { + let orchestrator = orchestrator_with_mock(TestDocker::default()).await; + orchestrator + .start() + .await + .expect("health monitor start succeeds"); + } + + #[tokio::test] + async fn test_evaluator_method_returns_shared_state() { + let docker = TestDocker::default(); + let orchestrator = orchestrator_with_mock(docker).await; + let config = sample_config("ghcr.io/platformnetwork/challenge:evaluator"); + let challenge_id = config.challenge_id; + + orchestrator + .add_challenge(config) + .await + .expect("challenge added"); + + let evaluator = orchestrator.evaluator(); + let ids: Vec<_> = evaluator + .list_challenges() + .into_iter() + .map(|status| status.challenge_id) + .collect(); + + assert_eq!(ids, vec![challenge_id]); + } + + #[tokio::test] + async fn test_docker_method_exposes_underlying_client() { + let docker = TestDocker::default(); + let orchestrator = orchestrator_with_mock(docker.clone()).await; + + orchestrator + .docker() + .list_challenge_containers() + .await + .expect("list call succeeds"); + + let ops = docker.operations(); + assert!(ops.contains(&"list_containers".to_string())); + } + + #[tokio::test] + #[serial] + async fn test_new_uses_injected_docker_client() { + let bridge = TestDockerBridge::default(); + let docker = DockerClient::with_bridge(bridge.clone(), PLATFORM_NETWORK); + ChallengeOrchestrator::set_test_docker_client(docker); + + let original_hostname = std::env::var("HOSTNAME").ok(); + std::env::set_var("HOSTNAME", "abcdef123456"); + + let orchestrator = ChallengeOrchestrator::new(OrchestratorConfig::default()) + .await + .expect("constructed orchestrator"); + assert_eq!( + bridge.created_networks(), + vec![PLATFORM_NETWORK.to_string()] + ); + assert!(bridge + .connected_networks() + .iter() + .any(|name| name == PLATFORM_NETWORK)); + + drop(orchestrator); + + if let Some(value) = original_hostname { + std::env::set_var("HOSTNAME", value); + } else { + std::env::remove_var("HOSTNAME"); + } + } + + #[derive(Clone, Default)] + struct TestDockerBridge { + inner: Arc, + } + + #[derive(Default)] + struct TestDockerBridgeInner { + available_networks: Mutex>, + created_networks: Mutex>, + connected_networks: Mutex>, + } + + impl TestDockerBridge { + fn created_networks(&self) -> Vec { + self.inner.created_networks.lock().unwrap().clone() + } + + fn connected_networks(&self) -> Vec { + self.inner.connected_networks.lock().unwrap().clone() + } + } + + #[async_trait] + impl DockerBridge for TestDockerBridge { + async fn ping(&self) -> Result<(), DockerError> { + Ok(()) + } + + async fn list_networks( + &self, + _options: Option>, + ) -> Result, DockerError> { + let networks = self.inner.available_networks.lock().unwrap().clone(); + Ok(networks + .into_iter() + .map(|name| Network { + name: Some(name), + ..Default::default() + }) + .collect()) + } + + async fn create_network( + &self, + options: CreateNetworkOptions, + ) -> Result<(), DockerError> { + self.inner + .created_networks + .lock() + .unwrap() + .push(options.name.clone()); + self.inner + .available_networks + .lock() + .unwrap() + .push(options.name); + Ok(()) + } + + async fn inspect_container( + &self, + _id: &str, + _options: Option, + ) -> Result { + let mut map = HashMap::new(); + for name in self + .inner + .connected_networks + .lock() + .unwrap() + .iter() + .cloned() + { + map.insert(name, EndpointSettings::default()); + } + Ok(ContainerInspectResponse { + network_settings: Some(NetworkSettings { + networks: Some(map), + ..Default::default() + }), + ..Default::default() + }) + } + + async fn connect_network( + &self, + network: &str, + _options: ConnectNetworkOptions, + ) -> Result<(), DockerError> { + let mut connected = self.inner.connected_networks.lock().unwrap(); + if !connected.iter().any(|name| name == network) { + connected.push(network.to_string()); + } + let mut available = self.inner.available_networks.lock().unwrap(); + if !available.iter().any(|name| name == network) { + available.push(network.to_string()); + } + Ok(()) + } + + fn create_image_stream( + &self, + _options: Option>, + ) -> Pin> + Send>> { + Box::pin(stream::empty::>()) + as Pin> + Send>> + } + + async fn create_volume( + &self, + _options: CreateVolumeOptions, + ) -> Result<(), DockerError> { + Ok(()) + } + + async fn create_container( + &self, + _options: Option>, + _config: Config, + ) -> Result { + Ok(ContainerCreateResponse { + id: "test-container".to_string(), + warnings: Vec::new(), + }) + } + + async fn start_container( + &self, + _id: &str, + _options: Option>, + ) -> Result<(), DockerError> { + Ok(()) + } + + async fn stop_container( + &self, + _id: &str, + _options: Option, + ) -> Result<(), DockerError> { + Ok(()) + } + + async fn remove_container( + &self, + _id: &str, + _options: Option, + ) -> Result<(), DockerError> { + Ok(()) + } + + async fn list_containers( + &self, + _options: Option>, + ) -> Result, DockerError> { + Ok(Vec::new()) + } + + fn logs_stream( + &self, + _id: &str, + _options: LogsOptions, + ) -> Pin> + Send>> { + Box::pin(stream::empty::>()) + as Pin> + Send>> + } + } +} diff --git a/crates/challenge-orchestrator/src/lifecycle.rs b/crates/challenge-orchestrator/src/lifecycle.rs index d54704f4e..44f3c7edb 100644 --- a/crates/challenge-orchestrator/src/lifecycle.rs +++ b/crates/challenge-orchestrator/src/lifecycle.rs @@ -1,26 +1,34 @@ //! Container lifecycle management +//! +//! Handles add/update/remove flows for challenge containers while keeping the +//! in-memory config/state stores consistent. The lifecycle manager is used by +//! the orchestrator as the primitive for rolling refreshes, unhealthy restarts, +//! and declarative sync operations. -use crate::{ChallengeContainerConfig, ChallengeInstance, ContainerStatus, DockerClient}; +#[cfg(test)] +use crate::CleanupResult; +use crate::{ChallengeContainerConfig, ChallengeDocker, ChallengeInstance, ContainerStatus}; use parking_lot::RwLock; use platform_core::ChallengeId; use std::collections::HashMap; use std::sync::Arc; use tracing::{error, info}; -/// Manages the lifecycle of challenge containers +/// Manages the lifecycle of challenge containers, retaining both the live +/// container handles and the configs needed to recreate them during restarts. pub struct LifecycleManager { - docker: DockerClient, + docker: Box, challenges: Arc>>, configs: Arc>>, } impl LifecycleManager { pub fn new( - docker: DockerClient, + docker: impl ChallengeDocker + 'static, challenges: Arc>>, ) -> Self { Self { - docker, + docker: Box::new(docker), challenges, configs: Arc::new(RwLock::new(HashMap::new())), } @@ -222,6 +230,10 @@ impl SyncResult { #[cfg(test)] mod tests { use super::*; + use async_trait::async_trait; + use chrono::Utc; + use std::collections::HashMap; + use std::sync::{Arc, Mutex}; #[test] fn test_sync_result_default() { @@ -249,4 +261,313 @@ mod tests { assert!(!result.is_success()); } + + #[tokio::test] + async fn test_restart_unhealthy_restarts_only_unhealthy() { + let mock = MockDocker::default(); + let mut manager = + LifecycleManager::new(mock.clone(), Arc::new(RwLock::new(HashMap::new()))); + + let unhealthy_id = ChallengeId::new(); + let healthy_id = ChallengeId::new(); + let unhealthy_container_id = "container-unhealthy"; + let healthy_container_id = "container-healthy"; + + manager.configs.write().insert( + unhealthy_id, + sample_config(unhealthy_id, "ghcr.io/org/unhealthy:1"), + ); + manager.configs.write().insert( + healthy_id, + sample_config(healthy_id, "ghcr.io/org/healthy:1"), + ); + + manager.challenges.write().insert( + unhealthy_id, + sample_instance( + unhealthy_id, + unhealthy_container_id, + "ghcr.io/org/unhealthy:1", + ContainerStatus::Unhealthy, + ), + ); + manager.challenges.write().insert( + healthy_id, + sample_instance( + healthy_id, + healthy_container_id, + "ghcr.io/org/healthy:1", + ContainerStatus::Running, + ), + ); + + let results = manager.restart_unhealthy().await; + + assert_eq!(results.len(), 1); + assert_eq!(results[0].0, unhealthy_id); + assert!(results[0].1.is_ok()); + + let ops = mock.operations(); + assert!(ops + .iter() + .any(|op| op == &format!("stop:{unhealthy_container_id}"))); + assert!(ops + .iter() + .any(|op| op == &format!("remove:{unhealthy_container_id}"))); + assert!(ops + .iter() + .any(|op| op == &format!("start:{}", unhealthy_id.to_string()))); + assert!(!ops + .iter() + .any(|op| op == &format!("stop:{healthy_container_id}"))); + } + + #[tokio::test] + async fn test_sync_handles_add_update_remove() { + let mock = MockDocker::default(); + let challenges = Arc::new(RwLock::new(HashMap::new())); + let mut manager = LifecycleManager::new(mock.clone(), challenges); + + let update_id = ChallengeId::new(); + let remove_id = ChallengeId::new(); + let new_id = ChallengeId::new(); + + manager + .configs + .write() + .insert(update_id, sample_config(update_id, "ghcr.io/org/update:v1")); + manager + .configs + .write() + .insert(remove_id, sample_config(remove_id, "ghcr.io/org/remove:v1")); + + manager.challenges.write().insert( + update_id, + sample_instance( + update_id, + "container-update-old", + "ghcr.io/org/update:v1", + ContainerStatus::Running, + ), + ); + manager.challenges.write().insert( + remove_id, + sample_instance( + remove_id, + "container-remove-old", + "ghcr.io/org/remove:v1", + ContainerStatus::Running, + ), + ); + + let result = manager + .sync(vec![ + sample_config(update_id, "ghcr.io/org/update:v2"), + sample_config(new_id, "ghcr.io/org/new:v1"), + ]) + .await + .expect("sync succeeds"); + + assert_eq!(result.added, vec![new_id]); + assert_eq!(result.updated, vec![update_id]); + assert_eq!(result.removed, vec![remove_id]); + assert!(result.errors.is_empty()); + assert!(result.unchanged.is_empty()); + + let challenges = manager.challenges.read(); + assert!(challenges.contains_key(&update_id)); + assert!(challenges.contains_key(&new_id)); + assert!(!challenges.contains_key(&remove_id)); + drop(challenges); + + let ops = mock.operations(); + assert!(ops.iter().any(|op| op == "pull:ghcr.io/org/update:v2")); + assert!(ops.iter().any(|op| op == "pull:ghcr.io/org/new:v1")); + assert!(ops + .iter() + .any(|op| op == &format!("start:{}", update_id.to_string()))); + assert!(ops + .iter() + .any(|op| op == &format!("start:{}", new_id.to_string()))); + assert!(ops.iter().any(|op| op == "stop:container-update-old")); + assert!(ops.iter().any(|op| op == "remove:container-update-old")); + assert!(ops.iter().any(|op| op == "stop:container-remove-old")); + assert!(ops.iter().any(|op| op == "remove:container-remove-old")); + } + + #[tokio::test] + async fn test_add_records_config_and_instance_state() { + let mock = MockDocker::default(); + let challenges = Arc::new(RwLock::new(HashMap::new())); + let mut manager = LifecycleManager::new(mock.clone(), challenges); + let challenge_id = ChallengeId::new(); + let config = sample_config(challenge_id, "ghcr.io/org/add:v1"); + + manager.add(config.clone()).await.expect("add succeeds"); + + assert!(manager.challenges.read().contains_key(&challenge_id)); + assert!(manager.configs.read().contains_key(&challenge_id)); + + let ops = mock.operations(); + assert!(ops.contains(&format!("pull:{}", config.docker_image))); + assert!(ops.contains(&format!("start:{}", challenge_id))); + } + + #[tokio::test] + async fn test_stop_all_removes_every_challenge() { + let mock = MockDocker::default(); + let challenges = Arc::new(RwLock::new(HashMap::new())); + let mut manager = LifecycleManager::new(mock.clone(), challenges); + + let first_id = ChallengeId::new(); + let second_id = ChallengeId::new(); + + manager + .configs + .write() + .insert(first_id, sample_config(first_id, "ghcr.io/org/first:v1")); + manager + .configs + .write() + .insert(second_id, sample_config(second_id, "ghcr.io/org/second:v1")); + + manager.challenges.write().insert( + first_id, + sample_instance( + first_id, + "container-first", + "ghcr.io/org/first:v1", + ContainerStatus::Running, + ), + ); + manager.challenges.write().insert( + second_id, + sample_instance( + second_id, + "container-second", + "ghcr.io/org/second:v1", + ContainerStatus::Running, + ), + ); + + let results = manager.stop_all().await; + + assert_eq!(results.len(), 2); + assert!(results.iter().all(|(_, res)| res.is_ok())); + assert!(manager.challenges.read().is_empty()); + assert!(manager.configs.read().is_empty()); + + let ops = mock.operations(); + assert!(ops.contains(&"stop:container-first".to_string())); + assert!(ops.contains(&"remove:container-first".to_string())); + assert!(ops.contains(&"stop:container-second".to_string())); + assert!(ops.contains(&"remove:container-second".to_string())); + } + + #[derive(Clone, Default)] + struct MockDocker { + inner: Arc, + } + + #[derive(Default)] + struct MockDockerInner { + operations: Mutex>, + } + + impl MockDocker { + fn record(&self, entry: impl Into) { + self.inner.operations.lock().unwrap().push(entry.into()); + } + + fn operations(&self) -> Vec { + self.inner.operations.lock().unwrap().clone() + } + } + + #[async_trait] + impl ChallengeDocker for MockDocker { + async fn pull_image(&self, image: &str) -> anyhow::Result<()> { + self.record(format!("pull:{image}")); + Ok(()) + } + + async fn start_challenge( + &self, + config: &ChallengeContainerConfig, + ) -> anyhow::Result { + self.record(format!("start:{}", config.challenge_id)); + Ok(sample_instance( + config.challenge_id, + &format!("container-{}", config.challenge_id), + &config.docker_image, + ContainerStatus::Running, + )) + } + + async fn stop_container(&self, container_id: &str) -> anyhow::Result<()> { + self.record(format!("stop:{container_id}")); + Ok(()) + } + + async fn remove_container(&self, container_id: &str) -> anyhow::Result<()> { + self.record(format!("remove:{container_id}")); + Ok(()) + } + + async fn is_container_running(&self, container_id: &str) -> anyhow::Result { + self.record(format!("is_running:{container_id}")); + Ok(true) + } + + async fn get_logs(&self, container_id: &str, tail: usize) -> anyhow::Result { + self.record(format!("logs:{container_id}:{tail}")); + Ok(String::new()) + } + + async fn list_challenge_containers(&self) -> anyhow::Result> { + self.record("list_containers".to_string()); + Ok(Vec::new()) + } + + async fn cleanup_stale_containers( + &self, + prefix: &str, + _max_age_minutes: u64, + _exclude_patterns: &[&str], + ) -> anyhow::Result { + self.record(format!("cleanup:{prefix}")); + Ok(CleanupResult::default()) + } + } + + fn sample_config(challenge_id: ChallengeId, image: &str) -> ChallengeContainerConfig { + ChallengeContainerConfig { + challenge_id, + name: format!("challenge-{challenge_id}"), + docker_image: image.to_string(), + mechanism_id: 0, + emission_weight: 1.0, + timeout_secs: 3600, + cpu_cores: 1.0, + memory_mb: 512, + gpu_required: false, + } + } + + fn sample_instance( + challenge_id: ChallengeId, + container_id: &str, + image: &str, + status: ContainerStatus, + ) -> ChallengeInstance { + let id_str = challenge_id.to_string(); + ChallengeInstance { + challenge_id, + container_id: container_id.to_string(), + image: image.to_string(), + endpoint: format!("http://{id_str}"), + started_at: Utc::now(), + status, + } + } }