Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 93 additions & 1 deletion crates/sprout-agent/src/llm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,17 @@ async fn backoff_with_jitter(attempt: u32) {
tokio::time::sleep(std::time::Duration::from_millis(delay)).await;
}

/// Transport-layer errors safe to retry for non-streaming LLM POSTs.
///
/// Covers timeouts, connect failures, and the broader request-class errors
/// reqwest reports for pre-response failures: TLS handshake aborts, sockets
/// dropped or reset mid-send, h2 GOAWAY/RST_STREAM, hyper protocol errors.
/// Body-serialization happens before the retry loop, so `is_request()` here
/// is always a network failure, never a malformed request we'd just resend.
fn is_retryable_transport_error(e: &reqwest::Error) -> bool {
e.is_timeout() || e.is_connect() || e.is_request()
}

async fn post<F>(http: &Client, url: &str, body: &Value, apply: F) -> Result<Value, AgentError>
where
F: Fn(reqwest::RequestBuilder) -> reqwest::RequestBuilder,
Expand All @@ -692,7 +703,13 @@ where
{
Ok(r) => r,
Err(e) => {
if attempt + 1 < MAX_RETRIES && (e.is_timeout() || e.is_connect()) {
if attempt + 1 < MAX_RETRIES && is_retryable_transport_error(&e) {
tracing::warn!(
attempt = attempt + 1,
max_attempts = MAX_RETRIES,
error = %e,
"llm: transport error, retrying"
);
backoff_with_jitter(attempt).await;
continue;
}
Expand All @@ -704,6 +721,12 @@ where
return Err(AgentError::LlmAuth(read_error_body(resp).await));
}
if (status.is_server_error() || status == 429) && attempt + 1 < MAX_RETRIES {
tracing::warn!(
attempt = attempt + 1,
max_attempts = MAX_RETRIES,
%status,
"llm: retryable status, retrying"
);
backoff_with_jitter(attempt).await;
continue;
}
Expand Down Expand Up @@ -1092,4 +1115,73 @@ mod tests {
"data:image/png;base64,aW1n"
);
}

/// Regression: a connection that is accepted and then dropped before any
/// HTTP response bytes are written surfaces as a reqwest request-class
/// error (not `is_connect()`, not `is_timeout()`). The retry predicate
/// must recognize it; otherwise transient TLS/h2/proxy hiccups bubble
/// out of the agent as `transport: error sending request ...`.
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn post_retries_on_dropped_connection_before_response() {
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;

let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let url = format!("http://{}/v1/x", listener.local_addr().unwrap());
let accepts = Arc::new(AtomicU32::new(0));
let accepts_srv = accepts.clone();

tokio::spawn(async move {
loop {
let (mut sock, _) = match listener.accept().await {
Ok(p) => p,
Err(_) => return,
};
let n = accepts_srv.fetch_add(1, Ordering::SeqCst);
if n == 0 {
// First attempt: read the request, then drop the socket
// without writing a response. reqwest surfaces this as
// a request-class error (is_request() == true).
let mut tmp = [0u8; 4096];
let _ = sock.read(&mut tmp).await;
drop(sock);
continue;
}
// Subsequent attempts: serve a tiny JSON body.
let mut buf = Vec::new();
let mut tmp = [0u8; 4096];
while !buf.windows(4).any(|w| w == b"\r\n\r\n") {
match sock.read(&mut tmp).await {
Ok(0) | Err(_) => return,
Ok(k) => buf.extend_from_slice(&tmp[..k]),
}
}
let body = "{\"ok\":true}";
let resp = format!(
"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n\
Content-Length: {}\r\nConnection: close\r\n\r\n{}",
body.len(),
body,
);
let _ = sock.write_all(resp.as_bytes()).await;
let _ = sock.shutdown().await;
}
});

let client = Client::builder()
.timeout(Duration::from_secs(5))
.build()
.unwrap();
let out = post(&client, &url, &serde_json::json!({}), |b| b)
.await
.expect("post should succeed after retry");
assert_eq!(out, serde_json::json!({ "ok": true }));
assert!(
accepts.load(Ordering::SeqCst) >= 2,
"server should have seen at least 2 connection attempts, saw {}",
accepts.load(Ordering::SeqCst)
);
}
}
Loading