1616# under the License.
1717from __future__ import annotations
1818
19- import json
2019import warnings
2120from typing import Any
2221
2322import yandexcloud
2423
25- from airflow .exceptions import AirflowException , AirflowProviderDeprecationWarning
24+ from airflow .exceptions import AirflowProviderDeprecationWarning
2625from airflow .hooks .base import BaseHook
26+ from airflow .providers .yandex .utils .credentials import (
27+ get_credentials ,
28+ get_service_account_id ,
29+ )
30+ from airflow .providers .yandex .utils .defaults import conn_name_attr , conn_type , default_conn_name , hook_name
31+ from airflow .providers .yandex .utils .fields import get_field_from_extras
32+ from airflow .providers .yandex .utils .user_agent import provider_user_agent
2733
2834
2935class YandexCloudBaseHook (BaseHook ):
3036 """
3137 A base hook for Yandex.Cloud related tasks.
3238
33- :param yandex_conn_id: The connection ID to use when fetching connection info.
39+ :param yandex_conn_id: The connection ID to use when fetching connection info
40+ :param connection_id: Deprecated, use yandex_conn_id instead
41+ :param default_folder_id: The folder ID to use instead of connection folder ID
42+ :param default_public_ssh_key: The key to use instead of connection key
43+ :param default_service_account_id: The service account ID to use instead of key service account ID
3444 """
3545
36- conn_name_attr = "yandex_conn_id"
37- default_conn_name = "yandexcloud_default"
38- conn_type = "yandexcloud"
39- hook_name = "Yandex Cloud"
46+ conn_name_attr = conn_name_attr
47+ default_conn_name = default_conn_name
48+ conn_type = conn_type
49+ hook_name = hook_name
4050
4151 @classmethod
4252 def get_connection_form_widgets (cls ) -> dict [str , Any ]:
@@ -50,14 +60,14 @@ def get_connection_form_widgets(cls) -> dict[str, Any]:
5060 lazy_gettext ("Service account auth JSON" ),
5161 widget = BS3PasswordFieldWidget (),
5262 description = "Service account auth JSON. Looks like "
53- '{"id", "...", "service_account_id": "...", "private_key": "..."}. '
63+ '{"id": "...", "service_account_id": "...", "private_key": "..."}. '
5464 "Will be used instead of OAuth token and SA JSON file path field if specified." ,
5565 ),
5666 "service_account_json_path" : StringField (
5767 lazy_gettext ("Service account auth JSON file path" ),
5868 widget = BS3TextFieldWidget (),
5969 description = "Service account auth JSON file path. File content looks like "
60- '{"id", "...", "service_account_id": "...", "private_key": "..."}. '
70+ '{"id": "...", "service_account_id": "...", "private_key": "..."}. '
6171 "Will be used instead of OAuth token if specified." ,
6272 ),
6373 "oauth" : PasswordField (
@@ -75,7 +85,7 @@ def get_connection_form_widgets(cls) -> dict[str, Any]:
7585 "public_ssh_key" : StringField (
7686 lazy_gettext ("Public SSH key" ),
7787 widget = BS3TextFieldWidget (),
78- description = "Optional. This key will be placed to all created Compute nodes"
88+ description = "Optional. This key will be placed to all created Compute nodes "
7989 "to let you have a root shell there" ,
8090 ),
8191 "endpoint" : StringField (
@@ -87,30 +97,13 @@ def get_connection_form_widgets(cls) -> dict[str, Any]:
8797
8898 @classmethod
8999 def provider_user_agent (cls ) -> str | None :
90- """Construct User-Agent from Airflow core & provider package versions."""
91- from airflow import __version__ as airflow_version
92- from airflow .configuration import conf
93- from airflow .providers_manager import ProvidersManager
94-
95- try :
96- manager = ProvidersManager ()
97- provider_name = manager .hooks [cls .conn_type ].package_name # type: ignore[union-attr]
98- provider = manager .providers [provider_name ]
99- return " " .join (
100- (
101- conf .get ("yandex" , "sdk_user_agent_prefix" , fallback = "" ),
102- f"apache-airflow/{ airflow_version } " ,
103- f"{ provider_name } /{ provider .version } " ,
104- )
105- ).strip ()
106- except KeyError :
107- warnings .warn (
108- f"Hook '{ cls .hook_name } ' info is not initialized in airflow.ProviderManager" ,
109- UserWarning ,
110- stacklevel = 2 ,
111- )
112-
113- return None
100+ warnings .warn (
101+ "Using `provider_user_agent` in `YandexCloudBaseHook` is deprecated. "
102+ "Please use it in `utils.user_agent` instead." ,
103+ AirflowProviderDeprecationWarning ,
104+ stacklevel = 2 ,
105+ )
106+ return provider_user_agent ()
114107
115108 @classmethod
116109 def get_ui_field_behaviour (cls ) -> dict [str , Any ]:
@@ -122,7 +115,7 @@ def get_ui_field_behaviour(cls) -> dict[str, Any]:
122115
123116 def __init__ (
124117 self ,
125- # Connection id is deprecated. Use yandex_conn_id instead
118+ # connection_id is deprecated, use yandex_conn_id instead
126119 connection_id : str | None = None ,
127120 yandex_conn_id : str | None = None ,
128121 default_folder_id : str | None = None ,
@@ -137,46 +130,23 @@ def __init__(
137130 AirflowProviderDeprecationWarning ,
138131 stacklevel = 2 ,
139132 )
140- self .connection_id = yandex_conn_id or connection_id or self . default_conn_name
133+ self .connection_id = yandex_conn_id or connection_id or default_conn_name
141134 self .connection = self .get_connection (self .connection_id )
142135 self .extras = self .connection .extra_dejson
143- credentials = self ._get_credentials ()
136+ credentials = get_credentials (
137+ oauth_token = self ._get_field ("oauth" ),
138+ service_account_json = self ._get_field ("service_account_json" ),
139+ service_account_json_path = self ._get_field ("service_account_json_path" ),
140+ )
144141 sdk_config = self ._get_endpoint ()
145- self .sdk = yandexcloud .SDK (user_agent = self . provider_user_agent (), ** sdk_config , ** credentials )
142+ self .sdk = yandexcloud .SDK (user_agent = provider_user_agent (), ** sdk_config , ** credentials )
146143 self .default_folder_id = default_folder_id or self ._get_field ("folder_id" )
147144 self .default_public_ssh_key = default_public_ssh_key or self ._get_field ("public_ssh_key" )
148- self .default_service_account_id = default_service_account_id or self ._get_service_account_id ()
149- self .client = self .sdk .client
150-
151- def _get_service_account_key (self ) -> dict [str , str ] | None :
152- service_account_json = self ._get_field ("service_account_json" )
153- service_account_json_path = self ._get_field ("service_account_json_path" )
154- if service_account_json_path :
155- with open (service_account_json_path ) as infile :
156- service_account_json = infile .read ()
157- if service_account_json :
158- return json .loads (service_account_json )
159- return None
160-
161- def _get_service_account_id (self ) -> str | None :
162- sa_key = self ._get_service_account_key ()
163- if sa_key :
164- return sa_key .get ("service_account_id" )
165- return None
166-
167- def _get_credentials (self ) -> dict [str , Any ]:
168- oauth_token = self ._get_field ("oauth" )
169- if oauth_token :
170- return {"token" : oauth_token }
171-
172- service_account_key = self ._get_service_account_key ()
173- if service_account_key :
174- return {"service_account_key" : service_account_key }
175-
176- raise AirflowException (
177- "No credentials are found in connection. Specify either service account "
178- "authentication JSON or user OAuth token in Yandex.Cloud connection"
145+ self .default_service_account_id = default_service_account_id or get_service_account_id (
146+ service_account_json = self ._get_field ("service_account_json" ),
147+ service_account_json_path = self ._get_field ("service_account_json_path" ),
179148 )
149+ self .client = self .sdk .client
180150
181151 def _get_endpoint (self ) -> dict [str , str ]:
182152 sdk_config = {}
@@ -186,18 +156,6 @@ def _get_endpoint(self) -> dict[str, str]:
186156 return sdk_config
187157
188158 def _get_field (self , field_name : str , default : Any = None ) -> Any :
189- """Get field from extra, first checking short name, then for backcompat we check for prefixed name."""
190159 if not hasattr (self , "extras" ):
191160 return default
192- backcompat_prefix = "extra__yandexcloud__"
193- if field_name .startswith ("extra__" ):
194- raise ValueError (
195- f"Got prefixed name { field_name } ; please remove the '{ backcompat_prefix } ' prefix "
196- "when using this method."
197- )
198- if field_name in self .extras :
199- return self .extras [field_name ]
200- prefixed_name = f"{ backcompat_prefix } { field_name } "
201- if prefixed_name in self .extras :
202- return self .extras [prefixed_name ]
203- return default
161+ return get_field_from_extras (self .extras , field_name , default )
0 commit comments