Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/e2e_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ jobs:
url: http://llama-stack:8321
api_key: xyzzy
user_data_collection:
feedback_disabled: false
feedback_enabled: true
feedback_storage: "/tmp/data/feedback"
transcripts_disabled: false
transcripts_enabled: true
transcripts_storage: "/tmp/data/transcripts"

authentication:
Expand Down
43 changes: 25 additions & 18 deletions src/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from jsonpath_ng.exceptions import JSONPathError
from pydantic import (
BaseModel,
ConfigDict,
Field,
model_validator,
FilePath,
Expand All @@ -21,7 +22,13 @@
from utils import checks


class TLSConfiguration(BaseModel):
class ConfigurationBase(BaseModel):
"""Base class for all configuration models that rejects unknown fields."""

model_config = ConfigDict(extra="forbid")


class TLSConfiguration(ConfigurationBase):
"""TLS configuration."""

tls_certificate_path: Optional[FilePath] = None
Expand All @@ -34,7 +41,7 @@ def check_tls_configuration(self) -> Self:
return self


class CORSConfiguration(BaseModel):
class CORSConfiguration(ConfigurationBase):
"""CORS configuration."""

allow_origins: list[str] = [
Expand All @@ -58,13 +65,13 @@ def check_cors_configuration(self) -> Self:
return self


class SQLiteDatabaseConfiguration(BaseModel):
class SQLiteDatabaseConfiguration(ConfigurationBase):
"""SQLite database configuration."""

db_path: str


class PostgreSQLDatabaseConfiguration(BaseModel):
class PostgreSQLDatabaseConfiguration(ConfigurationBase):
"""PostgreSQL database configuration."""

host: str = "localhost"
Expand All @@ -85,7 +92,7 @@ def check_postgres_configuration(self) -> Self:
return self


class DatabaseConfiguration(BaseModel):
class DatabaseConfiguration(ConfigurationBase):
"""Database configuration."""

sqlite: Optional[SQLiteDatabaseConfiguration] = None
Expand Down Expand Up @@ -126,7 +133,7 @@ def config(self) -> SQLiteDatabaseConfiguration | PostgreSQLDatabaseConfiguratio
raise ValueError("No database configuration found")


class ServiceConfiguration(BaseModel):
class ServiceConfiguration(ConfigurationBase):
"""Service configuration."""

host: str = "localhost"
Expand All @@ -146,15 +153,15 @@ def check_service_configuration(self) -> Self:
return self


class ModelContextProtocolServer(BaseModel):
class ModelContextProtocolServer(ConfigurationBase):
"""model context protocol server configuration."""

name: str
provider_id: str = "model-context-protocol"
url: str


class LlamaStackConfiguration(BaseModel):
class LlamaStackConfiguration(ConfigurationBase):
"""Llama stack configuration."""

url: Optional[str] = None
Expand Down Expand Up @@ -200,7 +207,7 @@ def check_llama_stack_model(self) -> Self:
return self


class UserDataCollection(BaseModel):
class UserDataCollection(ConfigurationBase):
"""User data collection configuration."""

feedback_enabled: bool = False
Expand Down Expand Up @@ -228,7 +235,7 @@ class JsonPathOperator(str, Enum):
IN = "in"


class JwtRoleRule(BaseModel):
class JwtRoleRule(ConfigurationBase):
"""Rule for extracting roles from JWT claims."""

jsonpath: str # JSONPath expression to evaluate against the JWT payload
Expand Down Expand Up @@ -306,22 +313,22 @@ class Action(str, Enum):
INFO = "info"


class AccessRule(BaseModel):
class AccessRule(ConfigurationBase):
"""Rule defining what actions a role can perform."""

role: str # Role name
actions: list[Action] # Allowed actions for this role


class AuthorizationConfiguration(BaseModel):
class AuthorizationConfiguration(ConfigurationBase):
"""Authorization configuration."""

access_rules: list[AccessRule] = Field(
default_factory=list
) # Rules for role-based access control


class JwtConfiguration(BaseModel):
class JwtConfiguration(ConfigurationBase):
"""JWT configuration."""

user_id_claim: str = constants.DEFAULT_JWT_UID_CLAIM
Expand All @@ -331,14 +338,14 @@ class JwtConfiguration(BaseModel):
) # Rules for extracting roles from JWT claims


class JwkConfiguration(BaseModel):
class JwkConfiguration(ConfigurationBase):
"""JWK configuration."""

url: AnyHttpUrl
jwt_configuration: JwtConfiguration = JwtConfiguration()

Comment on lines +341 to 346
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Avoid nested default instances in JWK config.

Use default_factory to prevent shared state.

-    jwt_configuration: JwtConfiguration = JwtConfiguration()
+    jwt_configuration: JwtConfiguration = Field(default_factory=JwtConfiguration)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
class JwkConfiguration(ConfigurationBase):
"""JWK configuration."""
url: AnyHttpUrl
jwt_configuration: JwtConfiguration = JwtConfiguration()
class JwkConfiguration(ConfigurationBase):
"""JWK configuration."""
url: AnyHttpUrl
jwt_configuration: JwtConfiguration = Field(default_factory=JwtConfiguration)
🤖 Prompt for AI Agents
In src/models/config.py around lines 341 to 346, the JwtConfiguration default is
created as a shared instance (jwt_configuration: JwtConfiguration =
JwtConfiguration()), which can cause shared mutable state across model
instances; change this to use pydantic's default_factory (e.g.,
Field(default_factory=JwtConfiguration)) so each JwkConfiguration instance gets
its own JwtConfiguration object and avoid shared state.


class AuthenticationConfiguration(BaseModel):
class AuthenticationConfiguration(ConfigurationBase):
"""Authentication configuration."""

module: str = constants.DEFAULT_AUTHENTICATION_MODULE
Expand Down Expand Up @@ -377,7 +384,7 @@ def jwk_configuration(self) -> JwkConfiguration:
return self.jwk_config


class Customization(BaseModel):
class Customization(ConfigurationBase):
"""Service customization."""

disable_query_system_prompt: bool = False
Expand All @@ -395,7 +402,7 @@ def check_customization_model(self) -> Self:
return self


class InferenceConfiguration(BaseModel):
class InferenceConfiguration(ConfigurationBase):
"""Inference configuration."""

default_model: Optional[str] = None
Expand All @@ -415,7 +422,7 @@ def check_default_model_and_provider(self) -> Self:
return self


class Configuration(BaseModel):
class Configuration(ConfigurationBase):
"""Global service configuration."""

name: str
Expand Down
12 changes: 12 additions & 0 deletions tests/unit/test_configuration_unknown_fields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""Test configuration validation for unknown fields."""

import pytest
from pydantic import ValidationError

from models.config import ServiceConfiguration


def test_configuration_rejects_unknown_fields():
"""Test that configuration models reject unknown fields."""
with pytest.raises(ValidationError, match="Extra inputs are not permitted"):
ServiceConfiguration(host="localhost", port=8080, unknown_field="should_fail")
Loading