diff --git a/Cargo.lock b/Cargo.lock index 95c5ae4dc..a8ea1d4a1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3960,11 +3960,14 @@ dependencies = [ name = "sprout-dev-mcp" version = "0.1.0" dependencies = [ + "base64", "git-credential-nostr", "git-sign-nostr", "ignore", + "image", "nix", "nostr", + "reqwest 0.13.3", "rmcp", "schemars", "serde", diff --git a/crates/sprout-agent/src/agent.rs b/crates/sprout-agent/src/agent.rs index 46397431a..5d7237ad9 100644 --- a/crates/sprout-agent/src/agent.rs +++ b/crates/sprout-agent/src/agent.rs @@ -11,6 +11,7 @@ use crate::mcp::McpRegistry; use crate::types::{ AgentError, ContentBlock, HistoryItem, ProviderStop, StopReason, ToolCall, ToolResult, + ToolResultContent, }; use crate::wire::{self, WireSender}; @@ -214,13 +215,17 @@ impl RunCtx<'_> { for (i, call) in calls.iter().enumerate() { let mut result = results[i].take().unwrap_or_else(|| ToolResult { provider_id: call.provider_id.clone(), - text: "internal error: missing result".into(), + content: vec![ToolResultContent::Text( + "internal error: missing result".into(), + )], is_error: true, }); // On tool error: append a reflection prompt so the LLM // diagnoses the failure before blindly retrying. if result.is_error { - result.text.push_str(ERROR_REFLECTION_SUFFIX); + result + .content + .push(ToolResultContent::Text(ERROR_REFLECTION_SUFFIX.to_string())); } self.history.push(HistoryItem::ToolResult(result)); } @@ -445,7 +450,7 @@ async fn emit_completed(wire: &WireSender, sid: &str, call: &ToolCall, result: & "sessionUpdate": "tool_call_update", "toolCallId": call.provider_id, "status": "completed", - "content": [{ "type": "content", "content": { "type": "text", "text": result.text } }], + "content": [{ "type": "content", "content": { "type": "text", "text": result.text() } }], "rawOutput": { "isError": result.is_error }, }), ), @@ -537,7 +542,9 @@ pub(crate) fn push_hook_outputs_as_tool_results( }); history.push(HistoryItem::ToolResult(ToolResult { provider_id, - text: format_hook_output_body(hook, server, text), + content: vec![ToolResultContent::Text(format_hook_output_body( + hook, server, text, + ))], is_error: false, })); } @@ -555,7 +562,7 @@ fn unique_nonce() -> u64 { fn synthetic_tool_result(call: &ToolCall, msg: String) -> ToolResult { ToolResult { provider_id: call.provider_id.clone(), - text: msg, + content: vec![ToolResultContent::Text(msg)], is_error: true, } } diff --git a/crates/sprout-agent/src/config.rs b/crates/sprout-agent/src/config.rs index be72e0ad9..923b69e96 100644 --- a/crates/sprout-agent/src/config.rs +++ b/crates/sprout-agent/src/config.rs @@ -3,7 +3,7 @@ use std::time::Duration; pub const PROTOCOL_VERSION: u32 = 1; pub const MAX_PROMPT_BYTES: usize = 1024 * 1024; -pub const MAX_TOOL_RESULT_BYTES: usize = 256 * 1024; +pub const MAX_TOOL_RESULT_BYTES: usize = 8 * 1024 * 1024; pub const MAX_TOOL_CALLS_PER_TURN: usize = 64; /// Leaves headroom for the summary call. @@ -105,7 +105,7 @@ impl Config { mcp_restart_max_ms: parse_env("SPROUT_AGENT_MCP_RESTART_MAX_MS", 30_000u64)?, max_sessions: parse_env("SPROUT_AGENT_MAX_SESSIONS", usize::MAX)?, max_line_bytes: parse_env("SPROUT_AGENT_MAX_LINE_BYTES", 4 * 1024 * 1024)?, - max_history_bytes: parse_env("SPROUT_AGENT_MAX_HISTORY_BYTES", 1024 * 1024)?, + max_history_bytes: parse_env("SPROUT_AGENT_MAX_HISTORY_BYTES", 16 * 1024 * 1024)?, max_handoffs: parse_env("SPROUT_AGENT_MAX_HANDOFFS", 5)?, max_parallel_tools: parse_env("SPROUT_AGENT_MAX_PARALLEL_TOOLS", 8usize)?, hook_timeout: Duration::from_millis(parse_env( diff --git a/crates/sprout-agent/src/handoff.rs b/crates/sprout-agent/src/handoff.rs index 4cb47e701..771c3adda 100644 --- a/crates/sprout-agent/src/handoff.rs +++ b/crates/sprout-agent/src/handoff.rs @@ -180,7 +180,7 @@ fn push_history_snippet(out: &mut String, item: &HistoryItem) { } HistoryItem::ToolResult(r) => { out.push_str(if r.is_error { "[tool_err] " } else { "[tool] " }); - out.push_str(&clamp_for_snippet(&r.text)); + out.push_str(&clamp_for_snippet(&r.text())); out.push('\n'); } } diff --git a/crates/sprout-agent/src/llm.rs b/crates/sprout-agent/src/llm.rs index 9760dd049..ddf85f637 100644 --- a/crates/sprout-agent/src/llm.rs +++ b/crates/sprout-agent/src/llm.rs @@ -2,7 +2,9 @@ use reqwest::Client; use serde_json::{json, Value}; use crate::config::{Config, Provider}; -use crate::types::{AgentError, HistoryItem, LlmResponse, ProviderStop, ToolCall, ToolDef}; +use crate::types::{ + AgentError, HistoryItem, LlmResponse, ProviderStop, ToolCall, ToolDef, ToolResultContent, +}; const MAX_LLM_RESPONSE_BYTES: usize = 16 * 1024 * 1024; const MAX_LLM_ERROR_BODY_BYTES: usize = 4 * 1024; @@ -127,7 +129,7 @@ fn anthropic_body(cfg: &Config, history: &[HistoryItem], tools: &[ToolDef]) -> V } HistoryItem::ToolResult(r) => pending.push(json!({ "type": "tool_result", "tool_use_id": r.provider_id, - "content": [{ "type": "text", "text": r.text }], "is_error": r.is_error })), + "content": anthropic_tool_result_content(&r.content), "is_error": r.is_error })), } } flush(&mut messages, &mut pending); @@ -146,6 +148,19 @@ fn anthropic_body(cfg: &Config, history: &[HistoryItem], tools: &[ToolDef]) -> V body } +fn anthropic_tool_result_content(content: &[ToolResultContent]) -> Vec { + content + .iter() + .map(|c| match c { + ToolResultContent::Text(text) => json!({ "type": "text", "text": text }), + ToolResultContent::Image { data, mime_type } => json!({ + "type": "image", + "source": { "type": "base64", "media_type": mime_type, "data": data }, + }), + }) + .collect() +} + fn openai_body(cfg: &Config, history: &[HistoryItem], tools: &[ToolDef]) -> Value { let mut messages: Vec = vec![json!({ "role": "system", "content": cfg.system_prompt })]; for item in history { @@ -170,8 +185,15 @@ fn openai_body(cfg: &Config, history: &[HistoryItem], tools: &[ToolDef]) -> Valu } messages.push(Value::Object(msg)); } - HistoryItem::ToolResult(r) => messages.push(json!({ - "role": "tool", "tool_call_id": r.provider_id, "content": r.text })), + HistoryItem::ToolResult(r) => { + messages.push(json!({ + "role": "tool", "tool_call_id": r.provider_id, + "content": openai_tool_text_content(&r.content) })); + let image_content = openai_image_user_content(&r.content); + if !image_content.is_empty() { + messages.push(json!({ "role": "user", "content": image_content })); + } + } } } let tools_json: Vec = tools @@ -192,6 +214,33 @@ fn openai_body(cfg: &Config, history: &[HistoryItem], tools: &[ToolDef]) -> Valu body } +fn openai_tool_text_content(content: &[ToolResultContent]) -> String { + let mut parts = Vec::new(); + for c in content { + match c { + ToolResultContent::Text(text) => parts.push(text.clone()), + ToolResultContent::Image { data, mime_type } => parts.push(format!( + "This tool result included an image ({mime_type}, {} base64 bytes) that is provided in the next user message.", + data.len() + )), + } + } + parts.join("\n") +} + +fn openai_image_user_content(content: &[ToolResultContent]) -> Vec { + content + .iter() + .filter_map(|c| match c { + ToolResultContent::Image { data, mime_type } => Some(json!({ + "type": "image_url", + "image_url": { "url": format!("data:{mime_type};base64,{data}") }, + })), + ToolResultContent::Text(_) => None, + }) + .collect() +} + fn map_stop(s: Option<&str>) -> ProviderStop { match s { Some("end_turn" | "stop") => ProviderStop::EndTurn, @@ -389,3 +438,90 @@ where } Err(AgentError::Llm("exhausted retries".into())) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::{Config, HookServers, Provider}; + use crate::types::{HistoryItem, ToolCall, ToolResult, ToolResultContent}; + use std::time::Duration; + + fn cfg(provider: Provider) -> Config { + Config { + provider, + system_prompt: "system".into(), + max_rounds: 10, + max_output_tokens: 1024, + llm_timeout: Duration::from_secs(10), + tool_timeout: Duration::from_secs(10), + mcp_init_timeout: Duration::from_secs(10), + mcp_max_restart_attempts: 1, + mcp_restart_base_ms: 1, + mcp_restart_max_ms: 1, + max_sessions: 1, + max_line_bytes: 1024 * 1024, + max_history_bytes: 16 * 1024 * 1024, + max_handoffs: 1, + max_parallel_tools: 1, + hook_timeout: Duration::from_secs(1), + stop_max_rejections: 0, + hook_servers: HookServers::None, + api_key: "key".into(), + model: "model".into(), + base_url: "http://example.invalid".into(), + anthropic_api_version: "2023-06-01".into(), + } + } + + fn image_history() -> Vec { + vec![ + HistoryItem::User("describe the image".into()), + HistoryItem::Assistant { + text: String::new(), + tool_calls: vec![ToolCall { + provider_id: "toolu_1".into(), + name: "dev__view_image".into(), + arguments: serde_json::json!({"source":"x.png"}), + }], + }, + HistoryItem::ToolResult(ToolResult { + provider_id: "toolu_1".into(), + content: vec![ + ToolResultContent::Text("10×10, 70 B (image/png from x.png)".into()), + ToolResultContent::Image { + data: "aW1n".into(), + mime_type: "image/png".into(), + }, + ], + is_error: false, + }), + ] + } + + #[test] + fn anthropic_tool_result_preserves_image_block() { + let body = anthropic_body(&cfg(Provider::Anthropic), &image_history(), &[]); + let content = &body["messages"][2]["content"][0]["content"]; + assert_eq!(content[0]["type"], "text"); + assert_eq!(content[1]["type"], "image"); + assert_eq!(content[1]["source"]["type"], "base64"); + assert_eq!(content[1]["source"]["media_type"], "image/png"); + assert_eq!(content[1]["source"]["data"], "aW1n"); + } + + #[test] + fn openai_tool_result_adds_followup_image_user_message() { + let body = openai_body(&cfg(Provider::OpenAi), &image_history(), &[]); + assert_eq!(body["messages"][3]["role"], "tool"); + assert!(body["messages"][3]["content"] + .as_str() + .unwrap() + .contains("provided in the next user message")); + assert_eq!(body["messages"][4]["role"], "user"); + assert_eq!(body["messages"][4]["content"][0]["type"], "image_url"); + assert_eq!( + body["messages"][4]["content"][0]["image_url"]["url"], + "data:image/png;base64,aW1n" + ); + } +} diff --git a/crates/sprout-agent/src/mcp.rs b/crates/sprout-agent/src/mcp.rs index ef875663a..b332400b3 100644 --- a/crates/sprout-agent/src/mcp.rs +++ b/crates/sprout-agent/src/mcp.rs @@ -14,7 +14,7 @@ use tokio::sync::watch; use tokio::sync::Mutex as AsyncMutex; use crate::config::{Config, HookServers}; -use crate::types::{clamp, AgentError, McpServerStdio, ToolDef, ToolResult}; +use crate::types::{clamp, AgentError, McpServerStdio, ToolDef, ToolResult, ToolResultContent}; const SEP: &str = "__"; const MAX_NAME_LEN: usize = 128; @@ -343,8 +343,8 @@ impl McpRegistry { if let Ok(mut counts) = self.hook_timeouts.lock() { counts.remove(&server_name); } - if !r.is_error && !r.text.trim().is_empty() { - indexed.push((idx, server_name, r.text)); + if !r.is_error && !r.text().trim().is_empty() { + indexed.push((idx, server_name, r.text())); } } Ok((_idx, server_name, Err(_elapsed))) => { @@ -594,15 +594,18 @@ impl McpRegistry { // Server is healthy — it correctly rejected bad input. Return to LLM. return Ok(ToolResult { provider_id: provider_id.to_owned(), - text: clamp(format!("Tool call rejected: {e}"), max_bytes), + content: vec![ToolResultContent::Text(clamp( + format!("Tool call rejected: {e}"), + max_bytes, + ))], is_error: true, }); } }; - let text = collapse_content(&res.content, max_bytes); + let content = tool_result_content(&res.content, max_bytes); Ok(ToolResult { provider_id: provider_id.to_owned(), - text: clamp(text, max_bytes), + content, is_error: res.is_error.unwrap_or(false), }) } @@ -837,46 +840,133 @@ fn push_bounded(out: &mut String, s: &str, max: usize) { } } -fn collapse_content(blocks: &[rmcp::model::Content], max_bytes: usize) -> String { +fn tool_result_content( + blocks: &[rmcp::model::Content], + max_bytes: usize, +) -> Vec { use rmcp::model::RawContent; - let mut out = String::new(); + let mut out = Vec::new(); + let mut text = String::new(); + let mut used = 0usize; let mut truncated = false; let short = |s: &str| truncate_at_boundary(s, MARKER_FIELD_MAX).to_owned(); + + let flush_text = |out: &mut Vec, text: &mut String, used: &mut usize| { + if !text.is_empty() { + *used = used.saturating_add(text.len()); + out.push(ToolResultContent::Text(std::mem::take(text))); + } + }; + + let text_budget = + |used: usize, text: &str| max_bytes.saturating_sub(used).saturating_sub(text.len()); + for c in blocks { - if out.len() >= max_bytes { + if used + text.len() >= max_bytes { truncated = true; break; } - if !out.is_empty() { - push_bounded(&mut out, "\n", max_bytes); - } - let chunk: String = match &c.raw { - RawContent::Text(t) => t.text.clone(), + match &c.raw { + RawContent::Text(t) => { + if !text.is_empty() { + let max = text_budget(used, &text); + push_bounded(&mut text, "\n", max); + } + let before = text.len(); + let max = text_budget(used, &text); + push_bounded(&mut text, &t.text, max); + if text.len() - before < t.text.len() { + truncated = true; + } + } RawContent::Image(i) => { - format!( - "[image elided: {}, {} bytes]", - short(&i.mime_type), - i.data.len() - ) + flush_text(&mut out, &mut text, &mut used); + let image_bytes = i.data.len().saturating_add(i.mime_type.len()); + if used.saturating_add(image_bytes) <= max_bytes { + used = used.saturating_add(image_bytes); + out.push(ToolResultContent::Image { + data: i.data.clone(), + mime_type: i.mime_type.clone(), + }); + } else { + truncated = true; + let marker = format!( + "[image elided: {}, {} base64 bytes exceeds remaining tool-result budget]", + short(&i.mime_type), + i.data.len() + ); + let max = max_bytes.saturating_sub(used); + push_bounded(&mut text, &marker, max); + } } RawContent::Audio(a) => { - format!( + if !text.is_empty() { + let max = text_budget(used, &text); + push_bounded(&mut text, "\n", max); + } + let chunk = format!( "[audio elided: {}, {} bytes]", short(&a.mime_type), a.data.len() - ) + ); + let max = text_budget(used, &text); + push_bounded(&mut text, &chunk, max); + } + RawContent::ResourceLink(r) => { + if !text.is_empty() { + let max = text_budget(used, &text); + push_bounded(&mut text, "\n", max); + } + let chunk = format!("[resource: {}]", short(&r.uri)); + let max = text_budget(used, &text); + push_bounded(&mut text, &chunk, max); + } + RawContent::Resource(_) => { + if !text.is_empty() { + let max = text_budget(used, &text); + push_bounded(&mut text, "\n", max); + } + let max = text_budget(used, &text); + push_bounded(&mut text, "[resource elided]", max); } - RawContent::ResourceLink(r) => format!("[resource: {}]", short(&r.uri)), - RawContent::Resource(_) => "[resource elided]".into(), - }; - let before = out.len(); - push_bounded(&mut out, &chunk, max_bytes); - if out.len() - before < chunk.len() { - truncated = true; } } if truncated { - out.push_str("\n[content truncated]"); + let max = text_budget(used, &text); + push_bounded(&mut text, "\n[content truncated]", max); } + flush_text(&mut out, &mut text, &mut used); out } + +#[cfg(test)] +mod content_tests { + use super::*; + use rmcp::model::Content; + + #[test] + fn tool_result_content_preserves_images() { + let blocks = vec![ + Content::text("header"), + Content::image("aW1n", "image/png"), + Content::text("tail"), + ]; + let out = tool_result_content(&blocks, 1024); + assert_eq!(out.len(), 3); + assert!(matches!(&out[0], ToolResultContent::Text(t) if t == "header")); + assert!(matches!( + &out[1], + ToolResultContent::Image { data, mime_type } + if data == "aW1n" && mime_type == "image/png" + )); + assert!(matches!(&out[2], ToolResultContent::Text(t) if t == "tail")); + } + + #[test] + fn tool_result_content_elides_images_over_budget() { + let blocks = vec![Content::image("a".repeat(300), "image/png")]; + let out = tool_result_content(&blocks, 256); + assert_eq!(out.len(), 1); + assert!(matches!(&out[0], ToolResultContent::Text(t) if t.contains("image elided"))); + } +} diff --git a/crates/sprout-agent/src/types.rs b/crates/sprout-agent/src/types.rs index d05ec51cc..a0905b444 100644 --- a/crates/sprout-agent/src/types.rs +++ b/crates/sprout-agent/src/types.rs @@ -1,6 +1,33 @@ use serde::Deserialize; use serde_json::Value; +#[derive(Debug, Clone)] +pub enum ToolResultContent { + Text(String), + Image { data: String, mime_type: String }, +} + +impl ToolResultContent { + pub fn estimated_bytes(&self) -> usize { + match self { + Self::Text(s) => s.len(), + // This is request-size pressure accounting, not a visual-token + // estimate. Count the base64 bytes we will actually serialize so + // image-heavy sessions cannot silently exceed provider/body caps. + Self::Image { data, mime_type } => data.len() + mime_type.len(), + } + } + + pub fn as_text_lossy(&self) -> String { + match self { + Self::Text(s) => s.clone(), + Self::Image { data, mime_type } => { + format!("[image: {mime_type}, {} base64 bytes]", data.len()) + } + } + } +} + #[derive(Debug, Clone)] pub enum HistoryItem { User(String), @@ -28,7 +55,13 @@ impl HistoryItem { }) .sum::() } - Self::ToolResult(r) => r.provider_id.len() + r.text.len(), + Self::ToolResult(r) => { + r.provider_id.len() + + r.content + .iter() + .map(ToolResultContent::estimated_bytes) + .sum::() + } } } } @@ -43,10 +76,20 @@ pub struct ToolCall { #[derive(Debug, Clone)] pub struct ToolResult { pub provider_id: String, - pub text: String, + pub content: Vec, pub is_error: bool, } +impl ToolResult { + pub fn text(&self) -> String { + self.content + .iter() + .map(ToolResultContent::as_text_lossy) + .collect::>() + .join("\n") + } +} + #[derive(Debug, Clone)] pub struct LlmResponse { pub text: String, diff --git a/crates/sprout-dev-mcp/Cargo.toml b/crates/sprout-dev-mcp/Cargo.toml index 162d228a9..574ede553 100644 --- a/crates/sprout-dev-mcp/Cargo.toml +++ b/crates/sprout-dev-mcp/Cargo.toml @@ -26,6 +26,11 @@ tempfile = "3" ignore = "0.4.25" tracing = { workspace = true } tracing-subscriber = { workspace = true } +# view_image tool: HTTP fetch (workspace reqwest is already used by sprout-cli; +# adding it here is essentially free), base64 encoding, and decode/resize. +reqwest = { workspace = true } +base64 = "0.22" +image = { version = "0.25", default-features = false, features = ["jpeg", "png", "gif", "webp"] } [target.'cfg(unix)'.dependencies] nix = { version = "0.31", default-features = false, features = ["signal", "process"] } diff --git a/crates/sprout-dev-mcp/src/main.rs b/crates/sprout-dev-mcp/src/main.rs index 617ca279d..b163f5a25 100644 --- a/crates/sprout-dev-mcp/src/main.rs +++ b/crates/sprout-dev-mcp/src/main.rs @@ -9,12 +9,14 @@ use rmcp::{ use std::path::Path; use std::sync::Arc; +mod paths; mod rg; mod shell; mod shim; mod str_replace; mod todo; mod tree; +mod view_image; #[derive(Clone)] struct DevMcp { @@ -45,6 +47,17 @@ impl DevMcp { shell::run(&self.state, p, context.ct).await } + #[tool( + name = "view_image", + description = "Load an image from a file path, http(s) URL, or data: URL and return it as an MCP image content block that multimodal LLMs (Anthropic, OpenAI-compatible, etc.) can see. Resizes to a longest-edge of 1568px by default (override with `max_dim`, range 64..=2048). Pass-through for already-small PNG/JPEG; transcodes oversize input to PNG (if alpha) or JPEG q85. Animated GIF/WebP rejected — provide a still frame. Hard cap 20 MiB source, ~4 MiB on the wire. Relative paths resolve under `workdir` (defaults to server cwd) and may not escape it." + )] + async fn view_image( + &self, + Parameters(p): Parameters, + ) -> Result { + view_image::run(&self.state, p).await + } + #[tool( name = "str_replace", description = "Atomic find-and-replace in a file. old_str must occur exactly once. Returns a unified diff. Path resolved relative to workdir (defaults to server cwd). Prefer over sed/awk." diff --git a/crates/sprout-dev-mcp/src/paths.rs b/crates/sprout-dev-mcp/src/paths.rs new file mode 100644 index 000000000..3fb605097 --- /dev/null +++ b/crates/sprout-dev-mcp/src/paths.rs @@ -0,0 +1,64 @@ +//! Path resolution shared across dev-mcp tools. +//! +//! `resolve_within` canonicalises a user-supplied path against a workspace +//! root and rejects any result that escapes the root (e.g. via `..`, absolute +//! paths, or symlinks). All tools that touch the filesystem must funnel +//! through this helper so the escape policy stays consistent. + +use std::path::{Path, PathBuf}; + +/// Resolve `path` (absolute or relative) against `root` and require the +/// canonicalised result to live under the canonicalised `root`. Returns an +/// error string suitable for `ErrorData::invalid_params` on rejection. +pub(crate) fn resolve_within(root: &Path, path: &str) -> Result { + let raw = Path::new(path); + let candidate: PathBuf = if raw.is_absolute() { + raw.to_path_buf() + } else { + root.join(raw) + }; + + let root_canon = std::fs::canonicalize(root) + .map_err(|e| format!("workdir not accessible: {} ({e})", root.display()))?; + + let resolved = std::fs::canonicalize(&candidate) + .map_err(|e| format!("path not accessible: {} ({e})", candidate.display()))?; + + if !resolved.starts_with(&root_canon) { + return Err(format!( + "path escapes workspace: {} not within {}", + resolved.display(), + root_canon.display() + )); + } + Ok(resolved) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::fs; + use tempfile::tempdir; + + #[test] + fn resolve_within_rejects_escape() { + let dir = tempdir().expect("tempdir"); + let inside = dir.path().join("file.txt"); + fs::write(&inside, b"x").expect("write"); + // Symlink targeting outside the dir should be rejected. + #[cfg(unix)] + { + let outside = std::env::temp_dir().join("sprout-mcp-paths-escape-target"); + let _ = fs::remove_file(&outside); + fs::write(&outside, b"y").expect("write outside"); + let link = dir.path().join("link.txt"); + std::os::unix::fs::symlink(&outside, &link).expect("symlink"); + let err = resolve_within(dir.path(), "link.txt").unwrap_err(); + assert!(err.contains("escapes workspace"), "got: {err}"); + let _ = fs::remove_file(&outside); + } + // Resolves a normal path inside. + let p = resolve_within(dir.path(), "file.txt").expect("resolve"); + assert!(p.ends_with("file.txt")); + } +} diff --git a/crates/sprout-dev-mcp/src/str_replace.rs b/crates/sprout-dev-mcp/src/str_replace.rs index 0a788bd0e..0251c97dc 100644 --- a/crates/sprout-dev-mcp/src/str_replace.rs +++ b/crates/sprout-dev-mcp/src/str_replace.rs @@ -154,29 +154,7 @@ pub fn run(state: &SharedState, p: StrReplaceParams) -> Result Result { - let raw = Path::new(path); - let candidate: PathBuf = if raw.is_absolute() { - raw.to_path_buf() - } else { - root.join(raw) - }; - - let root_canon = std::fs::canonicalize(root) - .map_err(|e| format!("workdir not accessible: {} ({e})", root.display()))?; - - let resolved = std::fs::canonicalize(&candidate) - .map_err(|e| format!("path not accessible: {} ({e})", candidate.display()))?; - - if !resolved.starts_with(&root_canon) { - return Err(format!( - "path escapes workspace: {} not within {}", - resolved.display(), - root_canon.display() - )); - } - Ok(resolved) -} +pub(crate) use crate::paths::resolve_within; pub(crate) fn count_occurrences_capped(text: &str, pattern: &str) -> usize { if pattern.is_empty() { @@ -301,28 +279,6 @@ mod tests { assert_eq!(count_occurrences_capped("abc", ""), 0); } - #[test] - fn resolve_within_rejects_escape() { - let dir = tempdir().expect("tempdir"); - let inside = dir.path().join("file.txt"); - fs::write(&inside, b"x").expect("write"); - // Symlink targeting outside the dir should be rejected. - #[cfg(unix)] - { - let outside = std::env::temp_dir().join("sprout-mcp-escape-target"); - let _ = fs::remove_file(&outside); - fs::write(&outside, b"y").expect("write outside"); - let link = dir.path().join("link.txt"); - std::os::unix::fs::symlink(&outside, &link).expect("symlink"); - let err = resolve_within(dir.path(), "link.txt").unwrap_err(); - assert!(err.contains("escapes workspace"), "got: {err}"); - let _ = fs::remove_file(&outside); - } - // Resolves a normal path inside. - let p = resolve_within(dir.path(), "file.txt").expect("resolve"); - assert!(p.ends_with("file.txt")); - } - fn make_state(cwd: &std::path::Path) -> SharedState { let shim = crate::shim::Shim::install().expect("shim install"); SharedState::new(cwd.to_path_buf(), shim).expect("state new") diff --git a/crates/sprout-dev-mcp/src/view_image.rs b/crates/sprout-dev-mcp/src/view_image.rs new file mode 100644 index 000000000..5221ca942 --- /dev/null +++ b/crates/sprout-dev-mcp/src/view_image.rs @@ -0,0 +1,934 @@ +//! `view_image` MCP tool — load an image from a path, http(s) URL, or +//! `data:` URL and return it as an MCP `image` content block that any +//! multimodal-capable host (Anthropic, OpenAI-compatible, etc.) can forward +//! to its model. +//! +//! Design goals: tiny surface, no protocol-specific branching, and a +//! "reasonable resolution" that fits comfortably inside both Anthropic's +//! recommended ≤1568px / ≤5 MiB image budget and OpenAI's high-detail tile +//! size sweet spot. The MCP host translates `Content::image(data, mime)` +//! into the right provider-native shape on our behalf (see Goose's +//! `providers::utils::convert_image` for a reference implementation). + +use crate::paths::resolve_within; +use crate::shell::SharedState; +use base64::Engine; +use image::{ + codecs::{jpeg::JpegEncoder, png::PngEncoder}, + DynamicImage, ExtendedColorType, ImageEncoder, ImageReader, Limits, +}; +use rmcp::{ + model::{CallToolResult, Content}, + ErrorData, +}; +use schemars::JsonSchema; +use serde::Deserialize; +use std::io::Cursor; +use std::path::PathBuf; +use std::time::Duration; + +/// Hard cap on bytes we will read from disk / URL / data: URL. +pub(crate) const MAX_SOURCE_BYTES: usize = 20 * 1024 * 1024; +/// Hard cap on the raw (pre-base64) bytes we emit. base64 expands by 4/3, so +/// a 3 MiB raw payload becomes ~4 MiB on the wire — comfortably below +/// Anthropic's 5 MiB-per-image limit. +pub(crate) const MAX_FINAL_RAW_BYTES: usize = 3 * 1024 * 1024; +/// Default longest-edge cap. Matches Anthropic's published recommendation +/// (≤1568px) and lands well inside OpenAI's high-detail tile budget. +pub(crate) const DEFAULT_MAX_DIM: u32 = 1568; +pub(crate) const MIN_MAX_DIM: u32 = 64; +pub(crate) const MAX_MAX_DIM: u32 = 2048; +/// Hard cap on decoded pixel count. A ≤20 MiB compressed source can decode +/// to hundreds of megabytes; we reject anything above this budget *before* +/// touching the decoder. 64 megapixels is generous (e.g. 8000×8000) yet +/// keeps worst-case allocation well under a gigabyte. +pub(crate) const MAX_PIXELS: u64 = 64 * 1024 * 1024; +/// Defence-in-depth for the `image` decoder: bound any single allocation it +/// performs to 256 MiB (the default is 512 MiB and skews high for a dev MCP). +pub(crate) const MAX_DECODER_ALLOC: u64 = 256 * 1024 * 1024; +/// Connect + read timeout for URL fetches. +const FETCH_TIMEOUT: Duration = Duration::from_secs(10); + +/// Build the decoder allocation cap. Centralised so the resize path uses the +/// same value tests can reason about. +fn decode_limits() -> Limits { + let mut l = Limits::default(); + l.max_alloc = Some(MAX_DECODER_ALLOC); + l +} + +#[derive(Debug, Deserialize, JsonSchema)] +pub struct ViewImageParams { + /// Image source: an absolute or workspace-relative file path, + /// an `http://` / `https://` URL, or a `data:image/;base64,...` URL. + pub source: String, + /// Optional longest-edge cap in pixels. Clamped to [64, 2048]. + /// Defaults to 1568, which fits Anthropic's recommended budget and + /// OpenAI's high-detail tile size. + #[serde(default)] + pub max_dim: Option, + /// Workspace root for relative path resolution. Ignored for URL sources. + /// Defaults to the server's cwd. + #[serde(default)] + pub workdir: Option, +} + +/// What `view_image` returns: the (mime, raw bytes) we will base64-encode +/// into an MCP image content block, plus a short human-readable summary. +#[derive(Debug)] +struct PreparedImage { + mime: &'static str, + bytes: Vec, + summary: String, +} + +pub async fn run(state: &SharedState, p: ViewImageParams) -> Result { + let max_dim = p + .max_dim + .unwrap_or(DEFAULT_MAX_DIM) + .clamp(MIN_MAX_DIM, MAX_MAX_DIM); + + let (raw, source_label) = load_source(state, &p).await?; + let prepared = prepare(&raw, max_dim).map_err(invalid_params)?; + + let encoded = base64::engine::general_purpose::STANDARD.encode(&prepared.bytes); + let header = format!( + "{} ({} from {source_label})", + prepared.summary, prepared.mime + ); + + Ok(CallToolResult::success(vec![ + Content::text(header), + Content::image(encoded, prepared.mime.to_string()), + ])) +} + +fn invalid_params(msg: String) -> ErrorData { + ErrorData::invalid_params(msg, None) +} + +/// Fetch the source bytes from path / http(s) / data URL. +async fn load_source( + state: &SharedState, + p: &ViewImageParams, +) -> Result<(Vec, String), ErrorData> { + let src = p.source.trim(); + if src.starts_with("data:") { + let bytes = decode_data_url(src).map_err(invalid_params)?; + // `decode_data_url` enforces an encoded-length precheck so we never + // allocate past the source cap. Re-verify the decoded length for + // belt-and-braces. + if bytes.len() > MAX_SOURCE_BYTES { + return Err(invalid_params(format!( + "data: URL decoded to {} bytes (limit {} bytes)", + bytes.len(), + MAX_SOURCE_BYTES + ))); + } + Ok((bytes, "data:URL".to_string())) + } else if src.starts_with("http://") || src.starts_with("https://") { + let bytes = fetch_url(src).await?; + Ok((bytes, src.to_string())) + } else if src.contains("://") { + // Treat any other `scheme://...` form as an explicit reject so + // `ftp://...` doesn't accidentally become a filesystem path. + Err(invalid_params(format!( + "unsupported URL scheme in `source`: {src}", + ))) + } else { + let workspace_root = match p.workdir.as_deref() { + Some(w) => PathBuf::from(w), + None => state.cwd.clone(), + }; + let target = resolve_within(&workspace_root, src).map_err(invalid_params)?; + let meta = std::fs::metadata(&target).map_err(|e| { + ErrorData::internal_error(format!("cannot stat {}: {e}", target.display()), None) + })?; + if !meta.is_file() { + return Err(invalid_params(format!( + "not a regular file: {}", + target.display() + ))); + } + if meta.len() as usize > MAX_SOURCE_BYTES { + return Err(invalid_params(format!( + "file too large: {} is {} bytes (limit {} bytes)", + target.display(), + meta.len(), + MAX_SOURCE_BYTES + ))); + } + // Use `take(cap + 1)` so a file that grows between the metadata + // check and the read still cannot exceed our budget. The +1 + // distinguishes "exactly at cap" from "grew past cap". + let file = std::fs::File::open(&target).map_err(|e| { + ErrorData::internal_error(format!("cannot open {}: {e}", target.display()), None) + })?; + let mut bytes = Vec::with_capacity(meta.len() as usize); + use std::io::Read; + file.take(MAX_SOURCE_BYTES as u64 + 1) + .read_to_end(&mut bytes) + .map_err(|e| { + ErrorData::internal_error(format!("cannot read {}: {e}", target.display()), None) + })?; + if bytes.len() > MAX_SOURCE_BYTES { + return Err(invalid_params(format!( + "file {} grew past {} byte cap during read", + target.display(), + MAX_SOURCE_BYTES + ))); + } + Ok((bytes, target.display().to_string())) + } +} + +/// Parse `data:image/[;base64],`. Only base64 payloads are +/// accepted — percent-encoded data URLs add surface area for no real benefit. +fn decode_data_url(src: &str) -> Result, String> { + let rest = src + .strip_prefix("data:") + .ok_or_else(|| "not a data: URL".to_string())?; + let (meta, payload) = rest + .split_once(',') + .ok_or_else(|| "malformed data: URL (no comma)".to_string())?; + // meta is "[;param=value]*[;base64]" + let mut parts = meta.split(';'); + let mime = parts.next().unwrap_or(""); + if !mime.starts_with("image/") { + return Err(format!("data: URL is not an image (got `{mime}`)")); + } + let is_base64 = parts.any(|p| p.eq_ignore_ascii_case("base64")); + if !is_base64 { + return Err( + "data: URL must be base64-encoded (non-base64 / percent-encoded forms are not supported)" + .to_string(), + ); + } + let payload = payload.trim(); + // Pre-check encoded length so we never allocate past the source cap. + // 4 base64 chars encode 3 raw bytes; ceil-divide MAX_SOURCE_BYTES. + let max_encoded = MAX_SOURCE_BYTES.div_ceil(3) * 4 + 4; // +4 absorbs padding rounding + if payload.len() > max_encoded { + return Err(format!( + "data: URL payload is {} base64 chars (limit ~{} = {} raw bytes)", + payload.len(), + max_encoded, + MAX_SOURCE_BYTES + )); + } + base64::engine::general_purpose::STANDARD + .decode(payload) + .map_err(|e| format!("data: URL base64 decode failed: {e}")) +} + +/// Fetch an http(s) URL with a streaming read and a hard byte cap. +/// Refuses up-front if `Content-Length` advertises more than the cap. +async fn fetch_url(url: &str) -> Result, ErrorData> { + let client = reqwest::Client::builder() + .connect_timeout(FETCH_TIMEOUT) + .timeout(FETCH_TIMEOUT) + .build() + .map_err(|e| ErrorData::internal_error(format!("http client init failed: {e}"), None))?; + let resp = client + .get(url) + .send() + .await + .map_err(|e| ErrorData::internal_error(format!("fetch failed: {url} ({e})"), None))?; + if !resp.status().is_success() { + return Err(invalid_params(format!( + "fetch {url} returned HTTP {}", + resp.status() + ))); + } + if let Some(len) = resp.content_length() { + if len as usize > MAX_SOURCE_BYTES { + return Err(invalid_params(format!( + "remote image too large: Content-Length {} bytes (limit {})", + len, MAX_SOURCE_BYTES + ))); + } + } + let mut buf: Vec = Vec::new(); + let mut stream = resp; + loop { + let chunk = stream + .chunk() + .await + .map_err(|e| ErrorData::internal_error(format!("fetch read failed: {e}"), None))?; + match chunk { + Some(bytes) => { + if buf.len() + bytes.len() > MAX_SOURCE_BYTES { + return Err(invalid_params(format!( + "remote image exceeded {} byte cap mid-stream", + MAX_SOURCE_BYTES + ))); + } + buf.extend_from_slice(&bytes); + } + None => break, + } + } + Ok(buf) +} + +/// Sniff the image format from magic bytes alone (do not trust extensions +/// or `Content-Type`). Returns the canonical MIME type. +fn sniff_mime(bytes: &[u8]) -> Result<&'static str, String> { + // PNG: 89 50 4E 47 0D 0A 1A 0A + if bytes.starts_with(b"\x89PNG\r\n\x1a\n") { + return Ok("image/png"); + } + // JPEG: FF D8 FF + if bytes.len() >= 3 && bytes[0..3] == [0xFF, 0xD8, 0xFF] { + return Ok("image/jpeg"); + } + // GIF87a / GIF89a + if bytes.starts_with(b"GIF87a") || bytes.starts_with(b"GIF89a") { + return Ok("image/gif"); + } + // WebP: "RIFF" .... "WEBP" + if bytes.len() >= 12 && &bytes[0..4] == b"RIFF" && &bytes[8..12] == b"WEBP" { + return Ok("image/webp"); + } + Err("unsupported image format (recognised: png, jpeg, gif, webp)".to_string()) +} + +/// Detect animated GIF (≥2 image descriptors) or animated WebP (VP8X chunk +/// with ANIM bit set). We refuse animated input outright rather than +/// silently emit a first-frame still. +/// +/// **Important**: both branches must be allocation-free byte-level scans. +/// Using the `image` crate's `GifDecoder::into_frames()` here would let an +/// attacker-controlled logical-screen size trigger a multi-GB RGBA buffer +/// before our pixel-count cap fires. +fn is_animated(bytes: &[u8], mime: &str) -> bool { + match mime { + "image/gif" => gif_has_two_image_descriptors(bytes), + "image/webp" => { + // Animated WebP files always use the extended (VP8X) container. + // The animation bit is bit 1 of the flags byte at offset 20. + if bytes.len() < 21 { + return false; + } + if &bytes[12..16] != b"VP8X" { + return false; + } + (bytes[20] & 0x02) != 0 + } + _ => false, + } +} + +/// Scan a GIF byte stream and report whether it contains ≥2 image descriptors +/// (frames). Does not allocate decode buffers — walks the block structure +/// described in the GIF89a spec and bails on the second `0x2C` separator. +fn gif_has_two_image_descriptors(bytes: &[u8]) -> bool { + // 6-byte header ("GIF87a"/"GIF89a") + 7-byte logical screen descriptor. + if bytes.len() < 13 { + return false; + } + let packed = bytes[10]; + let has_gct = (packed & 0x80) != 0; + let gct_size = if has_gct { + 3 * (1u32 << ((packed & 0x07) + 1)) + } else { + 0 + }; + let mut i = 13usize + gct_size as usize; + let mut frames = 0u32; + while let Some(&b) = bytes.get(i) { + i += 1; + match b { + 0x3B => return frames >= 2, // trailer + 0x21 => { + // Extension introducer: