Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/authentication/rh_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,14 @@ def get_username(self) -> str:
return identity["user"]["username"]
return identity["account_number"]

def get_org_id(self) -> str:
"""Extract organization ID from identity data.

Returns:
Organization ID string, or empty string if not present
"""
return self.identity_data["identity"].get("org_id", "")

def has_entitlement(self, service: str) -> bool:
"""Check if user has a specific service entitlement.

Expand Down Expand Up @@ -239,6 +247,9 @@ async def __call__(self, request: Request) -> AuthTuple:
# Validate entitlements if configured
rh_identity.validate_entitlements()

# Store identity data in request.state for downstream access
request.state.rh_identity_data = rh_identity

# Extract user data
user_id = rh_identity.get_user_id()
username = rh_identity.get_username()
Expand Down
49 changes: 48 additions & 1 deletion tests/unit/authentication/test_rh_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def create_auth_header(identity_data: dict) -> str:
return base64.b64encode(json_str.encode("utf-8")).decode("utf-8")


def create_request_with_header(header_value: Optional[str]) -> Request:
def create_request_with_header(header_value: Optional[str]) -> Mock:
"""Helper to create mock Request with x-rh-identity header.

Create a mock FastAPI Request with an `x-rh-identity` header for tests.
Expand All @@ -109,6 +109,7 @@ def create_request_with_header(header_value: Optional[str]) -> Request:
"""
request = Mock(spec=Request)
request.headers = {"x-rh-identity": header_value} if header_value else {}
request.state = Mock()
return request


Expand All @@ -129,6 +130,27 @@ def test_system_type_extraction(self, system_identity_data: dict) -> None:
assert rh_identity.get_user_id() == "c87dcb4c-8af1-40dd-878e-60c744edddd0"
assert rh_identity.get_username() == "123"

@pytest.mark.parametrize(
"fixture_name", ["user_identity_data", "system_identity_data"]
)
def test_get_org_id(
self, fixture_name: str, request: pytest.FixtureRequest
) -> None:
"""Test org_id extraction for both identity types."""
identity_data = request.getfixturevalue(fixture_name)
rh_identity = RHIdentityData(identity_data)
assert rh_identity.get_org_id() == "321"

def test_get_org_id_missing(self, user_identity_data: dict) -> None:
"""Test org_id returns empty string when not present."""
identity_data = {
**user_identity_data,
"identity": {**user_identity_data["identity"]},
}
identity_data["identity"].pop("org_id", None)
rh_identity = RHIdentityData(identity_data)
assert rh_identity.get_org_id() == ""

@pytest.mark.parametrize(
"service,expected",
[
Expand Down Expand Up @@ -309,6 +331,31 @@ async def test_system_authentication_success(
assert skip_check is False
assert token == NO_USER_TOKEN

@pytest.mark.asyncio
@pytest.mark.parametrize(
"fixture_name,expected_user_id",
[
("user_identity_data", "abc123"),
("system_identity_data", "c87dcb4c-8af1-40dd-878e-60c744edddd0"),
],
)
async def test_rh_identity_stored_in_request_state(
self, fixture_name: str, expected_user_id: str, request: pytest.FixtureRequest
) -> None:
"""Test RH Identity data is stored in request.state for downstream access."""
identity_data = request.getfixturevalue(fixture_name)
auth_dep = RHIdentityAuthDependency()
header_value = create_auth_header(identity_data)
mock_request = create_request_with_header(header_value)

await auth_dep(mock_request)

assert hasattr(mock_request.state, "rh_identity_data")
rh_identity = mock_request.state.rh_identity_data
assert isinstance(rh_identity, RHIdentityData)
assert rh_identity.get_user_id() == expected_user_id
assert rh_identity.get_org_id() == "321"

@pytest.mark.asyncio
async def test_missing_header(self) -> None:
"""Test authentication fails when header is missing."""
Expand Down
Loading