Skip to content

Commit 12ccb5f

Browse files
feat: add Yandex Cloud Lockbox secrets backend (#36449)
* refactor: move credentials logic to utils * docs: using metadata service in Yandex.Cloud Connection * feat: Yandex Cloud Lockbox Secret Backend * docs: Yandex LockboxSecretBackend
1 parent e9a4bca commit 12ccb5f

21 files changed

Lines changed: 1816 additions & 223 deletions

File tree

airflow/providers/yandex/hooks/yandex.py

Lines changed: 40 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -16,27 +16,37 @@
1616
# under the License.
1717
from __future__ import annotations
1818

19-
import json
2019
import warnings
2120
from typing import Any
2221

2322
import yandexcloud
2423

25-
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
24+
from airflow.exceptions import AirflowProviderDeprecationWarning
2625
from airflow.hooks.base import BaseHook
26+
from airflow.providers.yandex.utils.credentials import (
27+
get_credentials,
28+
get_service_account_id,
29+
)
30+
from airflow.providers.yandex.utils.defaults import conn_name_attr, conn_type, default_conn_name, hook_name
31+
from airflow.providers.yandex.utils.fields import get_field_from_extras
32+
from airflow.providers.yandex.utils.user_agent import provider_user_agent
2733

2834

2935
class YandexCloudBaseHook(BaseHook):
3036
"""
3137
A base hook for Yandex.Cloud related tasks.
3238
33-
:param yandex_conn_id: The connection ID to use when fetching connection info.
39+
:param yandex_conn_id: The connection ID to use when fetching connection info
40+
:param connection_id: Deprecated, use yandex_conn_id instead
41+
:param default_folder_id: The folder ID to use instead of connection folder ID
42+
:param default_public_ssh_key: The key to use instead of connection key
43+
:param default_service_account_id: The service account ID to use instead of key service account ID
3444
"""
3545

36-
conn_name_attr = "yandex_conn_id"
37-
default_conn_name = "yandexcloud_default"
38-
conn_type = "yandexcloud"
39-
hook_name = "Yandex Cloud"
46+
conn_name_attr = conn_name_attr
47+
default_conn_name = default_conn_name
48+
conn_type = conn_type
49+
hook_name = hook_name
4050

4151
@classmethod
4252
def get_connection_form_widgets(cls) -> dict[str, Any]:
@@ -50,14 +60,14 @@ def get_connection_form_widgets(cls) -> dict[str, Any]:
5060
lazy_gettext("Service account auth JSON"),
5161
widget=BS3PasswordFieldWidget(),
5262
description="Service account auth JSON. Looks like "
53-
'{"id", "...", "service_account_id": "...", "private_key": "..."}. '
63+
'{"id": "...", "service_account_id": "...", "private_key": "..."}. '
5464
"Will be used instead of OAuth token and SA JSON file path field if specified.",
5565
),
5666
"service_account_json_path": StringField(
5767
lazy_gettext("Service account auth JSON file path"),
5868
widget=BS3TextFieldWidget(),
5969
description="Service account auth JSON file path. File content looks like "
60-
'{"id", "...", "service_account_id": "...", "private_key": "..."}. '
70+
'{"id": "...", "service_account_id": "...", "private_key": "..."}. '
6171
"Will be used instead of OAuth token if specified.",
6272
),
6373
"oauth": PasswordField(
@@ -75,7 +85,7 @@ def get_connection_form_widgets(cls) -> dict[str, Any]:
7585
"public_ssh_key": StringField(
7686
lazy_gettext("Public SSH key"),
7787
widget=BS3TextFieldWidget(),
78-
description="Optional. This key will be placed to all created Compute nodes"
88+
description="Optional. This key will be placed to all created Compute nodes "
7989
"to let you have a root shell there",
8090
),
8191
"endpoint": StringField(
@@ -87,30 +97,13 @@ def get_connection_form_widgets(cls) -> dict[str, Any]:
8797

8898
@classmethod
8999
def provider_user_agent(cls) -> str | None:
90-
"""Construct User-Agent from Airflow core & provider package versions."""
91-
from airflow import __version__ as airflow_version
92-
from airflow.configuration import conf
93-
from airflow.providers_manager import ProvidersManager
94-
95-
try:
96-
manager = ProvidersManager()
97-
provider_name = manager.hooks[cls.conn_type].package_name # type: ignore[union-attr]
98-
provider = manager.providers[provider_name]
99-
return " ".join(
100-
(
101-
conf.get("yandex", "sdk_user_agent_prefix", fallback=""),
102-
f"apache-airflow/{airflow_version}",
103-
f"{provider_name}/{provider.version}",
104-
)
105-
).strip()
106-
except KeyError:
107-
warnings.warn(
108-
f"Hook '{cls.hook_name}' info is not initialized in airflow.ProviderManager",
109-
UserWarning,
110-
stacklevel=2,
111-
)
112-
113-
return None
100+
warnings.warn(
101+
"Using `provider_user_agent` in `YandexCloudBaseHook` is deprecated. "
102+
"Please use it in `utils.user_agent` instead.",
103+
AirflowProviderDeprecationWarning,
104+
stacklevel=2,
105+
)
106+
return provider_user_agent()
114107

115108
@classmethod
116109
def get_ui_field_behaviour(cls) -> dict[str, Any]:
@@ -122,7 +115,7 @@ def get_ui_field_behaviour(cls) -> dict[str, Any]:
122115

123116
def __init__(
124117
self,
125-
# Connection id is deprecated. Use yandex_conn_id instead
118+
# connection_id is deprecated, use yandex_conn_id instead
126119
connection_id: str | None = None,
127120
yandex_conn_id: str | None = None,
128121
default_folder_id: str | None = None,
@@ -137,46 +130,23 @@ def __init__(
137130
AirflowProviderDeprecationWarning,
138131
stacklevel=2,
139132
)
140-
self.connection_id = yandex_conn_id or connection_id or self.default_conn_name
133+
self.connection_id = yandex_conn_id or connection_id or default_conn_name
141134
self.connection = self.get_connection(self.connection_id)
142135
self.extras = self.connection.extra_dejson
143-
credentials = self._get_credentials()
136+
credentials = get_credentials(
137+
oauth_token=self._get_field("oauth"),
138+
service_account_json=self._get_field("service_account_json"),
139+
service_account_json_path=self._get_field("service_account_json_path"),
140+
)
144141
sdk_config = self._get_endpoint()
145-
self.sdk = yandexcloud.SDK(user_agent=self.provider_user_agent(), **sdk_config, **credentials)
142+
self.sdk = yandexcloud.SDK(user_agent=provider_user_agent(), **sdk_config, **credentials)
146143
self.default_folder_id = default_folder_id or self._get_field("folder_id")
147144
self.default_public_ssh_key = default_public_ssh_key or self._get_field("public_ssh_key")
148-
self.default_service_account_id = default_service_account_id or self._get_service_account_id()
149-
self.client = self.sdk.client
150-
151-
def _get_service_account_key(self) -> dict[str, str] | None:
152-
service_account_json = self._get_field("service_account_json")
153-
service_account_json_path = self._get_field("service_account_json_path")
154-
if service_account_json_path:
155-
with open(service_account_json_path) as infile:
156-
service_account_json = infile.read()
157-
if service_account_json:
158-
return json.loads(service_account_json)
159-
return None
160-
161-
def _get_service_account_id(self) -> str | None:
162-
sa_key = self._get_service_account_key()
163-
if sa_key:
164-
return sa_key.get("service_account_id")
165-
return None
166-
167-
def _get_credentials(self) -> dict[str, Any]:
168-
oauth_token = self._get_field("oauth")
169-
if oauth_token:
170-
return {"token": oauth_token}
171-
172-
service_account_key = self._get_service_account_key()
173-
if service_account_key:
174-
return {"service_account_key": service_account_key}
175-
176-
raise AirflowException(
177-
"No credentials are found in connection. Specify either service account "
178-
"authentication JSON or user OAuth token in Yandex.Cloud connection"
145+
self.default_service_account_id = default_service_account_id or get_service_account_id(
146+
service_account_json=self._get_field("service_account_json"),
147+
service_account_json_path=self._get_field("service_account_json_path"),
179148
)
149+
self.client = self.sdk.client
180150

181151
def _get_endpoint(self) -> dict[str, str]:
182152
sdk_config = {}
@@ -186,18 +156,6 @@ def _get_endpoint(self) -> dict[str, str]:
186156
return sdk_config
187157

188158
def _get_field(self, field_name: str, default: Any = None) -> Any:
189-
"""Get field from extra, first checking short name, then for backcompat we check for prefixed name."""
190159
if not hasattr(self, "extras"):
191160
return default
192-
backcompat_prefix = "extra__yandexcloud__"
193-
if field_name.startswith("extra__"):
194-
raise ValueError(
195-
f"Got prefixed name {field_name}; please remove the '{backcompat_prefix}' prefix "
196-
"when using this method."
197-
)
198-
if field_name in self.extras:
199-
return self.extras[field_name]
200-
prefixed_name = f"{backcompat_prefix}{field_name}"
201-
if prefixed_name in self.extras:
202-
return self.extras[prefixed_name]
203-
return default
161+
return get_field_from_extras(self.extras, field_name, default)
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.

0 commit comments

Comments
 (0)