diff --git a/crates/navigator-cli/src/run.rs b/crates/navigator-cli/src/run.rs index 72ef6277..5850e998 100644 --- a/crates/navigator-cli/src/run.rs +++ b/crates/navigator-cli/src/run.rs @@ -1983,7 +1983,7 @@ fn inferred_provider_type(command: &[String]) -> Option { /// /// Returns a deduplicated list of provider **names** suitable for /// `SandboxSpec.providers`. -async fn ensure_required_providers( +pub async fn ensure_required_providers( client: &mut NavigatorClient, explicit_names: &[String], inferred_types: &[String], @@ -1996,20 +1996,14 @@ async fn ensure_required_providers( let mut configured_names: Vec = Vec::new(); let mut seen_names: HashSet = HashSet::new(); - // ── Explicit provider names (validated server-side) ─────────────────── - for name in explicit_names { - if seen_names.insert(name.clone()) { - configured_names.push(name.clone()); - } - } - - // ── Resolve inferred provider types ─────────────────────────────────── - if !inferred_types.is_empty() { - // Map from lowercase type -> first provider name found with that type. - let mut type_to_name: HashMap = HashMap::new(); + // ── Fetch all existing providers ───────────────────────────────────── + // Build both a name set (for explicit --provider lookups) and a + // type-to-name map (for inferred provider resolution). + let mut known_names: HashSet = HashSet::new(); + let mut type_to_name: HashMap = HashMap::new(); + { let mut offset = 0_u32; let limit = 100_u32; - loop { let response = client .list_providers(ListProvidersRequest { limit, offset }) @@ -2017,6 +2011,7 @@ async fn ensure_required_providers( .into_diagnostic()?; let providers = response.into_inner().providers; for provider in &providers { + known_names.insert(provider.name.clone()); if !provider.r#type.is_empty() { let type_lower = provider.r#type.to_ascii_lowercase(); type_to_name @@ -2024,13 +2019,47 @@ async fn ensure_required_providers( .or_insert_with(|| provider.name.clone()); } } - if providers.len() < limit as usize { break; } offset = offset.saturating_add(limit); } + } + + // ── Explicit provider names ────────────────────────────────────────── + // If the name exists on the server, use it directly. Otherwise, if the + // name matches a known provider type, auto-create a provider of that + // type with the requested name. + for name in explicit_names { + if known_names.contains(name) { + if seen_names.insert(name.clone()) { + configured_names.push(name.clone()); + } + } else if let Some(provider_type) = normalize_provider_type(name) { + auto_create_provider( + client, + provider_type, + Some(name), + auto_providers_override, + &mut seen_names, + &mut configured_names, + ) + .await?; + // Record the type mapping so the inferred-types pass below + // doesn't attempt to create a duplicate provider. + type_to_name + .entry(provider_type.to_ascii_lowercase()) + .or_insert_with(|| name.clone()); + } else { + return Err(miette::miette!( + "provider '{name}' not found and '{name}' is not a recognized provider type. \ + Create it first with `nemoclaw provider create --type --name {name}`" + )); + } + } + // ── Resolve inferred provider types ────────────────────────────────── + if !inferred_types.is_empty() { // Collect resolved names for types that already have a provider. for t in inferred_types { if let Some(name) = type_to_name.get(&t.to_ascii_lowercase()) @@ -2046,119 +2075,172 @@ async fn ensure_required_providers( .cloned() .collect::>(); - if !missing.is_empty() { - // --no-auto-providers: skip all missing providers silently. - if auto_providers_override == Some(false) { - for provider_type in &missing { - eprintln!( - "{} Skipping provider '{provider_type}' (--no-auto-providers)", - "!".yellow(), - ); - } - return Ok(configured_names); - } + for provider_type in missing { + auto_create_provider( + client, + &provider_type, + None, + auto_providers_override, + &mut seen_names, + &mut configured_names, + ) + .await?; + } + } - // No override and non-interactive: error. - if auto_providers_override.is_none() && !std::io::stdin().is_terminal() { - return Err(miette::miette!( - "missing required providers: {}. Create them first with \ - `nemoclaw provider create --type --name --from-existing`, \ - pass --auto-providers to auto-create, or set them up manually from inside the sandbox", - missing.join(", ") - )); - } + Ok(configured_names) +} - let registry = ProviderRegistry::new(); - for provider_type in missing { - eprintln!("Missing provider: {provider_type}"); +/// Prompt for (or auto-confirm) creation of a provider from local credentials. +/// +/// When `preferred_name` is `Some`, the provider is created with that exact +/// name (used for explicit `--provider ` values). When `None`, the name +/// defaults to the type and retries with suffixes on conflict (used for +/// inferred provider types). +async fn auto_create_provider( + client: &mut NavigatorClient, + provider_type: &str, + preferred_name: Option<&str>, + auto_providers_override: Option, + seen_names: &mut HashSet, + configured_names: &mut Vec, +) -> Result<()> { + eprintln!("Missing provider: {provider_type}"); - // --auto-providers: auto-confirm all. - let should_create = if auto_providers_override == Some(true) { - true - } else { - Confirm::new() - .with_prompt("Create from local credentials?") - .default(true) - .interact() - .into_diagnostic()? - }; + // --no-auto-providers: skip silently. + if auto_providers_override == Some(false) { + eprintln!( + "{} Skipping provider '{provider_type}' (--no-auto-providers)", + "!".yellow(), + ); + eprintln!(); + return Ok(()); + } - if !should_create { - eprintln!("{} Skipping provider '{provider_type}'", "!".yellow(),); - eprintln!(); - continue; - } + // No override and non-interactive: error. + if auto_providers_override.is_none() && !std::io::stdin().is_terminal() { + return Err(miette::miette!( + "missing required provider '{provider_type}'. Create it first with \ + `nemoclaw provider create --type {provider_type} --name {provider_type} --from-existing`, \ + pass --auto-providers to auto-create, or set it up manually from inside the sandbox" + )); + } - let discovered = registry.discover_existing(&provider_type).map_err(|err| { - miette::miette!("failed to discover provider '{provider_type}': {err}") - })?; - let Some(discovered) = discovered else { + // --auto-providers: auto-confirm; otherwise prompt. + let should_create = if auto_providers_override == Some(true) { + true + } else { + Confirm::new() + .with_prompt("Create from local credentials?") + .default(true) + .interact() + .into_diagnostic()? + }; + + if !should_create { + eprintln!("{} Skipping provider '{provider_type}'", "!".yellow()); + eprintln!(); + return Ok(()); + } + + let registry = ProviderRegistry::new(); + let discovered = registry + .discover_existing(provider_type) + .map_err(|err| miette::miette!("failed to discover provider '{provider_type}': {err}"))?; + let Some(discovered) = discovered else { + eprintln!( + "{} No existing local credentials/config found for '{}'. You can configure it from inside the sandbox.", + "!".yellow(), + provider_type + ); + eprintln!(); + return Ok(()); + }; + + if let Some(exact_name) = preferred_name { + // Explicit name: create with exactly that name, no retries. + let request = CreateProviderRequest { + provider: Some(Provider { + id: String::new(), + name: exact_name.to_string(), + r#type: provider_type.to_string(), + credentials: discovered.credentials.clone(), + config: discovered.config.clone(), + }), + }; + + let response = client.create_provider(request).await.map_err(|status| { + miette::miette!("failed to create provider '{exact_name}': {status}") + })?; + let provider = response + .into_inner() + .provider + .ok_or_else(|| miette::miette!("provider missing from response"))?; + eprintln!( + "{} Created provider {} ({}) from existing local state", + "✓".green().bold(), + provider.name, + provider.r#type + ); + if seen_names.insert(provider.name.clone()) { + configured_names.push(provider.name); + } + } else { + // Inferred type: try type as name, then suffixed variants. + let mut created = false; + for attempt in 0..5 { + let name = if attempt == 0 { + provider_type.to_string() + } else { + format!("{provider_type}-{attempt}") + }; + + let request = CreateProviderRequest { + provider: Some(Provider { + id: String::new(), + name: name.clone(), + r#type: provider_type.to_string(), + credentials: discovered.credentials.clone(), + config: discovered.config.clone(), + }), + }; + + match client.create_provider(request).await { + Ok(response) => { + let provider = response + .into_inner() + .provider + .ok_or_else(|| miette::miette!("provider missing from response"))?; eprintln!( - "{} No existing local credentials/config found for '{}'. You can configure it from inside the sandbox.", - "!".yellow(), - provider_type + "{} Created provider {} ({}) from existing local state", + "✓".green().bold(), + provider.name, + provider.r#type ); - eprintln!(); - continue; - }; - - let mut created = false; - for attempt in 0..5 { - let name = if attempt == 0 { - provider_type.clone() - } else { - format!("{provider_type}-{attempt}") - }; - - let request = CreateProviderRequest { - provider: Some(Provider { - id: String::new(), - name: name.clone(), - r#type: provider_type.clone(), - credentials: discovered.credentials.clone(), - config: discovered.config.clone(), - }), - }; - - match client.create_provider(request).await { - Ok(response) => { - let provider = response - .into_inner() - .provider - .ok_or_else(|| miette::miette!("provider missing from response"))?; - eprintln!( - "{} Created provider {} ({}) from existing local state", - "✓".green().bold(), - provider.name, - provider.r#type - ); - if seen_names.insert(provider.name.clone()) { - configured_names.push(provider.name); - } - created = true; - break; - } - Err(status) if status.code() == Code::AlreadyExists => {} - Err(status) => { - return Err(miette::miette!( - "failed to create provider for type '{provider_type}': {status}" - )); - } + if seen_names.insert(provider.name.clone()) { + configured_names.push(provider.name); } + created = true; + break; } - - if !created { + Err(status) if status.code() == Code::AlreadyExists => {} + Err(status) => { return Err(miette::miette!( - "failed to create provider for type '{provider_type}' after name retries" + "failed to create provider for type '{provider_type}': {status}" )); } - - eprintln!(); } } + + if !created { + return Err(miette::miette!( + "failed to create provider for type '{provider_type}' after name retries" + )); + } } - Ok(configured_names) + eprintln!(); + Ok(()) } fn parse_key_value_pairs(items: &[String], flag: &str) -> Result> { diff --git a/crates/navigator-cli/tests/ensure_providers_integration.rs b/crates/navigator-cli/tests/ensure_providers_integration.rs new file mode 100644 index 00000000..fbee593f --- /dev/null +++ b/crates/navigator-cli/tests/ensure_providers_integration.rs @@ -0,0 +1,652 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Integration tests for `ensure_required_providers` — verifies that explicit +//! `--provider` names are auto-created when they match a known provider type, +//! pass through when they already exist, and error for unrecognised names. + +use navigator_cli::run; +use navigator_cli::tls::TlsOptions; +use navigator_core::proto::navigator_server::{Navigator, NavigatorServer}; +use navigator_core::proto::{ + CreateProviderRequest, CreateSandboxRequest, CreateSshSessionRequest, CreateSshSessionResponse, + DeleteProviderRequest, DeleteProviderResponse, DeleteSandboxRequest, DeleteSandboxResponse, + ExecSandboxEvent, ExecSandboxRequest, GetProviderRequest, GetSandboxPolicyRequest, + GetSandboxPolicyResponse, GetSandboxProviderEnvironmentRequest, + GetSandboxProviderEnvironmentResponse, GetSandboxRequest, HealthRequest, HealthResponse, + ListProvidersRequest, ListProvidersResponse, ListSandboxesRequest, ListSandboxesResponse, + Provider, ProviderResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, SandboxResponse, + SandboxStreamEvent, ServiceStatus, UpdateProviderRequest, WatchSandboxRequest, +}; +use rcgen::{ + BasicConstraints, Certificate, CertificateParams, ExtendedKeyUsagePurpose, IsCa, KeyPair, +}; +use std::collections::HashMap; +use std::sync::Arc; +use tempfile::TempDir; +use tokio::net::TcpListener; +use tokio::sync::{Mutex, mpsc}; +use tokio_stream::wrappers::TcpListenerStream; +use tonic::transport::{Certificate as TlsCertificate, Identity, Server, ServerTlsConfig}; +use tonic::{Response, Status}; + +// ── EnvVarGuard ────────────────────────────────────────────────────── + +// Serialise tests that mutate environment variables so concurrent +// threads don't clobber each other. +static ENV_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(()); + +struct SavedVar { + key: &'static str, + original: Option, +} + +/// Holds the global env lock and restores all modified variables on drop. +struct EnvVarGuard { + vars: Vec, + _lock: std::sync::MutexGuard<'static, ()>, +} + +#[allow(unsafe_code)] +impl EnvVarGuard { + /// Acquire the lock and set one or more environment variables. + fn set(pairs: &[(&'static str, &str)]) -> Self { + let lock = ENV_LOCK + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let mut vars = Vec::with_capacity(pairs.len()); + for &(key, value) in pairs { + let original = std::env::var(key).ok(); + unsafe { + std::env::set_var(key, value); + } + vars.push(SavedVar { key, original }); + } + Self { vars, _lock: lock } + } +} + +#[allow(unsafe_code)] +impl Drop for EnvVarGuard { + fn drop(&mut self) { + for var in &self.vars { + if let Some(value) = &var.original { + unsafe { + std::env::set_var(var.key, value); + } + } else { + unsafe { + std::env::remove_var(var.key); + } + } + } + // _lock drops here, releasing the mutex + } +} + +// ── mock Navigator server ───────────────────────────────────────────── + +#[derive(Clone, Default)] +struct ProviderState { + providers: Arc>>, +} + +#[derive(Clone, Default)] +struct TestNavigator { + state: ProviderState, +} + +impl TestNavigator { + /// Seed the mock with an existing provider. + async fn seed_provider(&self, name: &str, provider_type: &str) { + let mut providers = self.state.providers.lock().await; + providers.insert( + name.to_string(), + Provider { + id: format!("id-{name}"), + name: name.to_string(), + r#type: provider_type.to_string(), + credentials: HashMap::new(), + config: HashMap::new(), + }, + ); + } +} + +#[tonic::async_trait] +impl Navigator for TestNavigator { + async fn health( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(HealthResponse { + status: ServiceStatus::Healthy.into(), + version: "test".to_string(), + })) + } + + async fn create_sandbox( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(SandboxResponse::default())) + } + + async fn get_sandbox( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(SandboxResponse::default())) + } + + async fn list_sandboxes( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(ListSandboxesResponse::default())) + } + + async fn delete_sandbox( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(DeleteSandboxResponse { deleted: true })) + } + + async fn get_sandbox_policy( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(GetSandboxPolicyResponse::default())) + } + + async fn get_sandbox_provider_environment( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new( + GetSandboxProviderEnvironmentResponse::default(), + )) + } + + async fn create_ssh_session( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(CreateSshSessionResponse::default())) + } + + async fn revoke_ssh_session( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(RevokeSshSessionResponse::default())) + } + + async fn create_provider( + &self, + request: tonic::Request, + ) -> Result, Status> { + let mut provider = request + .into_inner() + .provider + .ok_or_else(|| Status::invalid_argument("provider is required"))?; + let mut providers = self.state.providers.lock().await; + if providers.contains_key(&provider.name) { + return Err(Status::already_exists("provider already exists")); + } + if provider.id.is_empty() { + provider.id = format!("id-{}", provider.name); + } + providers.insert(provider.name.clone(), provider.clone()); + Ok(Response::new(ProviderResponse { + provider: Some(provider), + })) + } + + async fn get_provider( + &self, + request: tonic::Request, + ) -> Result, Status> { + let name = request.into_inner().name; + let providers = self.state.providers.lock().await; + let provider = providers + .get(&name) + .cloned() + .ok_or_else(|| Status::not_found("provider not found"))?; + Ok(Response::new(ProviderResponse { + provider: Some(provider), + })) + } + + async fn list_providers( + &self, + _request: tonic::Request, + ) -> Result, Status> { + let providers = self + .state + .providers + .lock() + .await + .values() + .cloned() + .collect::>(); + Ok(Response::new(ListProvidersResponse { providers })) + } + + async fn update_provider( + &self, + request: tonic::Request, + ) -> Result, Status> { + let provider = request + .into_inner() + .provider + .ok_or_else(|| Status::invalid_argument("provider is required"))?; + + let mut providers = self.state.providers.lock().await; + let existing = providers + .get(&provider.name) + .cloned() + .ok_or_else(|| Status::not_found("provider not found"))?; + let updated = Provider { + id: existing.id, + name: provider.name, + r#type: provider.r#type, + credentials: provider.credentials, + config: provider.config, + }; + providers.insert(updated.name.clone(), updated.clone()); + Ok(Response::new(ProviderResponse { + provider: Some(updated), + })) + } + + async fn delete_provider( + &self, + request: tonic::Request, + ) -> Result, Status> { + let name = request.into_inner().name; + let deleted = self.state.providers.lock().await.remove(&name).is_some(); + Ok(Response::new(DeleteProviderResponse { deleted })) + } + + type WatchSandboxStream = + tokio_stream::wrappers::ReceiverStream>; + type ExecSandboxStream = + tokio_stream::wrappers::ReceiverStream>; + + async fn watch_sandbox( + &self, + _request: tonic::Request, + ) -> Result, Status> { + let (_tx, rx) = mpsc::channel(1); + Ok(Response::new(tokio_stream::wrappers::ReceiverStream::new( + rx, + ))) + } + + async fn exec_sandbox( + &self, + _request: tonic::Request, + ) -> Result, Status> { + let (_tx, rx) = mpsc::channel(1); + Ok(Response::new(tokio_stream::wrappers::ReceiverStream::new( + rx, + ))) + } + + async fn update_sandbox_policy( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + + async fn get_sandbox_policy_status( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + + async fn list_sandbox_policies( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + + async fn report_policy_status( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + + async fn get_sandbox_logs( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } + + async fn push_sandbox_logs( + &self, + _request: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } +} + +// ── TLS helpers ────────────────────────────────────────────────────── + +fn install_rustls_provider() { + let _ = rustls::crypto::ring::default_provider().install_default(); +} + +fn build_ca() -> (Certificate, KeyPair) { + let key_pair = KeyPair::generate().unwrap(); + let mut params = CertificateParams::new(Vec::::new()).unwrap(); + params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained); + let cert = params.self_signed(&key_pair).unwrap(); + (cert, key_pair) +} + +fn build_server_cert(ca: &Certificate, ca_key: &KeyPair) -> (String, String) { + let key_pair = KeyPair::generate().unwrap(); + let mut params = CertificateParams::new(vec!["localhost".to_string()]).unwrap(); + params.extended_key_usages = vec![ExtendedKeyUsagePurpose::ServerAuth]; + let cert = params.signed_by(&key_pair, ca, ca_key).unwrap(); + (cert.pem(), key_pair.serialize_pem()) +} + +fn build_client_cert(ca: &Certificate, ca_key: &KeyPair) -> (String, String) { + let key_pair = KeyPair::generate().unwrap(); + let mut params = CertificateParams::new(Vec::::new()).unwrap(); + params.extended_key_usages = vec![ExtendedKeyUsagePurpose::ClientAuth]; + let cert = params.signed_by(&key_pair, ca, ca_key).unwrap(); + (cert.pem(), key_pair.serialize_pem()) +} + +// ── test server fixture ────────────────────────────────────────────── + +struct TestServer { + endpoint: String, + tls: TlsOptions, + navigator: TestNavigator, + _dir: TempDir, +} + +async fn run_server() -> TestServer { + install_rustls_provider(); + + let (ca, ca_key) = build_ca(); + let (server_cert, server_key) = build_server_cert(&ca, &ca_key); + let (client_cert, client_key) = build_client_cert(&ca, &ca_key); + let ca_cert = ca.pem(); + + let identity = Identity::from_pem(server_cert, server_key); + let client_ca = TlsCertificate::from_pem(ca_cert.clone()); + let tls_config = ServerTlsConfig::new() + .identity(identity) + .client_ca_root(client_ca); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let incoming = TcpListenerStream::new(listener); + + let navigator = TestNavigator::default(); + let svc_navigator = navigator.clone(); + + tokio::spawn(async move { + Server::builder() + .tls_config(tls_config) + .unwrap() + .add_service(NavigatorServer::new(svc_navigator)) + .serve_with_incoming(incoming) + .await + .unwrap(); + }); + + let dir = tempfile::tempdir().unwrap(); + let ca_path = dir.path().join("ca.crt"); + let cert_path = dir.path().join("tls.crt"); + let key_path = dir.path().join("tls.key"); + std::fs::write(&ca_path, ca_cert).unwrap(); + std::fs::write(&cert_path, client_cert).unwrap(); + std::fs::write(&key_path, client_key).unwrap(); + + let tls = TlsOptions::new(Some(ca_path), Some(cert_path), Some(key_path)); + let endpoint = format!("https://localhost:{}", addr.port()); + + TestServer { + endpoint, + tls, + navigator, + _dir: dir, + } +} + +// ── tests ──────────────────────────────────────────────────────────── + +/// When `--provider nvidia` is passed and a provider named "nvidia" already +/// exists, `ensure_required_providers` should return it directly without +/// creating anything new. +#[tokio::test] +async fn explicit_provider_name_passes_through_when_it_exists() { + let ts = run_server().await; + ts.navigator.seed_provider("nvidia", "nvidia").await; + + let mut client = navigator_cli::tls::grpc_client(&ts.endpoint, &ts.tls) + .await + .expect("grpc client"); + + let result = run::ensure_required_providers( + &mut client, + &["nvidia".to_string()], + &[], + Some(true), // --auto-providers (should not matter here) + ) + .await + .expect("should succeed"); + + assert_eq!(result, vec!["nvidia".to_string()]); + + // Verify no extra providers were created. + let providers = ts.navigator.state.providers.lock().await; + assert_eq!(providers.len(), 1, "no new providers should be created"); +} + +/// When `--provider nvidia` is passed, no provider named "nvidia" exists, and +/// "nvidia" is a valid provider type, the CLI should auto-create a provider +/// named "nvidia" of type "nvidia" using discovered local credentials. +#[tokio::test] +async fn explicit_provider_name_auto_creates_when_valid_type() { + let ts = run_server().await; + let _guard = EnvVarGuard::set(&[("NVIDIA_API_KEY", "nvapi-test-key")]); + + let mut client = navigator_cli::tls::grpc_client(&ts.endpoint, &ts.tls) + .await + .expect("grpc client"); + + let result = run::ensure_required_providers( + &mut client, + &["nvidia".to_string()], + &[], + Some(true), // --auto-providers to skip interactive prompt + ) + .await + .expect("should auto-create the provider"); + + assert_eq!(result, vec!["nvidia".to_string()]); + + // Verify the provider was created on the server with the right type. + let providers = ts.navigator.state.providers.lock().await; + let provider = providers + .get("nvidia") + .expect("nvidia provider should exist"); + assert_eq!(provider.r#type, "nvidia"); + assert_eq!( + provider.credentials.get("NVIDIA_API_KEY"), + Some(&"nvapi-test-key".to_string()), + ); +} + +/// When `--provider my-custom-thing` is passed and "my-custom-thing" is not a +/// known provider type, the CLI should return an error. +#[tokio::test] +async fn explicit_provider_name_errors_for_unrecognised_name() { + let ts = run_server().await; + + let mut client = navigator_cli::tls::grpc_client(&ts.endpoint, &ts.tls) + .await + .expect("grpc client"); + + let err = run::ensure_required_providers( + &mut client, + &["my-custom-thing".to_string()], + &[], + Some(true), + ) + .await + .expect_err("should fail for unrecognised provider name"); + + let msg = err.to_string(); + assert!( + msg.contains("my-custom-thing"), + "error should mention the name: {msg}" + ); + assert!( + msg.contains("not a recognized provider type"), + "error should explain why it failed: {msg}" + ); +} + +/// Inferred types (from the trailing command) that don't exist should be +/// auto-created, preserving the existing behaviour. +#[tokio::test] +async fn inferred_type_auto_creates_provider() { + let ts = run_server().await; + let _guard = EnvVarGuard::set(&[("ANTHROPIC_API_KEY", "sk-ant-test")]); + + let mut client = navigator_cli::tls::grpc_client(&ts.endpoint, &ts.tls) + .await + .expect("grpc client"); + + let result = run::ensure_required_providers( + &mut client, + &[], + &["claude".to_string()], + Some(true), // --auto-providers + ) + .await + .expect("should auto-create the inferred provider"); + + assert_eq!(result, vec!["claude".to_string()]); + + let providers = ts.navigator.state.providers.lock().await; + let provider = providers + .get("claude") + .expect("claude provider should exist"); + assert_eq!(provider.r#type, "claude"); +} + +/// When `--no-auto-providers` is set, missing explicit providers that would +/// otherwise be auto-created should be silently skipped. +#[tokio::test] +async fn no_auto_providers_skips_missing_explicit_provider() { + let ts = run_server().await; + let _guard = EnvVarGuard::set(&[("NVIDIA_API_KEY", "nvapi-skip-test")]); + + let mut client = navigator_cli::tls::grpc_client(&ts.endpoint, &ts.tls) + .await + .expect("grpc client"); + + let result = run::ensure_required_providers( + &mut client, + &["nvidia".to_string()], + &[], + Some(false), // --no-auto-providers + ) + .await + .expect("should succeed with empty list"); + + assert!( + result.is_empty(), + "skipped providers should not appear in the result" + ); + + let providers = ts.navigator.state.providers.lock().await; + assert!( + providers.is_empty(), + "no providers should be created when --no-auto-providers is set" + ); +} + +/// Both explicit names and inferred types should be resolved together, +/// deduplicating providers that appear in both lists. +#[tokio::test] +async fn explicit_and_inferred_providers_combined() { + let ts = run_server().await; + let _guard = EnvVarGuard::set(&[ + ("NVIDIA_API_KEY", "nvapi-combo"), + ("ANTHROPIC_API_KEY", "sk-ant-combo"), + ]); + + let mut client = navigator_cli::tls::grpc_client(&ts.endpoint, &ts.tls) + .await + .expect("grpc client"); + + let result = run::ensure_required_providers( + &mut client, + &["nvidia".to_string()], + &["claude".to_string()], + Some(true), + ) + .await + .expect("should create both providers"); + + assert_eq!(result.len(), 2); + assert!(result.contains(&"nvidia".to_string())); + assert!(result.contains(&"claude".to_string())); + + let providers = ts.navigator.state.providers.lock().await; + assert_eq!(providers.len(), 2); + assert!(providers.contains_key("nvidia")); + assert!(providers.contains_key("claude")); +} + +/// When an explicit provider name matches an inferred type, the provider +/// should only appear once in the result. +#[tokio::test] +async fn explicit_and_inferred_deduplicates() { + let ts = run_server().await; + let _guard = EnvVarGuard::set(&[("NVIDIA_API_KEY", "nvapi-dedup")]); + + let mut client = navigator_cli::tls::grpc_client(&ts.endpoint, &ts.tls) + .await + .expect("grpc client"); + + // Both explicit and inferred want "nvidia". + let result = run::ensure_required_providers( + &mut client, + &["nvidia".to_string()], + &["nvidia".to_string()], + Some(true), + ) + .await + .expect("should succeed"); + + assert_eq!( + result, + vec!["nvidia".to_string()], + "nvidia should appear exactly once" + ); + + let providers = ts.navigator.state.providers.lock().await; + assert_eq!( + providers.len(), + 1, + "only one provider should be created on the server" + ); +} diff --git a/e2e/rust/tests/provider_auto_create.rs b/e2e/rust/tests/provider_auto_create.rs new file mode 100644 index 00000000..c2fb7685 --- /dev/null +++ b/e2e/rust/tests/provider_auto_create.rs @@ -0,0 +1,106 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#![cfg(feature = "e2e")] + +//! E2E test: `--provider ` auto-creates a provider from local credentials. +//! +//! When `--provider claude` is passed and no provider named "claude" exists, +//! the CLI should discover `ANTHROPIC_API_KEY` from the local environment, +//! auto-create a provider, and inject its credentials into the sandbox. +//! +//! The sandbox command (`printenv ANTHROPIC_API_KEY`) verifies that the +//! credential made it all the way through to the sandbox process environment. +//! +//! Prerequisites: +//! - A running nemoclaw gateway (`nemoclaw gateway start`) +//! - The `nemoclaw` binary (built automatically from the workspace) + +use std::process::Stdio; + +use nemoclaw_e2e::harness::binary::nemoclaw_cmd; +use nemoclaw_e2e::harness::output::{extract_field, strip_ansi}; + +const TEST_API_KEY: &str = "sk-e2e-auto-provider-test-key"; + +/// Helper: delete a provider by name, ignoring errors. +async fn delete_provider(name: &str) { + let mut cmd = nemoclaw_cmd(); + cmd.arg("provider") + .arg("delete") + .arg(name) + .stdout(Stdio::null()) + .stderr(Stdio::null()); + let _ = cmd.status().await; +} + +/// Helper: delete a sandbox by name, ignoring errors. +async fn delete_sandbox(name: &str) { + let mut cmd = nemoclaw_cmd(); + cmd.arg("sandbox") + .arg("delete") + .arg(name) + .stdout(Stdio::null()) + .stderr(Stdio::null()); + let _ = cmd.status().await; +} + +/// `--provider claude --auto-providers` with `ANTHROPIC_API_KEY` set should +/// auto-create a "claude" provider and inject the credential into the sandbox. +#[tokio::test] +async fn auto_created_provider_credential_available_in_sandbox() { + // Clean up any leftover from a previous run. + delete_provider("claude").await; + + // Create a sandbox that prints the ANTHROPIC_API_KEY env var. + // --auto-providers skips the interactive prompt. + let mut cmd = nemoclaw_cmd(); + cmd.arg("sandbox") + .arg("create") + .arg("--provider") + .arg("claude") + .arg("--auto-providers") + .arg("--no-bootstrap") + .arg("--") + .arg("printenv") + .arg("ANTHROPIC_API_KEY") + .env("ANTHROPIC_API_KEY", TEST_API_KEY) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()); + + let output = cmd + .output() + .await + .expect("failed to spawn nemoclaw sandbox create"); + + let stdout = String::from_utf8_lossy(&output.stdout).to_string(); + let stderr = String::from_utf8_lossy(&output.stderr).to_string(); + let combined = format!("{stdout}{stderr}"); + let clean = strip_ansi(&combined); + + // Parse sandbox name for cleanup. + let sandbox_name = extract_field(&combined, "Name"); + + // Always clean up, even if assertions fail. + if let Some(ref name) = sandbox_name { + delete_sandbox(name).await; + } + delete_provider("claude").await; + + // Now assert. + assert!( + output.status.success(), + "sandbox create should succeed (exit {:?}):\n{clean}", + output.status.code() + ); + + assert!( + clean.contains("Created provider claude"), + "output should confirm provider auto-creation:\n{clean}" + ); + + assert!( + clean.contains(TEST_API_KEY), + "sandbox should have ANTHROPIC_API_KEY in its environment:\n{clean}" + ); +}