Skip to content

Commit 36d57e4

Browse files
committed
refactor token creation logic to use methods in data models
1 parent 9797f17 commit 36d57e4

File tree

3 files changed

+29
-24
lines changed
  • providers/keycloak

3 files changed

+29
-24
lines changed

providers/keycloak/src/airflow/providers/keycloak/auth_manager/datamodels/token.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@
2222
from pydantic import Field, RootModel, model_validator
2323

2424
from airflow.api_fastapi.core_api.base import BaseModel, StrictBaseModel
25+
from airflow.providers.keycloak.auth_manager.services.token import (
26+
create_client_credentials_token,
27+
create_token_for,
28+
)
2529

2630

2731
class TokenResponse(BaseModel):
@@ -37,6 +41,12 @@ class TokenPasswordBody(StrictBaseModel):
3741
username: str = Field()
3842
password: str = Field()
3943

44+
def create_token(self, expiration_time_in_seconds: int) -> str:
45+
"""Create token using password grant."""
46+
return create_token_for(
47+
self.username, self.password, expiration_time_in_seconds=expiration_time_in_seconds
48+
)
49+
4050

4151
class TokenClientCredentialsBody(StrictBaseModel):
4252
"""Client Credentials Grant Token serializer for post bodies."""
@@ -45,6 +55,12 @@ class TokenClientCredentialsBody(StrictBaseModel):
4555
client_id: str = Field()
4656
client_secret: str = Field()
4757

58+
def create_token(self, expiration_time_in_seconds: int) -> str:
59+
"""Create token using client credentials grant."""
60+
return create_client_credentials_token(
61+
self.client_id, self.client_secret, expiration_time_in_seconds=expiration_time_in_seconds
62+
)
63+
4864

4965
TokenUnion = Annotated[
5066
TokenPasswordBody | TokenClientCredentialsBody,

providers/keycloak/src/airflow/providers/keycloak/auth_manager/routes/token.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,9 @@
2626
from airflow.configuration import conf
2727
from airflow.providers.keycloak.auth_manager.datamodels.token import (
2828
TokenBody,
29-
TokenClientCredentialsBody,
3029
TokenPasswordBody,
3130
TokenResponse,
3231
)
33-
from airflow.providers.keycloak.auth_manager.services.token import (
34-
create_client_credentials_token,
35-
create_token_for,
36-
)
3732

3833
log = logging.getLogger(__name__)
3934
token_router = AirflowRouter(tags=["KeycloakAuthManagerToken"])
@@ -44,16 +39,10 @@
4439
status_code=status.HTTP_201_CREATED,
4540
responses=create_openapi_http_exception_doc([status.HTTP_400_BAD_REQUEST, status.HTTP_401_UNAUTHORIZED]),
4641
)
47-
def create_token(
48-
body: TokenBody,
49-
) -> TokenResponse:
50-
credentials = body.root
51-
if isinstance(credentials, TokenPasswordBody):
52-
token = create_token_for(credentials.username, credentials.password)
53-
elif isinstance(credentials, TokenClientCredentialsBody):
54-
token = create_client_credentials_token(credentials.client_id, credentials.client_secret)
55-
else:
56-
raise ValueError("Unsupported grant_type")
42+
def create_token(body: TokenBody) -> TokenResponse:
43+
token = body.root.create_token(
44+
expiration_time_in_seconds=int(conf.getint("api_auth", "jwt_expiration_time"))
45+
)
5746
return TokenResponse(access_token=token)
5847

5948

@@ -63,9 +52,7 @@ def create_token(
6352
responses=create_openapi_http_exception_doc([status.HTTP_400_BAD_REQUEST, status.HTTP_401_UNAUTHORIZED]),
6453
)
6554
def create_token_cli(body: TokenPasswordBody) -> TokenResponse:
66-
token = create_token_for(
67-
body.username,
68-
body.password,
69-
expiration_time_in_seconds=int(conf.getint("api_auth", "jwt_cli_expiration_time")),
55+
token = body.create_token(
56+
expiration_time_in_seconds=int(conf.getint("api_auth", "jwt_cli_expiration_time"))
7057
)
7158
return TokenResponse(access_token=token)

providers/keycloak/tests/unit/keycloak/auth_manager/routes/test_token.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class TestTokenRouter:
4141
("api_auth", "jwt_expiration_time"): "10",
4242
}
4343
)
44-
@patch("airflow.providers.keycloak.auth_manager.routes.token.create_token_for")
44+
@patch("airflow.providers.keycloak.auth_manager.datamodels.token.create_token_for")
4545
def test_create_token_password_grant(self, mock_create_token_for, client, body):
4646
mock_create_token_for.return_value = self.token
4747
response = client.post(
@@ -58,7 +58,7 @@ def test_create_token_password_grant(self, mock_create_token_for, client, body):
5858
("api_auth", "jwt_expiration_time"): "10",
5959
}
6060
)
61-
@patch("airflow.providers.keycloak.auth_manager.routes.token.create_token_for")
61+
@patch("airflow.providers.keycloak.auth_manager.datamodels.token.create_token_for")
6262
def test_create_token_cli(self, mock_create_token_for, client):
6363
mock_create_token_for.return_value = self.token
6464
response = client.post(
@@ -74,7 +74,7 @@ def test_create_token_cli(self, mock_create_token_for, client):
7474
("api_auth", "jwt_expiration_time"): "10",
7575
}
7676
)
77-
@patch("airflow.providers.keycloak.auth_manager.routes.token.create_client_credentials_token")
77+
@patch("airflow.providers.keycloak.auth_manager.datamodels.token.create_client_credentials_token")
7878
def test_create_token_client_credentials(self, mock_create_client_credentials_token, client):
7979
mock_create_client_credentials_token.return_value = self.token
8080
response = client.post(
@@ -88,7 +88,9 @@ def test_create_token_client_credentials(self, mock_create_client_credentials_to
8888

8989
assert response.status_code == 201
9090
assert response.json() == {"access_token": self.token}
91-
mock_create_client_credentials_token.assert_called_once_with("client_id", "client_secret")
91+
mock_create_client_credentials_token.assert_called_once_with(
92+
"client_id", "client_secret", expiration_time_in_seconds=10
93+
)
9294

9395
@pytest.mark.parametrize(
9496
"body",
@@ -107,7 +109,7 @@ def test_create_token_client_credentials(self, mock_create_client_credentials_to
107109
("api_auth", "jwt_expiration_time"): "10",
108110
}
109111
)
110-
@patch("airflow.providers.keycloak.auth_manager.routes.token.create_client_credentials_token")
112+
@patch("airflow.providers.keycloak.auth_manager.datamodels.token.create_client_credentials_token")
111113
def test_create_token_invalid_body(self, mock_create_client_credentials_token, client, body):
112114
mock_create_client_credentials_token.return_value = self.token
113115
response = client.post(

0 commit comments

Comments
 (0)