diff --git a/bitreq/examples/custom_cert.rs b/bitreq/examples/custom_cert.rs new file mode 100644 index 000000000..9c334f3b5 --- /dev/null +++ b/bitreq/examples/custom_cert.rs @@ -0,0 +1,37 @@ +//! This example demonstrates the client builder with custom DER certificate. +//! to run: cargo run --example custom_cert --features async-https-rustls + +#[cfg(not(feature = "async-https-rustls"))] +fn main() { + println!("This example requires the 'async-https-rustls' feature."); +} + +#[cfg(feature = "async-https-rustls")] +fn main() -> Result<(), bitreq::Error> { + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_io() + .build() + .expect("failed to build Tokio runtime"); + + runtime.block_on(request_with_client()) +} + +#[cfg(feature = "async-https-rustls")] +async fn request_with_client() -> Result<(), bitreq::Error> { + let url = "https://example.com"; + let cert_der = include_bytes!("../tests/test_cert.der"); + let client = bitreq::Client::builder().with_root_certificate(cert_der.as_slice())?.build(); + // OR + // let cert_der: &[u8] = include_bytes!("../tests/test_cert.der"); + // let client = bitreq::Client::builder().with_root_certificate(cert_der)?.build(); + // OR + // let cert_vec: Vec = include_bytes!("../tests/test_cert.der").to_vec(); + // let client = bitreq::Client::builder().with_root_certificate(cert_vec.as_slice())?.build(); + + let response = client.send_async(bitreq::get(url)).await.unwrap(); + + println!("Status: {}", response.status_code); + println!("Body: {}", response.as_str()?); + + Ok(()) +} diff --git a/bitreq/src/client.rs b/bitreq/src/client.rs index b5de6f2fb..a7edf6f27 100644 --- a/bitreq/src/client.rs +++ b/bitreq/src/client.rs @@ -13,6 +13,193 @@ use crate::connection::AsyncConnection; use crate::request::{OwnedConnectionParams as ConnectionKey, ParsedRequest}; use crate::{Error, Request, Response}; +mod tls { + #[cfg(not(all(feature = "rustls", feature = "tokio-rustls")))] + pub(crate) use self::disabled::*; + #[cfg(all(feature = "rustls", feature = "tokio-rustls"))] + pub(crate) use self::enabled::*; + + #[cfg(not(all(feature = "rustls", feature = "tokio-rustls")))] + mod disabled { + #[derive(Clone)] + pub(crate) struct ClientConfig; + + impl ClientConfig { + pub fn build(self) -> Self { self } + } + } + + #[cfg(all(feature = "rustls", feature = "tokio-rustls"))] + mod enabled { + use crate::client::ClientBuilder; + use crate::connection::certificates::Certificates; + use crate::Error; + + #[derive(Clone)] + pub(crate) struct ClientConfig { + pub(crate) tls: Option, + } + + impl ClientConfig { + pub fn build(self) -> Self { + let tls = self.tls.map(|tls| tls.build()); + Self { tls } + } + } + + #[derive(Clone)] + pub(crate) struct TlsConfig { + pub(crate) certificates: Certificates, + } + + impl TlsConfig { + fn new(cert_der: Vec) -> Result { + let certificates = Certificates::new(Some(cert_der))?; + + Ok(Self { certificates }) + } + + fn build(mut self) -> Self { + self.certificates = self.certificates.with_root_certificates(); + self + } + } + + impl ClientBuilder { + /// Adds a custom root certificate for TLS verification. + /// + /// The certificate must be provided in DER format. This method accepts any type + /// that can be converted into a `Vec`, such as `Vec`, `&[u8]`, or arrays. + /// This is useful when connecting to servers using self-signed certificates + /// or custom Certificate Authorities. + /// + /// # Arguments + /// + /// * `cert_der` - A DER-encoded X.509 certificate. Accepts any type that implements + /// `Into>` (e.g., `&[u8]`, `Vec`, or `[u8; N]`). + /// + /// # Example + /// + /// ```no_run + /// # use bitreq::Client; + /// // Using a byte slice + /// let cert_der: &[u8] = include_bytes!("../tests/test_cert.der"); + /// let client = Client::builder() + /// .with_root_certificate(cert_der) + /// .unwrap() + /// .build(); + /// + /// // Using a Vec + /// let cert_vec: Vec = cert_der.to_vec(); + /// let client = Client::builder() + /// .with_root_certificate(cert_vec) + /// .unwrap() + /// .build(); + /// ``` + pub fn with_root_certificate>>( + mut self, + cert_der: T, + ) -> Result { + let cert_der = cert_der.into(); + + if let Some(ref mut client_config) = self.client_config { + if let Some(ref mut tls_config) = client_config.tls { + let certificates = + tls_config.certificates.clone().append_certificate(cert_der)?; + tls_config.certificates = certificates; + + return Ok(self); + } + } + + let tls_config = TlsConfig::new(cert_der)?; + self.client_config = Some(ClientConfig { tls: Some(tls_config) }); + Ok(self) + } + } + } +} + +pub(crate) use tls::ClientConfig; + +pub struct ClientBuilder { + capacity: usize, + client_config: Option, +} + +/// Builder for configuring a `Client` with custom settings. +/// +/// The builder allows you to set the connection pool capacity and add +/// custom root certificates for TLS verification before constructing the client. +/// +/// # Example +/// +/// ```no_run +/// # async fn example() -> Result<(), bitreq::Error> { +/// use bitreq::{Client, RequestExt}; +/// +/// let cert_der = include_bytes!("../tests/test_cert.der"); +/// let client = Client::builder() +/// .with_capacity(20) +/// .build(); +/// +/// let response = bitreq::get("https://example.com") +/// .send_async_with_client(&client) +/// .await?; +/// # Ok(()) +/// # } +/// ``` +impl ClientBuilder { + /// Creates a new `ClientBuilder` with default settings. + /// + /// Default configuration: + /// * `capacity` - 10 (single connection) + /// * `root_certificates` - None (uses system certificates) + pub fn new() -> Self { Self { capacity: 10, client_config: None } } + + /// Sets the maximum number of connections to keep in the pool. + /// + /// When the pool reaches this capacity, the least recently used connection + /// is evicted to make room for new connections. + /// + /// # Arguments + /// + /// * `capacity` - Maximum number of cached connections + /// + /// # Example + /// + /// ```no_run + /// # use bitreq::Client; + /// let client = Client::builder() + /// .with_capacity(10) + /// .build(); + /// ``` + pub fn with_capacity(mut self, capacity: usize) -> Self { + self.capacity = capacity; + self + } + + /// Builds the `Client` with the configured settings. + /// + /// Consumes the builder and returns a configured `Client` instance + /// ready to send requests with connection pooling. + pub fn build(self) -> Client { + let client_config = self.client_config.map(|c| c.build()); + Client { + r#async: Arc::new(Mutex::new(ClientImpl { + connections: HashMap::new(), + lru_order: VecDeque::new(), + capacity: self.capacity, + client_config: client_config.map(Arc::new), + })), + } + } +} + +impl Default for ClientBuilder { + fn default() -> Self { Self::new() } +} + /// A client that caches connections for reuse. /// /// The client maintains a pool of up to `capacity` connections, evicting @@ -39,10 +226,11 @@ struct ClientImpl { connections: HashMap>, lru_order: VecDeque, capacity: usize, + client_config: Option>, } impl Client { - /// Creates a new `Client` with the specified connection cache capacity. + /// Creates a new `Client` with the specified connection pool capacity. /// /// # Arguments /// @@ -54,10 +242,14 @@ impl Client { connections: HashMap::new(), lru_order: VecDeque::new(), capacity, + client_config: None, })), } } + /// Create a builder for a client + pub fn builder() -> ClientBuilder { ClientBuilder::new() } + /// Sends a request asynchronously using a cached connection if available. pub async fn send_async(&self, request: Request) -> Result { let parsed_request = ParsedRequest::new(request)?; @@ -77,7 +269,13 @@ impl Client { let conn = if let Some(conn) = conn_opt { conn } else { - let connection = AsyncConnection::new(key, parsed_request.timeout_at).await?; + let client_config = { + let state = self.r#async.lock().unwrap(); + state.client_config.as_ref().map(Arc::clone) + }; + + let connection = + AsyncConnection::new(key, parsed_request.timeout_at, client_config).await?; let connection = Arc::new(connection); let mut state = self.r#async.lock().unwrap(); diff --git a/bitreq/src/connection.rs b/bitreq/src/connection.rs index f8b98c133..d719f400a 100644 --- a/bitreq/src/connection.rs +++ b/bitreq/src/connection.rs @@ -22,6 +22,8 @@ use tokio::net::TcpStream as AsyncTcpStream; #[cfg(feature = "async")] use tokio::sync::Mutex as AsyncMutex; +#[cfg(feature = "async")] +use crate::client::ClientConfig; use crate::request::{ConnectionParams, OwnedConnectionParams, ParsedRequest}; #[cfg(feature = "async")] use crate::Response; @@ -29,6 +31,8 @@ use crate::{Error, Method, ResponseLazy}; type UnsecuredStream = TcpStream; +#[cfg(all(feature = "rustls", feature = "tokio-rustls"))] +pub(crate) mod certificates; #[cfg(feature = "rustls")] mod rustls_stream; #[cfg(feature = "rustls")] @@ -238,6 +242,7 @@ struct AsyncConnectionState { /// Defaults to 60 seconds after open to align with nginx's default timeout of 75 seconds, but /// can be overridden by the `Keep-Alive` header. socket_new_requests_timeout: Mutex, + client_config: Option>, } #[cfg(feature = "async")] @@ -266,15 +271,15 @@ impl AsyncConnection { pub(crate) async fn new( params: ConnectionParams<'_>, timeout_at: Option, + client_config: Option>, ) -> Result { + let config = client_config.as_ref().map(Arc::clone); + let future = async move { let socket = Self::connect(params).await?; if params.https { - #[cfg(not(feature = "tokio-rustls"))] - return Err(Error::HttpsFeatureNotEnabled); - #[cfg(feature = "tokio-rustls")] - rustls_stream::wrap_async_stream(socket, params.host).await + Self::wrap_async_stream(socket, params.host, config).await } else { Ok(AsyncHttpStream::Unsecured(socket)) } @@ -295,9 +300,34 @@ impl AsyncConnection { readable_request_id: AtomicUsize::new(0), min_dropped_reader_id: AtomicUsize::new(usize::MAX), socket_new_requests_timeout: Mutex::new(Instant::now() + Duration::from_secs(60)), + client_config, })))) } + /// Call the correct wrapper function depending on whether client_configs are present + #[cfg(all(feature = "rustls", feature = "tokio-rustls"))] + async fn wrap_async_stream( + socket: AsyncTcpStream, + host: &str, + client_config: Option>, + ) -> Result { + if let Some(client_config) = client_config { + rustls_stream::wrap_async_stream_with_configs(socket, host, client_config).await + } else { + rustls_stream::wrap_async_stream(socket, host).await + } + } + + /// Error treatment function, should not be called under normal circustances + #[cfg(not(all(feature = "rustls", feature = "tokio-rustls")))] + async fn wrap_async_stream( + _socket: AsyncTcpStream, + _host: &str, + _client_config: Option>, + ) -> Result { + Err(Error::HttpsFeatureNotEnabled) + } + async fn tcp_connect(host: &str, port: u16) -> Result { #[cfg(feature = "log")] log::trace!("Looking up host {host}"); @@ -446,9 +476,13 @@ impl AsyncConnection { retry_new_connection!(_internal); }; (_internal) => { - let new_connection = - AsyncConnection::new(request.connection_params(), request.timeout_at) - .await?; + let config = conn.client_config.as_ref().map(Arc::clone); + let new_connection = AsyncConnection::new( + request.connection_params(), + request.timeout_at, + config, + ) + .await?; *self.0.lock().unwrap() = Arc::clone(&*new_connection.0.lock().unwrap()); core::mem::drop(read); // Note that this cannot recurse infinitely as we'll always be able to send at @@ -806,7 +840,8 @@ async fn async_handle_redirects( let new_connection; if needs_new_connection { new_connection = - AsyncConnection::new(request.connection_params(), request.timeout_at).await?; + AsyncConnection::new(request.connection_params(), request.timeout_at, None) + .await?; connection = &new_connection; } connection.send(request).await diff --git a/bitreq/src/connection/certificates.rs b/bitreq/src/connection/certificates.rs new file mode 100644 index 000000000..37f68cd45 --- /dev/null +++ b/bitreq/src/connection/certificates.rs @@ -0,0 +1,63 @@ +#[cfg(feature = "rustls")] +use std::sync::Arc; + +#[cfg(feature = "rustls")] +use rustls::RootCertStore; +#[cfg(feature = "rustls-webpki")] +use webpki_roots::TLS_SERVER_ROOTS; + +use crate::Error; + +#[derive(Clone)] +pub(crate) struct Certificates { + pub(crate) inner: Arc, +} + +impl Certificates { + pub(crate) fn new(cert_der: Option>) -> Result { + let certificates = Self { inner: Arc::new(RootCertStore::empty()) }; + + if let Some(cert_der) = cert_der { + certificates.append_certificate(cert_der) + } else { + Ok(certificates) + } + } + + #[cfg(feature = "rustls")] + pub(crate) fn append_certificate(mut self, cert_der: Vec) -> Result { + let certificates = Arc::make_mut(&mut self.inner); + certificates.add(&rustls::Certificate(cert_der)).map_err(Error::RustlsAppendCert)?; + + Ok(self) + } + + #[cfg(feature = "rustls")] + pub(crate) fn with_root_certificates(mut self) -> Self { + let root_certificates = Arc::make_mut(&mut self.inner); + + // Try to load native certs + #[cfg(feature = "https-rustls-probe")] + if let Ok(os_roots) = rustls_native_certs::load_native_certs() { + for root_cert in os_roots { + // Ignore erroneous OS certificates, there's nothing + // to do differently in that situation anyways. + let _ = root_certificates.add(&rustls::Certificate(root_cert.0)); + } + } + + #[cfg(feature = "rustls-webpki")] + { + #[allow(deprecated)] + // Need to use add_server_trust_anchors to compile with rustls 0.21.1 + root_certificates.add_server_trust_anchors(TLS_SERVER_ROOTS.iter().map(|ta| { + rustls::OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject, + ta.spki, + ta.name_constraints, + ) + })); + } + self + } +} diff --git a/bitreq/src/connection/rustls_stream.rs b/bitreq/src/connection/rustls_stream.rs index 01a3c417f..2b6b59a3b 100644 --- a/bitreq/src/connection/rustls_stream.rs +++ b/bitreq/src/connection/rustls_stream.rs @@ -24,6 +24,8 @@ use webpki_roots::TLS_SERVER_ROOTS; use super::{AsyncHttpStream, AsyncTcpStream}; #[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))] use super::{AsyncHttpStream, AsyncTcpStream}; +#[cfg(feature = "async")] +use crate::client::ClientConfig as CustomClientConfig; use crate::Error; #[cfg(feature = "rustls")] @@ -63,6 +65,15 @@ fn build_client_config() -> Arc { Arc::new(config) } +#[cfg(all(feature = "rustls", feature = "tokio-rustls"))] +fn build_rustls_client_config(certificates: Arc) -> Arc { + let config = ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(certificates) + .with_no_client_auth(); + Arc::new(config) +} + #[cfg(feature = "rustls")] pub(super) fn wrap_stream(tcp: TcpStream, host: &str) -> Result { #[cfg(feature = "log")] @@ -106,6 +117,33 @@ pub(super) async fn wrap_async_stream( Ok(AsyncHttpStream::Secured(Box::new(tls))) } +#[cfg(all(feature = "rustls", feature = "tokio-rustls"))] +pub(super) async fn wrap_async_stream_with_configs( + tcp: AsyncTcpStream, + host: &str, + custom_client_config: Arc, +) -> Result { + #[cfg(feature = "log")] + log::trace!("Setting up TLS parameters for {host}."); + let dns_name = match ServerName::try_from(host) { + Ok(result) => result, + Err(err) => return Err(Error::IoError(io::Error::new(io::ErrorKind::Other, err))), + }; + + let tls_config = custom_client_config.tls.as_ref().unwrap(); + let certificates = Arc::clone(&tls_config.certificates.inner); + + let client_config = build_rustls_client_config(certificates); + let connector = TlsConnector::from(client_config); + + #[cfg(feature = "log")] + log::trace!("Establishing TLS session to {host}."); + + let tls = connector.connect(dns_name, tcp).await.map_err(Error::IoError)?; + + Ok(AsyncHttpStream::Secured(Box::new(tls))) +} + #[cfg(all(feature = "native-tls", not(feature = "rustls")))] pub type SecuredStream = TlsStream; diff --git a/bitreq/src/error.rs b/bitreq/src/error.rs index ca9d1421d..5a4c6024f 100644 --- a/bitreq/src/error.rs +++ b/bitreq/src/error.rs @@ -22,6 +22,9 @@ pub enum Error { #[cfg(feature = "rustls")] /// Ran into a rustls error while creating the connection. RustlsCreateConnection(rustls::Error), + #[cfg(feature = "rustls")] + /// Ran into a rustls error while appending a certificate. + RustlsAppendCert(rustls::Error), #[cfg(feature = "native-tls")] /// Ran into a native-tls error while creating the connection. NativeTlsCreateConnection(native_tls::Error), @@ -104,6 +107,8 @@ impl fmt::Display for Error { InvalidUtf8InBody(err) => write!(f, "{}", err), #[cfg(feature = "rustls")] RustlsCreateConnection(err) => write!(f, "error creating rustls connection: {}", err), + #[cfg(feature = "rustls")] + RustlsAppendCert(err) => write!(f, "error appending certificate: {}", err), #[cfg(feature = "native-tls")] NativeTlsCreateConnection(err) => write!(f, "error creating native-tls connection: {err}"), MalformedChunkLength => write!(f, "non-usize chunk length with transfer-encoding: chunked"), @@ -147,6 +152,8 @@ impl error::Error for Error { InvalidUtf8InBody(err) => Some(err), #[cfg(feature = "rustls")] RustlsCreateConnection(err) => Some(err), + #[cfg(feature = "rustls")] + RustlsAppendCert(err) => Some(err), _ => None, } } diff --git a/bitreq/src/request.rs b/bitreq/src/request.rs index d39d6d89a..24411bc0b 100644 --- a/bitreq/src/request.rs +++ b/bitreq/src/request.rs @@ -327,7 +327,7 @@ impl Request { #[cfg(feature = "async")] pub async fn send_async(self) -> Result { let parsed_request = ParsedRequest::new(self)?; - AsyncConnection::new(parsed_request.connection_params(), parsed_request.timeout_at) + AsyncConnection::new(parsed_request.connection_params(), parsed_request.timeout_at, None) .await? .send(parsed_request) .await diff --git a/bitreq/tests/main.rs b/bitreq/tests/main.rs index 8d357f354..6dd077500 100644 --- a/bitreq/tests/main.rs +++ b/bitreq/tests/main.rs @@ -16,6 +16,35 @@ async fn test_https() { assert_eq!(get_status_code(bitreq::get("https://example.com")).await, 200); } +#[tokio::test] +#[cfg(all(feature = "rustls", feature = "tokio-rustls"))] +async fn test_https_with_client() { + setup(); + let client = bitreq::Client::new(1); + let response = client.send_async(bitreq::get("https://example.com")).await.unwrap(); + assert_eq!(response.status_code, 200); +} + +#[tokio::test] +#[cfg(all(feature = "rustls", feature = "tokio-rustls"))] +async fn test_https_with_client_builder() { + setup(); + let client = bitreq::Client::builder().build(); + let response = client.send_async(bitreq::get("https://example.com")).await.unwrap(); + assert_eq!(response.status_code, 200); +} + +#[tokio::test] +#[cfg(all(feature = "rustls", feature = "tokio-rustls"))] +async fn test_https_with_client_builder_and_cert() { + setup(); + let cert_der = include_bytes!("test_cert.der"); + let client = + bitreq::Client::builder().with_root_certificate(cert_der.as_slice()).unwrap().build(); + let response = client.send_async(bitreq::get("https://example.com")).await.unwrap(); + assert_eq!(response.status_code, 200); +} + #[tokio::test] #[cfg(feature = "json-using-serde")] async fn test_json_using_serde() { diff --git a/bitreq/tests/test_cert.der b/bitreq/tests/test_cert.der new file mode 100644 index 000000000..f8d4129e3 Binary files /dev/null and b/bitreq/tests/test_cert.der differ