diff --git a/.github/workflows/e2e_tests.yaml b/.github/workflows/e2e_tests.yaml index 5a339598c..4aeabe656 100644 --- a/.github/workflows/e2e_tests.yaml +++ b/.github/workflows/e2e_tests.yaml @@ -83,9 +83,9 @@ jobs: url: http://llama-stack:8321 api_key: xyzzy user_data_collection: - feedback_disabled: false + feedback_enabled: true feedback_storage: "/tmp/data/feedback" - transcripts_disabled: false + transcripts_enabled: true transcripts_storage: "/tmp/data/transcripts" authentication: diff --git a/src/models/config.py b/src/models/config.py index de6cc3e2d..9166eb4bc 100644 --- a/src/models/config.py +++ b/src/models/config.py @@ -8,6 +8,7 @@ from jsonpath_ng.exceptions import JSONPathError from pydantic import ( BaseModel, + ConfigDict, Field, model_validator, FilePath, @@ -21,7 +22,13 @@ from utils import checks -class TLSConfiguration(BaseModel): +class ConfigurationBase(BaseModel): + """Base class for all configuration models that rejects unknown fields.""" + + model_config = ConfigDict(extra="forbid") + + +class TLSConfiguration(ConfigurationBase): """TLS configuration.""" tls_certificate_path: Optional[FilePath] = None @@ -34,7 +41,7 @@ def check_tls_configuration(self) -> Self: return self -class CORSConfiguration(BaseModel): +class CORSConfiguration(ConfigurationBase): """CORS configuration.""" allow_origins: list[str] = [ @@ -58,13 +65,13 @@ def check_cors_configuration(self) -> Self: return self -class SQLiteDatabaseConfiguration(BaseModel): +class SQLiteDatabaseConfiguration(ConfigurationBase): """SQLite database configuration.""" db_path: str -class PostgreSQLDatabaseConfiguration(BaseModel): +class PostgreSQLDatabaseConfiguration(ConfigurationBase): """PostgreSQL database configuration.""" host: str = "localhost" @@ -85,7 +92,7 @@ def check_postgres_configuration(self) -> Self: return self -class DatabaseConfiguration(BaseModel): +class DatabaseConfiguration(ConfigurationBase): """Database configuration.""" sqlite: Optional[SQLiteDatabaseConfiguration] = None @@ -126,7 +133,7 @@ def config(self) -> SQLiteDatabaseConfiguration | PostgreSQLDatabaseConfiguratio raise ValueError("No database configuration found") -class ServiceConfiguration(BaseModel): +class ServiceConfiguration(ConfigurationBase): """Service configuration.""" host: str = "localhost" @@ -146,7 +153,7 @@ def check_service_configuration(self) -> Self: return self -class ModelContextProtocolServer(BaseModel): +class ModelContextProtocolServer(ConfigurationBase): """model context protocol server configuration.""" name: str @@ -154,7 +161,7 @@ class ModelContextProtocolServer(BaseModel): url: str -class LlamaStackConfiguration(BaseModel): +class LlamaStackConfiguration(ConfigurationBase): """Llama stack configuration.""" url: Optional[str] = None @@ -200,7 +207,7 @@ def check_llama_stack_model(self) -> Self: return self -class UserDataCollection(BaseModel): +class UserDataCollection(ConfigurationBase): """User data collection configuration.""" feedback_enabled: bool = False @@ -228,7 +235,7 @@ class JsonPathOperator(str, Enum): IN = "in" -class JwtRoleRule(BaseModel): +class JwtRoleRule(ConfigurationBase): """Rule for extracting roles from JWT claims.""" jsonpath: str # JSONPath expression to evaluate against the JWT payload @@ -306,14 +313,14 @@ class Action(str, Enum): INFO = "info" -class AccessRule(BaseModel): +class AccessRule(ConfigurationBase): """Rule defining what actions a role can perform.""" role: str # Role name actions: list[Action] # Allowed actions for this role -class AuthorizationConfiguration(BaseModel): +class AuthorizationConfiguration(ConfigurationBase): """Authorization configuration.""" access_rules: list[AccessRule] = Field( @@ -321,7 +328,7 @@ class AuthorizationConfiguration(BaseModel): ) # Rules for role-based access control -class JwtConfiguration(BaseModel): +class JwtConfiguration(ConfigurationBase): """JWT configuration.""" user_id_claim: str = constants.DEFAULT_JWT_UID_CLAIM @@ -331,14 +338,14 @@ class JwtConfiguration(BaseModel): ) # Rules for extracting roles from JWT claims -class JwkConfiguration(BaseModel): +class JwkConfiguration(ConfigurationBase): """JWK configuration.""" url: AnyHttpUrl jwt_configuration: JwtConfiguration = JwtConfiguration() -class AuthenticationConfiguration(BaseModel): +class AuthenticationConfiguration(ConfigurationBase): """Authentication configuration.""" module: str = constants.DEFAULT_AUTHENTICATION_MODULE @@ -377,7 +384,7 @@ def jwk_configuration(self) -> JwkConfiguration: return self.jwk_config -class Customization(BaseModel): +class Customization(ConfigurationBase): """Service customization.""" disable_query_system_prompt: bool = False @@ -395,7 +402,7 @@ def check_customization_model(self) -> Self: return self -class InferenceConfiguration(BaseModel): +class InferenceConfiguration(ConfigurationBase): """Inference configuration.""" default_model: Optional[str] = None @@ -415,7 +422,7 @@ def check_default_model_and_provider(self) -> Self: return self -class Configuration(BaseModel): +class Configuration(ConfigurationBase): """Global service configuration.""" name: str diff --git a/tests/unit/test_configuration_unknown_fields.py b/tests/unit/test_configuration_unknown_fields.py new file mode 100644 index 000000000..b97c3f0d3 --- /dev/null +++ b/tests/unit/test_configuration_unknown_fields.py @@ -0,0 +1,12 @@ +"""Test configuration validation for unknown fields.""" + +import pytest +from pydantic import ValidationError + +from models.config import ServiceConfiguration + + +def test_configuration_rejects_unknown_fields(): + """Test that configuration models reject unknown fields.""" + with pytest.raises(ValidationError, match="Extra inputs are not permitted"): + ServiceConfiguration(host="localhost", port=8080, unknown_field="should_fail")