diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 55d02cd8b155..54473bcdd831 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -3528,6 +3528,8 @@ dependencies = [ "rama-tls-rustls", "rama-unix", "rustls-native-certs", + "schannel", + "security-framework 3.5.1", "serde", "serde_json", "sha2 0.10.9", diff --git a/codex-rs/network-proxy/Cargo.toml b/codex-rs/network-proxy/Cargo.toml index cd32ff92b22e..cc6a8de7104f 100644 --- a/codex-rs/network-proxy/Cargo.toml +++ b/codex-rs/network-proxy/Cargo.toml @@ -44,3 +44,9 @@ tempfile = { workspace = true } [target.'cfg(target_family = "unix")'.dependencies] rama-unix = { version = "=0.3.0-alpha.4" } + +[target.'cfg(target_os = "macos")'.dependencies] +security-framework = "3" + +[target.'cfg(windows)'.dependencies] +schannel = "0.1" diff --git a/codex-rs/network-proxy/src/certs.rs b/codex-rs/network-proxy/src/certs.rs index 001469aa9a65..2ff7cc03b459 100644 --- a/codex-rs/network-proxy/src/certs.rs +++ b/codex-rs/network-proxy/src/certs.rs @@ -23,6 +23,7 @@ use rama_tls_rustls::server::TlsAcceptorData; use sha2::Digest as _; use sha2::Sha256; use std::collections::HashMap; +use std::collections::HashSet; use std::fs; use std::fs::File; use std::fs::OpenOptions; @@ -30,6 +31,7 @@ use std::io::Write; use std::net::IpAddr; use std::path::Path; use std::path::PathBuf; +use std::sync::Arc; use std::time::SystemTime; use std::time::UNIX_EPOCH; use tracing::info; @@ -101,6 +103,7 @@ const MANAGED_MITM_CA_DIR: &str = "proxy"; const MANAGED_MITM_CA_CERT: &str = "ca.pem"; const MANAGED_MITM_CA_KEY: &str = "ca.key"; const MANAGED_MITM_CA_TRUST_BUNDLE_PREFIX: &str = "ca-bundle"; +pub(crate) const SSL_CERT_DIR_ENV_KEY: &str = "SSL_CERT_DIR"; // Best-effort compatibility set for common child toolchains that accept a CA bundle path. // This is intentionally curated rather than pretending to cover every TLS client. @@ -117,6 +120,14 @@ pub const CUSTOM_CA_ENV_KEYS: [&str; 10] = [ "NPM_CONFIG_CAFILE", ]; +pub(crate) fn ca_env_from_process() -> HashMap<&'static str, String> { + CUSTOM_CA_ENV_KEYS + .into_iter() + .chain([SSL_CERT_DIR_ENV_KEY]) + .filter_map(|key| std::env::var(key).ok().map(|value| (key, value))) + .collect() +} + /// Immutable managed MITM CA bundle path plus startup TLS env values. #[derive(Debug, Clone, PartialEq, Eq)] pub(crate) struct ManagedMitmCaTrustBundle { @@ -146,38 +157,203 @@ fn managed_ca_trust_bundle_for_cert_path( cert_path: &Path, env: &HashMap<&'static str, String>, ) -> Result { - let startup_env_values = CUSTOM_CA_ENV_KEYS + let startup_env_values = startup_ca_file_env_values(env); + let startup_cert_dir = env + .get(SSL_CERT_DIR_ENV_KEY) + .filter(|value| !value.is_empty()) + .map(String::as_str); + let trust_bundle = + build_managed_ca_trust_bundle(cert_path, &startup_env_values, startup_cert_dir)?; + let path = persist_managed_ca_trust_bundle(cert_path, &trust_bundle)?; + + Ok(ManagedMitmCaTrustBundle { + path, + startup_env_values, + }) +} + +pub(crate) fn upstream_tls_root_store( + env: &HashMap<&'static str, String>, +) -> Result> { + let (managed_ca_cert_path, _) = managed_ca_paths()?; + upstream_tls_root_store_for_cert_path(&managed_ca_cert_path, env) +} + +pub(crate) fn upstream_tls_root_store_for_cert_path( + managed_ca_cert_path: &Path, + env: &HashMap<&'static str, String>, +) -> Result> { + let startup_env_values = startup_ca_file_env_values(env); + let startup_cert_dir = env + .get(SSL_CERT_DIR_ENV_KEY) + .filter(|value| !value.is_empty()) + .map(String::as_str); + let certificates = load_platform_and_startup_root_certificates( + managed_ca_cert_path, + &startup_env_values, + startup_cert_dir, + )?; + let mut roots = rustls::RootCertStore::empty(); + let (_, ignored) = roots.add_parsable_certificates(certificates); + if ignored > 0 { + warn!( + ignored_root_count = ignored, + "ignored invalid platform or startup roots for MITM upstream TLS" + ); + } + Ok(Arc::new(roots)) +} + +fn startup_ca_file_env_values( + env: &HashMap<&'static str, String>, +) -> HashMap<&'static str, String> { + CUSTOM_CA_ENV_KEYS .into_iter() .filter_map(|key| { env.get(key) .filter(|value| !value.is_empty()) .map(|value| (key, value.clone())) }) - .collect(); - let trust_bundle = build_managed_ca_trust_bundle(cert_path)?; - let path = persist_managed_ca_trust_bundle(cert_path, &trust_bundle)?; + .collect() +} - Ok(ManagedMitmCaTrustBundle { - path, +fn build_managed_ca_trust_bundle( + managed_ca_cert_path: &Path, + startup_env_values: &HashMap<&'static str, String>, + startup_cert_dir: Option<&str>, +) -> Result { + let mut trust_bundle = String::new(); + for cert in load_platform_and_startup_root_certificates( + managed_ca_cert_path, startup_env_values, - }) + startup_cert_dir, + )? { + push_certificate_pem(&mut trust_bundle, cert.as_ref()); + } + append_pem_file(&mut trust_bundle, managed_ca_cert_path)?; + Ok(trust_bundle) } -fn build_managed_ca_trust_bundle(managed_ca_cert_path: &Path) -> Result { - let mut trust_bundle = String::new(); +fn load_platform_and_startup_root_certificates( + managed_ca_cert_path: &Path, + startup_env_values: &HashMap<&'static str, String>, + startup_cert_dir: Option<&str>, +) -> Result>> { + let managed_ca_cert = fs::read(managed_ca_cert_path).with_context(|| { + format!( + "failed to read managed MITM CA certificate: {}", + managed_ca_cert_path.display() + ) + })?; + let managed_ca_cert = CertificateDer::from_pem_slice(&managed_ca_cert) + .context("failed to parse managed MITM CA certificate")?; let rustls_native_certs::CertificateResult { certs, errors, .. } = - rustls_native_certs::load_native_certs(); + crate::native_certs::load_platform_native_certs(); if !errors.is_empty() { warn!( native_root_error_count = errors.len(), "encountered errors while loading native root certificates for MITM trust bundle" ); } - for cert in certs { - push_certificate_pem(&mut trust_bundle, cert.as_ref()); + let mut certificates = certs; + let mut appended_startup_paths = HashSet::new(); + for path in CUSTOM_CA_ENV_KEYS + .into_iter() + .filter_map(|key| startup_env_values.get(key)) + .map(PathBuf::from) + { + if path != managed_ca_cert_path + && !is_current_generated_trust_bundle_path(&path, managed_ca_cert_path) + && appended_startup_paths.insert(path.clone()) + { + certificates.extend(read_ca_certificates(&path)?); + } } - append_pem_file(&mut trust_bundle, managed_ca_cert_path)?; - Ok(trust_bundle) + if let Some(startup_cert_dir) = startup_cert_dir { + for path in std::env::split_paths(startup_cert_dir) { + if appended_startup_paths.insert(path.clone()) { + certificates.extend(load_ca_directory_certificates(&path)); + } + } + } + let mut seen = HashSet::new(); + certificates.retain(|cert| cert != &managed_ca_cert && seen.insert(cert.as_ref().to_vec())); + Ok(certificates) +} + +fn read_ca_certificates(path: &Path) -> Result>> { + let pem = fs::read(path) + .with_context(|| format!("failed to read startup CA bundle: {}", path.display()))?; + let pem = String::from_utf8_lossy(&pem); + let contains_trusted_certificates = pem.contains("TRUSTED CERTIFICATE"); + let normalized_pem = pem + .replace("BEGIN TRUSTED CERTIFICATE", "BEGIN CERTIFICATE") + .replace("END TRUSTED CERTIFICATE", "END CERTIFICATE"); + let certs = CertificateDer::pem_slice_iter(normalized_pem.as_bytes()) + .collect::, _>>() + .with_context(|| format!("failed to parse startup CA bundle: {}", path.display()))?; + if certs.is_empty() { + return Err(anyhow!( + "startup CA bundle contained no certificates: {}", + path.display() + )); + } + certs + .into_iter() + .map(|cert| { + let cert = if contains_trusted_certificates { + first_der_item(cert.as_ref()).ok_or_else(|| { + anyhow!( + "startup CA bundle contained an invalid trusted certificate: {}", + path.display() + ) + })? + } else { + cert.as_ref() + }; + Ok(CertificateDer::from(cert.to_vec())) + }) + .collect() +} + +fn load_ca_directory_certificates(path: &Path) -> Vec> { + let rustls_native_certs::CertificateResult { certs, errors, .. } = + rustls_native_certs::load_certs_from_paths(None, Some(path)); + if !errors.is_empty() { + warn!( + ca_path = %path.display(), + ca_error_count = errors.len(), + "encountered errors while loading startup CA directory" + ); + } + certs +} + +fn first_der_item(der: &[u8]) -> Option<&[u8]> { + der_item_length(der).map(|length| &der[..length]) +} + +fn der_item_length(der: &[u8]) -> Option { + let &length_octet = der.get(1)?; + if length_octet & 0x80 == 0 { + return Some(2 + usize::from(length_octet)).filter(|length| *length <= der.len()); + } + + let length_octets = usize::from(length_octet & 0x7f); + if length_octets == 0 { + return None; + } + + let length_end = 2usize.checked_add(length_octets)?; + let mut content_length = 0usize; + for &byte in der.get(2..length_end)? { + content_length = content_length + .checked_mul(256)? + .checked_add(usize::from(byte))?; + } + length_end + .checked_add(content_length) + .filter(|length| *length <= der.len()) } fn is_current_generated_trust_bundle_path(path: &Path, managed_ca_cert_path: &Path) -> bool { @@ -508,17 +684,80 @@ mod tests { } #[test] - fn managed_ca_trust_bundle_records_startup_ca_env_values() { + fn managed_ca_trust_bundle_appends_startup_file_and_directory_certificates() { let dir = tempdir().unwrap(); let managed_ca_cert_path = dir.path().join("ca.pem"); - fs::write(&managed_ca_cert_path, "managed ca\n").unwrap(); - let env = HashMap::from([("SSL_CERT_FILE", "/tmp/startup-ca.pem".to_string())]); + let startup_ca_bundle_path = dir.path().join("startup-ca.pem"); + let startup_ca_dir = dir.path().join("startup-certs"); + let (managed_ca_cert, _) = generate_ca().unwrap(); + let (startup_ca_cert, startup_ca_key) = generate_ca().unwrap(); + let (directory_ca_cert, _) = generate_ca().unwrap(); + let mut trusted_ca_der = CertificateDer::from_pem_slice(startup_ca_cert.as_bytes()) + .unwrap() + .as_ref() + .to_vec(); + trusted_ca_der.extend_from_slice(&[0x30, 0x00]); + let mut trusted_ca_cert = String::new(); + push_certificate_pem(&mut trusted_ca_cert, &trusted_ca_der); + let trusted_ca_cert = trusted_ca_cert.replace("CERTIFICATE", "TRUSTED CERTIFICATE"); + fs::write(&managed_ca_cert_path, &managed_ca_cert).unwrap(); + fs::write( + &startup_ca_bundle_path, + format!("{trusted_ca_cert}{startup_ca_key}"), + ) + .unwrap(); + fs::create_dir(&startup_ca_dir).unwrap(); + fs::write(startup_ca_dir.join("directory-ca.pem"), &directory_ca_cert).unwrap(); + let startup_ca_bundle_path = startup_ca_bundle_path.display().to_string(); + let env = HashMap::from([ + ("SSL_CERT_FILE", startup_ca_bundle_path.clone()), + (SSL_CERT_DIR_ENV_KEY, startup_ca_dir.display().to_string()), + ]); + let trust_bundle = managed_ca_trust_bundle_for_cert_path(&managed_ca_cert_path, &env).unwrap(); assert_eq!( trust_bundle.startup_env_values, - HashMap::from([("SSL_CERT_FILE", "/tmp/startup-ca.pem".to_string())]) + HashMap::from([("SSL_CERT_FILE", startup_ca_bundle_path)]) + ); + let baseline_bundle = fs::read_to_string(&trust_bundle.path).unwrap(); + let baseline_certs = CertificateDer::pem_slice_iter(baseline_bundle.as_bytes()) + .collect::, _>>() + .unwrap(); + let expected_certs = [&startup_ca_cert, &directory_ca_cert, &managed_ca_cert] + .map(|cert| CertificateDer::from_pem_slice(cert.as_bytes()).unwrap()); + + assert!( + expected_certs + .iter() + .all(|cert| baseline_certs.contains(cert)) ); + assert!(!baseline_bundle.contains(&startup_ca_key)); + assert!(!baseline_bundle.contains("TRUSTED CERTIFICATE")); + } + + #[test] + fn managed_ca_trust_bundle_skips_inherited_current_bundle() { + let dir = tempdir().unwrap(); + let managed_ca_cert_path = dir.path().join("ca.pem"); + let inherited_bundle_path = dir.path().join("ca-bundle-parent.pem"); + let (managed_ca_cert, _) = generate_ca().unwrap(); + fs::write(&managed_ca_cert_path, &managed_ca_cert).unwrap(); + fs::write( + &inherited_bundle_path, + format!("parent roots\n{managed_ca_cert}"), + ) + .unwrap(); + let env = HashMap::from([( + "REQUESTS_CA_BUNDLE", + inherited_bundle_path.display().to_string(), + )]); + + let trust_bundle = + managed_ca_trust_bundle_for_cert_path(&managed_ca_cert_path, &env).unwrap(); + let baseline_bundle = fs::read_to_string(&trust_bundle.path).unwrap(); + + assert_eq!(baseline_bundle.matches(&managed_ca_cert).count(), 1); } #[cfg(unix)] diff --git a/codex-rs/network-proxy/src/lib.rs b/codex-rs/network-proxy/src/lib.rs index e80b154a5eae..c8950f1f5592 100644 --- a/codex-rs/network-proxy/src/lib.rs +++ b/codex-rs/network-proxy/src/lib.rs @@ -6,6 +6,7 @@ mod connect_policy; mod http_proxy; mod mitm; mod mitm_hook; +mod native_certs; mod network_policy; mod policy; mod proxy; diff --git a/codex-rs/network-proxy/src/mitm.rs b/codex-rs/network-proxy/src/mitm.rs index 345c5b503296..70175ca70602 100644 --- a/codex-rs/network-proxy/src/mitm.rs +++ b/codex-rs/network-proxy/src/mitm.rs @@ -107,11 +107,19 @@ impl MitmState { // generate/load a local CA and issue per-host leaf certs so we can terminate TLS and // apply policy. let ca = ManagedMitmCa::load_or_create()?; + let upstream_tls_root_store = + crate::certs::upstream_tls_root_store(&crate::certs::ca_env_from_process())?; let upstream = if config.allow_upstream_proxy { - UpstreamClient::from_env_proxy_with_allow_local_binding(config.allow_local_binding) + UpstreamClient::from_env_proxy_with_allow_local_binding( + config.allow_local_binding, + upstream_tls_root_store, + ) } else { - UpstreamClient::direct_with_allow_local_binding(config.allow_local_binding) + UpstreamClient::direct_with_allow_local_binding( + config.allow_local_binding, + upstream_tls_root_store, + ) }; Ok(Self { diff --git a/codex-rs/network-proxy/src/native_certs.rs b/codex-rs/network-proxy/src/native_certs.rs new file mode 100644 index 000000000000..6c8d45005db8 --- /dev/null +++ b/codex-rs/network-proxy/src/native_certs.rs @@ -0,0 +1,260 @@ +#[cfg(any(target_os = "macos", windows))] +use rama_tls_rustls::dep::pki_types::CertificateDer; +use rustls_native_certs::CertificateResult; +#[cfg(any(target_os = "macos", windows))] +use rustls_native_certs::Error; +#[cfg(any(target_os = "macos", windows))] +use rustls_native_certs::ErrorKind; + +// `rustls_native_certs::load_native_certs()` first consults SSL_CERT_FILE and +// SSL_CERT_DIR. Load platform roots directly so a startup custom CA can be +// layered onto the managed bundle without replacing the platform trust store. +#[cfg(all(unix, not(target_os = "macos")))] +pub(crate) fn load_platform_native_certs() -> CertificateResult { + let mut result = + rustls_native_certs::load_certs_from_paths(platform_cert_file().as_deref(), None); + for cert_dir in platform_cert_dirs() { + extend_certificate_result( + &mut result, + rustls_native_certs::load_certs_from_paths(None, Some(&cert_dir)), + ); + } + dedupe_certs(&mut result); + result +} + +#[cfg(target_os = "macos")] +pub(crate) fn load_platform_native_certs() -> CertificateResult { + use security_framework::trust_settings::Domain; + use security_framework::trust_settings::TrustSettings; + use security_framework::trust_settings::TrustSettingsForCertificate; + use std::collections::BTreeMap; + + let mut result = CertificateResult::default(); + let mut all_certs = BTreeMap::new(); + for domain in &[Domain::User, Domain::Admin, Domain::System] { + let ts = TrustSettings::new(*domain); + let iter = match ts.iter() { + Ok(iter) => iter, + Err(err) => { + result.errors.push(Error { + context: match domain { + Domain::User => "failed to load user trust settings", + Domain::Admin => "failed to load admin trust settings", + Domain::System => "failed to load system trust settings", + }, + kind: ErrorKind::Os(err.into()), + }); + continue; + } + }; + + for cert in iter { + let der = cert.to_der(); + let trusted = match ts.tls_trust_settings_for_certificate(&cert) { + Ok(trusted) => trusted.unwrap_or(TrustSettingsForCertificate::TrustRoot), + Err(err) => { + result.errors.push(Error { + context: "certificate not trusted", + kind: ErrorKind::Os(err.into()), + }); + continue; + } + }; + all_certs.entry(der).or_insert(trusted); + } + } + + for (der, trusted) in all_certs { + use TrustSettingsForCertificate::*; + + if let TrustRoot | TrustAsRoot = trusted { + result.certs.push(CertificateDer::from(der)); + } + } + result +} + +#[cfg(windows)] +pub(crate) fn load_platform_native_certs() -> CertificateResult { + use schannel::cert_store::CertStore; + + let mut result = CertificateResult::default(); + let current_user_store = match CertStore::open_current_user("ROOT") { + Ok(store) => store, + Err(err) => { + result.errors.push(Error { + context: "failed to open current user certificate store", + kind: ErrorKind::Os(err.into()), + }); + return result; + } + }; + + for cert in current_user_store.certs() { + let valid_uses = match cert.valid_uses() { + Ok(valid_uses) => valid_uses, + Err(err) => { + result.errors.push(Error { + context: "failed to inspect certificate valid uses", + kind: ErrorKind::Os(err.into()), + }); + continue; + } + }; + let is_time_valid = match cert.is_time_valid() { + Ok(is_time_valid) => is_time_valid, + Err(err) => { + result.errors.push(Error { + context: "failed to inspect certificate time validity", + kind: ErrorKind::Os(err.into()), + }); + continue; + } + }; + if usable_for_rustls(valid_uses) && is_time_valid { + result + .certs + .push(CertificateDer::from(cert.to_der().to_vec())); + } + } + result +} + +#[cfg(not(any(all(unix, not(target_os = "macos")), target_os = "macos", windows)))] +pub(crate) fn load_platform_native_certs() -> CertificateResult { + rustls_native_certs::load_native_certs() +} + +#[cfg(all(unix, not(target_os = "macos")))] +fn extend_certificate_result(result: &mut CertificateResult, extra: CertificateResult) { + result.certs.extend(extra.certs); + result.errors.extend(extra.errors); +} + +#[cfg(all(unix, not(target_os = "macos")))] +fn dedupe_certs(result: &mut CertificateResult) { + result.certs.sort_unstable_by(|a, b| a.cmp(b)); + result.certs.dedup(); +} + +#[cfg(all(unix, not(target_os = "macos")))] +fn platform_cert_file() -> Option { + PLATFORM_CERTIFICATE_FILE_NAMES + .iter() + .map(std::path::Path::new) + .find(|path| path.exists()) + .map(std::path::Path::to_path_buf) +} + +#[cfg(all(unix, not(target_os = "macos")))] +fn platform_cert_dirs() -> impl Iterator { + PLATFORM_CERTIFICATE_DIRS + .iter() + .map(std::path::Path::new) + .filter(|path| path.exists()) + .map(std::path::Path::to_path_buf) +} + +#[cfg(all(unix, not(target_os = "macos"), target_os = "linux"))] +const PLATFORM_CERTIFICATE_DIRS: &[&str] = &[ + "/etc/ssl/certs", + "/etc/pki/tls/certs", + "/etc/security/certificates", +]; + +#[cfg(all(unix, not(target_os = "macos"), target_os = "freebsd"))] +const PLATFORM_CERTIFICATE_DIRS: &[&str] = &["/etc/ssl/certs", "/usr/local/share/certs"]; + +#[cfg(all( + unix, + not(target_os = "macos"), + any(target_os = "illumos", target_os = "solaris") +))] +const PLATFORM_CERTIFICATE_DIRS: &[&str] = &["/etc/certs/CA"]; + +#[cfg(all(unix, not(target_os = "macos"), target_os = "netbsd"))] +const PLATFORM_CERTIFICATE_DIRS: &[&str] = &["/etc/openssl/certs"]; + +#[cfg(all(unix, not(target_os = "macos"), target_os = "aix"))] +const PLATFORM_CERTIFICATE_DIRS: &[&str] = &["/var/ssl/certs"]; + +#[cfg(all( + unix, + not(target_os = "macos"), + not(any( + target_os = "linux", + target_os = "freebsd", + target_os = "illumos", + target_os = "solaris", + target_os = "netbsd", + target_os = "aix" + )) +))] +const PLATFORM_CERTIFICATE_DIRS: &[&str] = &["/etc/ssl/certs"]; + +#[cfg(all(unix, not(target_os = "macos"), target_os = "linux"))] +const PLATFORM_CERTIFICATE_FILE_NAMES: &[&str] = &[ + "/etc/ssl/certs/ca-certificates.crt", + "/etc/pki/ca-trust/extracted/pem/tls-ca-bundle.pem", + "/etc/pki/tls/certs/ca-bundle.crt", + "/etc/ssl/ca-bundle.pem", + "/etc/pki/tls/cacert.pem", + "/etc/ssl/cert.pem", + "/opt/etc/ssl/certs/ca-certificates.crt", + "/etc/ssl/certs/cacert.pem", +]; + +#[cfg(all(unix, not(target_os = "macos"), target_os = "freebsd"))] +const PLATFORM_CERTIFICATE_FILE_NAMES: &[&str] = &["/usr/local/etc/ssl/cert.pem"]; + +#[cfg(all(unix, not(target_os = "macos"), target_os = "dragonfly"))] +const PLATFORM_CERTIFICATE_FILE_NAMES: &[&str] = &["/usr/local/share/certs/ca-root-nss.crt"]; + +#[cfg(all(unix, not(target_os = "macos"), target_os = "netbsd"))] +const PLATFORM_CERTIFICATE_FILE_NAMES: &[&str] = &["/etc/openssl/certs/ca-certificates.crt"]; + +#[cfg(all(unix, not(target_os = "macos"), target_os = "openbsd"))] +const PLATFORM_CERTIFICATE_FILE_NAMES: &[&str] = &["/etc/ssl/cert.pem"]; + +#[cfg(all(unix, not(target_os = "macos"), target_os = "solaris"))] +const PLATFORM_CERTIFICATE_FILE_NAMES: &[&str] = &["/etc/certs/ca-certificates.crt"]; + +#[cfg(all(unix, not(target_os = "macos"), target_os = "illumos"))] +const PLATFORM_CERTIFICATE_FILE_NAMES: &[&str] = + &["/etc/ssl/cacert.pem", "/etc/certs/ca-certificates.crt"]; + +#[cfg(all(unix, not(target_os = "macos"), target_os = "android"))] +const PLATFORM_CERTIFICATE_FILE_NAMES: &[&str] = + &["/data/data/com.termux/files/usr/etc/tls/cert.pem"]; + +#[cfg(all(unix, not(target_os = "macos"), target_os = "haiku"))] +const PLATFORM_CERTIFICATE_FILE_NAMES: &[&str] = &["/boot/system/data/ssl/CARootCertificates.pem"]; + +#[cfg(all( + unix, + not(target_os = "macos"), + not(any( + target_os = "linux", + target_os = "freebsd", + target_os = "dragonfly", + target_os = "netbsd", + target_os = "openbsd", + target_os = "solaris", + target_os = "illumos", + target_os = "android", + target_os = "haiku", + )) +))] +const PLATFORM_CERTIFICATE_FILE_NAMES: &[&str] = &["/etc/ssl/certs/ca-certificates.crt"]; + +#[cfg(windows)] +fn usable_for_rustls(uses: schannel::cert_context::ValidUses) -> bool { + match uses { + schannel::cert_context::ValidUses::All => true, + schannel::cert_context::ValidUses::Oids(strs) => strs.iter().any(|x| x == PKIX_SERVER_AUTH), + } +} + +#[cfg(windows)] +const PKIX_SERVER_AUTH: &str = "1.3.6.1.5.5.7.3.1"; diff --git a/codex-rs/network-proxy/src/proxy.rs b/codex-rs/network-proxy/src/proxy.rs index c3685a4310aa..cb56059db762 100644 --- a/codex-rs/network-proxy/src/proxy.rs +++ b/codex-rs/network-proxy/src/proxy.rs @@ -306,10 +306,7 @@ struct NetworkProxyRuntimeSettings { impl NetworkProxyRuntimeSettings { fn from_config(config: &config::NetworkProxyConfig) -> Result { let mitm_ca_trust_bundle = if config.network.mitm { - let env = crate::certs::CUSTOM_CA_ENV_KEYS - .into_iter() - .filter_map(|key| std::env::var(key).ok().map(|value| (key, value))) - .collect(); + let env = crate::certs::ca_env_from_process(); Some(crate::certs::managed_ca_trust_bundle(&env)?) } else { None diff --git a/codex-rs/network-proxy/src/upstream.rs b/codex-rs/network-proxy/src/upstream.rs index 3437b0d32deb..5d36706ffb65 100644 --- a/codex-rs/network-proxy/src/upstream.rs +++ b/codex-rs/network-proxy/src/upstream.rs @@ -21,6 +21,8 @@ use rama_net::client::EstablishedClientConnection; use rama_net::http::RequestContext; use rama_tls_rustls::client::TlsConnectorDataBuilder; use rama_tls_rustls::client::TlsConnectorLayer; +use rama_tls_rustls::client::client_root_certs; +use rama_tls_rustls::dep::rustls; use std::sync::Arc; use std::time::Instant; use tracing::info; @@ -104,6 +106,7 @@ impl UpstreamClient { Self::new( ProxyConfig::default(), TargetCheckedTcpConnector::new(state), + client_root_certs(), ) } @@ -111,20 +114,29 @@ impl UpstreamClient { Self::new( ProxyConfig::from_env(), TargetCheckedTcpConnector::new(state), + client_root_certs(), ) } - pub(crate) fn direct_with_allow_local_binding(allow_local_binding: bool) -> Self { + pub(crate) fn direct_with_allow_local_binding( + allow_local_binding: bool, + tls_root_store: Arc, + ) -> Self { Self::new( ProxyConfig::default(), TargetCheckedTcpConnector::from_allow_local_binding(allow_local_binding), + tls_root_store, ) } - pub(crate) fn from_env_proxy_with_allow_local_binding(allow_local_binding: bool) -> Self { + pub(crate) fn from_env_proxy_with_allow_local_binding( + allow_local_binding: bool, + tls_root_store: Arc, + ) -> Self { Self::new( ProxyConfig::from_env(), TargetCheckedTcpConnector::from_allow_local_binding(allow_local_binding), + tls_root_store, ) } @@ -137,8 +149,12 @@ impl UpstreamClient { } } - fn new(proxy_config: ProxyConfig, transport: TargetCheckedTcpConnector) -> Self { - let connector = build_http_connector(transport); + fn new( + proxy_config: ProxyConfig, + transport: TargetCheckedTcpConnector, + tls_root_store: Arc, + ) -> Self { + let connector = build_http_connector(transport, tls_root_store); Self { connector, proxy_config, @@ -221,6 +237,7 @@ impl Service> for UpstreamClient { fn build_http_connector( transport: TargetCheckedTcpConnector, + tls_root_store: Arc, ) -> BoxService< Request, EstablishedClientConnection, Request>, @@ -228,7 +245,10 @@ fn build_http_connector( > { ensure_rustls_crypto_provider(); let proxy = HttpProxyConnectorLayer::optional().into_layer(transport); - let tls_config = TlsConnectorDataBuilder::new() + let client_config = rustls::ClientConfig::builder_with_protocol_versions(rustls::ALL_VERSIONS) + .with_root_certificates(tls_root_store) + .with_no_client_auth(); + let tls_config = TlsConnectorDataBuilder::from(client_config) .with_alpn_protocols_http_auto() .build(); let tls = TlsConnectorLayer::auto() @@ -239,6 +259,10 @@ fn build_http_connector( connector.boxed() } +#[cfg(test)] +#[path = "upstream_tests.rs"] +mod tests; + #[cfg(target_os = "macos")] fn build_unix_connector( path: &str, diff --git a/codex-rs/network-proxy/src/upstream_tests.rs b/codex-rs/network-proxy/src/upstream_tests.rs new file mode 100644 index 000000000000..85f5da9017b7 --- /dev/null +++ b/codex-rs/network-proxy/src/upstream_tests.rs @@ -0,0 +1,111 @@ +use super::*; +use pretty_assertions::assert_eq; +use rama_http::StatusCode; +use rama_tls_rustls::dep::pki_types::CertificateDer; +use rama_tls_rustls::dep::pki_types::PrivateKeyDer; +use rama_tls_rustls::dep::pki_types::pem::PemObject; +use rama_tls_rustls::dep::rcgen::BasicConstraints; +use rama_tls_rustls::dep::rcgen::CertificateParams; +use rama_tls_rustls::dep::rcgen::DistinguishedName; +use rama_tls_rustls::dep::rcgen::DnType; +use rama_tls_rustls::dep::rcgen::ExtendedKeyUsagePurpose; +use rama_tls_rustls::dep::rcgen::IsCa; +use rama_tls_rustls::dep::rcgen::Issuer; +use rama_tls_rustls::dep::rcgen::KeyPair; +use rama_tls_rustls::dep::rcgen::KeyUsagePurpose; +use rama_tls_rustls::dep::rcgen::PKCS_ECDSA_P256_SHA256; +use rama_tls_rustls::dep::tokio_rustls::TlsAcceptor; +use std::collections::HashMap; +use std::fs; +use std::sync::Arc; +use tempfile::tempdir; +use tokio::io::AsyncReadExt; +use tokio::io::AsyncWriteExt; +use tokio::net::TcpListener; + +fn generate_ca(common_name: &str) -> (String, KeyPair) { + let mut params = CertificateParams::default(); + params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained); + params.key_usages = vec![ + KeyUsagePurpose::KeyCertSign, + KeyUsagePurpose::DigitalSignature, + KeyUsagePurpose::KeyEncipherment, + ]; + let mut distinguished_name = DistinguishedName::new(); + distinguished_name.push(DnType::CommonName, common_name); + params.distinguished_name = distinguished_name; + let key_pair = KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256).unwrap(); + let cert = params.self_signed(&key_pair).unwrap(); + (cert.pem(), key_pair) +} + +#[tokio::test] +async fn mitm_upstream_client_trusts_startup_custom_ca() { + ensure_rustls_crypto_provider(); + let temp_dir = tempdir().unwrap(); + let startup_ca_path = temp_dir.path().join("startup-ca.pem"); + let managed_ca_path = temp_dir.path().join("managed-ca.pem"); + let (startup_ca_pem, startup_ca_key) = generate_ca("startup CA"); + let (managed_ca_pem, _) = generate_ca("managed MITM CA"); + fs::write(&startup_ca_path, &startup_ca_pem).unwrap(); + fs::write(&managed_ca_path, managed_ca_pem).unwrap(); + + let issuer = Issuer::from_ca_cert_pem(&startup_ca_pem, startup_ca_key).unwrap(); + let mut server_params = CertificateParams::new(vec!["localhost".to_string()]).unwrap(); + server_params.extended_key_usages = vec![ExtendedKeyUsagePurpose::ServerAuth]; + server_params.key_usages = vec![ + KeyUsagePurpose::DigitalSignature, + KeyUsagePurpose::KeyEncipherment, + ]; + let server_key = KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256).unwrap(); + let server_cert = server_params.signed_by(&server_key, &issuer).unwrap(); + let server_cert = CertificateDer::from_pem_slice(server_cert.pem().as_bytes()).unwrap(); + let server_key = PrivateKeyDer::from_pem_slice(server_key.serialize_pem().as_bytes()).unwrap(); + let mut server_config = + rustls::ServerConfig::builder_with_protocol_versions(rustls::ALL_VERSIONS) + .with_no_client_auth() + .with_single_cert(vec![server_cert], server_key) + .unwrap(); + server_config.alpn_protocols = vec![b"http/1.1".to_vec()]; + + let env = HashMap::from([( + "SSL_CERT_FILE", + startup_ca_path.to_string_lossy().into_owned(), + )]); + let roots = + crate::certs::upstream_tls_root_store_for_cert_path(&managed_ca_path, &env).unwrap(); + let baseline_roots = + crate::certs::upstream_tls_root_store_for_cert_path(&managed_ca_path, &HashMap::new()) + .unwrap(); + assert_eq!(roots.len(), baseline_roots.len() + 1); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let address = listener.local_addr().unwrap(); + let acceptor = TlsAcceptor::from(Arc::new(server_config)); + let server = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + let mut stream = acceptor.accept(stream).await.unwrap(); + let mut request = [0; 4096]; + let bytes_read = stream.read(&mut request).await.unwrap(); + assert!(bytes_read > 0); + stream + .write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\nConnection: close\r\n\r\n") + .await + .unwrap(); + }); + + let client = + UpstreamClient::direct_with_allow_local_binding(/*allow_local_binding*/ true, roots); + let response = client + .serve( + Request::builder() + .uri(format!("https://localhost:{}/", address.port())) + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + server.await.unwrap(); +}