diff --git a/src/models/config.py b/src/models/config.py index 8acf95799..bda9699ad 100644 --- a/src/models/config.py +++ b/src/models/config.py @@ -30,13 +30,21 @@ class CORSConfiguration(BaseModel): allow_origins: list[str] = [ "*" ] # not AnyHttpUrl: we need to support "*" that is not valid URL - allow_credentials: bool = True + allow_credentials: bool = False allow_methods: list[str] = ["*"] allow_headers: list[str] = ["*"] @model_validator(mode="after") def check_cors_configuration(self) -> Self: """Check CORS configuration.""" + # credentials are not allowed with wildcard origins per CORS/Fetch spec. + # see https://fastapi.tiangolo.com/tutorial/cors/ + if self.allow_credentials and "*" in self.allow_origins: + raise ValueError( + "Invalid CORS configuration: allow_credentials can not be set to true when " + "allow origins contains '*' wildcard." + "Use explicit origins or disable credential." + ) return self diff --git a/tests/unit/models/test_config.py b/tests/unit/models/test_config.py index f6302ef96..02a976c68 100644 --- a/tests/unit/models/test_config.py +++ b/tests/unit/models/test_config.py @@ -220,12 +220,12 @@ def test_cors_default_configuration() -> None: cfg = CORSConfiguration() assert cfg is not None assert cfg.allow_origins == ["*"] - assert cfg.allow_credentials is True + assert cfg.allow_credentials is False assert cfg.allow_methods == ["*"] assert cfg.allow_headers == ["*"] -def test_cors_custom_configuration() -> None: +def test_cors_custom_configuration_v1() -> None: """Test the CORS configuration.""" cfg = CORSConfiguration( allow_origins=["foo_origin", "bar_origin", "baz_origin"], @@ -240,6 +240,54 @@ def test_cors_custom_configuration() -> None: assert cfg.allow_headers == ["foo_header", "bar_header", "baz_header"] +def test_cors_custom_configuration_v2() -> None: + """Test the CORS configuration.""" + cfg = CORSConfiguration( + allow_origins=["foo_origin", "bar_origin", "baz_origin"], + allow_credentials=True, + allow_methods=["foo_method", "bar_method", "baz_method"], + allow_headers=["foo_header", "bar_header", "baz_header"], + ) + assert cfg is not None + assert cfg.allow_origins == ["foo_origin", "bar_origin", "baz_origin"] + assert cfg.allow_credentials is True + assert cfg.allow_methods == ["foo_method", "bar_method", "baz_method"] + assert cfg.allow_headers == ["foo_header", "bar_header", "baz_header"] + + +def test_cors_custom_configuration_v3() -> None: + """Test the CORS configuration.""" + cfg = CORSConfiguration( + allow_origins=["*"], + allow_credentials=False, + allow_methods=["foo_method", "bar_method", "baz_method"], + allow_headers=["foo_header", "bar_header", "baz_header"], + ) + assert cfg is not None + assert cfg.allow_origins == ["*"] + assert cfg.allow_credentials is False + assert cfg.allow_methods == ["foo_method", "bar_method", "baz_method"] + assert cfg.allow_headers == ["foo_header", "bar_header", "baz_header"] + + +def test_cors_improper_configuration() -> None: + """Test the CORS configuration.""" + expected = ( + "Value error, Invalid CORS configuration: " + + "allow_credentials can not be set to true when allow origins contains '\\*' wildcard." + + "Use explicit origins or disable credential." + ) + + with pytest.raises(ValueError, match=expected): + # allow_credentials can not be true when allow_origins contains '*' + CORSConfiguration( + allow_origins=["*"], + allow_credentials=True, + allow_methods=["foo_method", "bar_method", "baz_method"], + allow_headers=["foo_header", "bar_header", "baz_header"], + ) + + def test_tls_configuration() -> None: """Test the TLS configuration.""" cfg = TLSConfiguration(