Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 69 additions & 9 deletions codex-rs/app-server/src/codex_message_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,12 @@ impl CodexMessageProcessor {
Ok((review_request, hint))
}

pub async fn process_request(&mut self, connection_id: ConnectionId, request: ClientRequest) {
pub async fn process_request(
&mut self,
connection_id: ConnectionId,
request: ClientRequest,
app_server_client_name: Option<String>,
) {
let to_connection_request_id = |request_id| ConnectionRequestId {
connection_id,
request_id,
Expand Down Expand Up @@ -647,8 +652,12 @@ impl CodexMessageProcessor {
.await;
}
ClientRequest::TurnStart { request_id, params } => {
self.turn_start(to_connection_request_id(request_id), params)
.await;
self.turn_start(
to_connection_request_id(request_id),
params,
app_server_client_name.clone(),
)
.await;
}
ClientRequest::TurnSteer { request_id, params } => {
self.turn_steer(to_connection_request_id(request_id), params)
Expand Down Expand Up @@ -767,12 +776,20 @@ impl CodexMessageProcessor {
.await;
}
ClientRequest::SendUserMessage { request_id, params } => {
self.send_user_message(to_connection_request_id(request_id), params)
.await;
self.send_user_message(
to_connection_request_id(request_id),
params,
app_server_client_name.clone(),
)
.await;
}
ClientRequest::SendUserTurn { request_id, params } => {
self.send_user_turn(to_connection_request_id(request_id), params)
.await;
self.send_user_turn(
to_connection_request_id(request_id),
params,
app_server_client_name.clone(),
)
.await;
}
ClientRequest::InterruptConversation { request_id, params } => {
self.interrupt_conversation(to_connection_request_id(request_id), params)
Expand Down Expand Up @@ -5062,6 +5079,7 @@ impl CodexMessageProcessor {
&self,
request_id: ConnectionRequestId,
params: SendUserMessageParams,
app_server_client_name: Option<String>,
) {
let SendUserMessageParams {
conversation_id,
Expand All @@ -5080,6 +5098,12 @@ impl CodexMessageProcessor {
self.outgoing.send_error(request_id, error).await;
return;
};
if let Err(error) =
Self::set_app_server_client_name(conversation.as_ref(), app_server_client_name).await
{
self.outgoing.send_error(request_id, error).await;
return;
}

let mapped_items: Vec<CoreInputItem> = items
.into_iter()
Expand Down Expand Up @@ -5110,7 +5134,12 @@ impl CodexMessageProcessor {
.await;
}

async fn send_user_turn(&self, request_id: ConnectionRequestId, params: SendUserTurnParams) {
async fn send_user_turn(
&self,
request_id: ConnectionRequestId,
params: SendUserTurnParams,
app_server_client_name: Option<String>,
) {
let SendUserTurnParams {
conversation_id,
items,
Expand All @@ -5136,6 +5165,12 @@ impl CodexMessageProcessor {
self.outgoing.send_error(request_id, error).await;
return;
};
if let Err(error) =
Self::set_app_server_client_name(conversation.as_ref(), app_server_client_name).await
{
self.outgoing.send_error(request_id, error).await;
return;
}

let mapped_items: Vec<CoreInputItem> = items
.into_iter()
Expand Down Expand Up @@ -5607,7 +5642,12 @@ impl CodexMessageProcessor {
let _ = conversation.submit(Op::Interrupt).await;
}

async fn turn_start(&self, request_id: ConnectionRequestId, params: TurnStartParams) {
async fn turn_start(
&self,
request_id: ConnectionRequestId,
params: TurnStartParams,
app_server_client_name: Option<String>,
) {
if let Err(error) = Self::validate_v2_input_limit(&params.input) {
self.outgoing.send_error(request_id, error).await;
return;
Expand All @@ -5619,6 +5659,12 @@ impl CodexMessageProcessor {
return;
}
};
if let Err(error) =
Self::set_app_server_client_name(thread.as_ref(), app_server_client_name).await
{
self.outgoing.send_error(request_id, error).await;
return;
}

let collaboration_modes_config = CollaborationModesConfig {
default_mode_request_user_input: thread.enabled(Feature::DefaultModeRequestUserInput),
Expand Down Expand Up @@ -5700,6 +5746,20 @@ impl CodexMessageProcessor {
}
}

async fn set_app_server_client_name(
thread: &CodexThread,
app_server_client_name: Option<String>,
) -> Result<(), JSONRPCErrorError> {
thread
.set_app_server_client_name(app_server_client_name)
.await
.map_err(|err| JSONRPCErrorError {
code: INTERNAL_ERROR_CODE,
message: format!("failed to set app server client name: {err}"),
data: None,
})
}

async fn turn_steer(&self, request_id: ConnectionRequestId, params: TurnSteerParams) {
let (_, thread) = match self.load_thread(&params.thread_id).await {
Ok(v) => v,
Expand Down
4 changes: 3 additions & 1 deletion codex-rs/app-server/src/message_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ pub(crate) struct ConnectionSessionState {
pub(crate) initialized: bool,
pub(crate) experimental_api_enabled: bool,
pub(crate) opted_out_notification_methods: HashSet<String>,
pub(crate) app_server_client_name: Option<String>,
}

pub(crate) struct MessageProcessorArgs {
Expand Down Expand Up @@ -329,6 +330,7 @@ impl MessageProcessor {
if let Ok(mut suffix) = USER_AGENT_SUFFIX.lock() {
*suffix = Some(user_agent_suffix);
}
session.app_server_client_name = Some(name.clone());

let user_agent = get_codex_user_agent();
let response = InitializeResponse { user_agent };
Expand Down Expand Up @@ -430,7 +432,7 @@ impl MessageProcessor {
// inline the full `CodexMessageProcessor::process_request` future, which
// can otherwise push worker-thread stack usage over the edge.
self.codex_message_processor
.process_request(connection_id, other)
.process_request(connection_id, other, session.app_server_client_name.clone())
.boxed()
.await;
}
Expand Down
103 changes: 103 additions & 0 deletions codex-rs/app-server/tests/suite/v2/initialize.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
use anyhow::Result;
use app_test_support::McpProcess;
use app_test_support::create_final_assistant_message_sse_response;
use app_test_support::create_mock_responses_server_sequence_unchecked;
use app_test_support::to_response;
use codex_app_server_protocol::ClientInfo;
use codex_app_server_protocol::InitializeCapabilities;
use codex_app_server_protocol::InitializeResponse;
use codex_app_server_protocol::JSONRPCMessage;
use codex_app_server_protocol::JSONRPCResponse;
use codex_app_server_protocol::RequestId;
use codex_app_server_protocol::ThreadStartParams;
use codex_app_server_protocol::ThreadStartResponse;
use codex_app_server_protocol::TurnStartParams;
use codex_app_server_protocol::TurnStartResponse;
use codex_app_server_protocol::UserInput as V2UserInput;
use core_test_support::fs_wait;
use pretty_assertions::assert_eq;
use serde_json::Value;
use std::path::Path;
use std::time::Duration;
use tempfile::TempDir;
use tokio::time::timeout;

Expand Down Expand Up @@ -178,11 +186,100 @@ async fn initialize_opt_out_notification_methods_filters_notifications() -> Resu
Ok(())
}

#[tokio::test]
async fn turn_start_notify_payload_includes_initialize_client_name() -> Result<()> {
let responses = vec![create_final_assistant_message_sse_response("Done")?];
let server = create_mock_responses_server_sequence_unchecked(responses).await;
let codex_home = TempDir::new()?;
let notify_script = codex_home.path().join("notify.py");
std::fs::write(
&notify_script,
r#"from pathlib import Path
import sys

Path(__file__).with_name("notify.json").write_text(sys.argv[-1], encoding="utf-8")
"#,
)?;
let notify_file = codex_home.path().join("notify.json");
let notify_script = notify_script
.to_str()
.expect("notify script path should be valid UTF-8");
create_config_toml_with_extra(
codex_home.path(),
&server.uri(),
"never",
&format!(
"notify = [\"python3\", {}]",
toml_basic_string(notify_script)
),
)?;

let mut mcp = McpProcess::new(codex_home.path()).await?;
timeout(
DEFAULT_READ_TIMEOUT,
mcp.initialize_with_client_info(ClientInfo {
name: "xcode".to_string(),
title: Some("Xcode".to_string()),
version: "1.0.0".to_string(),
}),
)
.await??;

let thread_req = mcp
.send_thread_start_request(ThreadStartParams::default())
.await?;
let thread_resp: JSONRPCResponse = timeout(
DEFAULT_READ_TIMEOUT,
mcp.read_stream_until_response_message(RequestId::Integer(thread_req)),
)
.await??;
let ThreadStartResponse { thread, .. } = to_response(thread_resp)?;

let turn_req = mcp
.send_turn_start_request(TurnStartParams {
thread_id: thread.id,
input: vec![V2UserInput::Text {
text: "Hello".to_string(),
text_elements: Vec::new(),
}],
..Default::default()
})
.await?;
let turn_resp: JSONRPCResponse = timeout(
DEFAULT_READ_TIMEOUT,
mcp.read_stream_until_response_message(RequestId::Integer(turn_req)),
)
.await??;
let _: TurnStartResponse = to_response(turn_resp)?;

timeout(
DEFAULT_READ_TIMEOUT,
mcp.read_stream_until_notification_message("turn/completed"),
)
.await??;

fs_wait::wait_for_path_exists(&notify_file, Duration::from_secs(5)).await?;
let payload_raw = tokio::fs::read_to_string(&notify_file).await?;
let payload: Value = serde_json::from_str(&payload_raw)?;
assert_eq!(payload["client"], "xcode");

Ok(())
}

// Helper to create a config.toml pointing at the mock model server.
fn create_config_toml(
codex_home: &Path,
server_uri: &str,
approval_policy: &str,
) -> std::io::Result<()> {
create_config_toml_with_extra(codex_home, server_uri, approval_policy, "")
}

fn create_config_toml_with_extra(
codex_home: &Path,
server_uri: &str,
approval_policy: &str,
extra: &str,
) -> std::io::Result<()> {
let config_toml = codex_home.join("config.toml");
std::fs::write(
Expand All @@ -195,6 +292,8 @@ sandbox_mode = "read-only"

model_provider = "mock_provider"

{extra}

[model_providers.mock_provider]
name = "Mock provider for test"
base_url = "{server_uri}/v1"
Expand All @@ -205,3 +304,7 @@ stream_max_retries = 0
),
)
}

fn toml_basic_string(value: &str) -> String {
format!("\"{}\"", value.replace('\\', "\\\\").replace('"', "\\\""))
}
Loading
Loading