Skip to content

Commit ede34eb

Browse files
authored
Allow custom api versions in MSGraphAsyncOperator (#41331)
1 parent 1fb776f commit ede34eb

13 files changed

Lines changed: 613 additions & 466 deletions

File tree

airflow/providers/microsoft/azure/hooks/msgraph.py

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ class KiotaRequestAdapterHook(BaseHook):
9696
:param timeout: The HTTP timeout being used by the KiotaRequestAdapter (default is None).
9797
When no timeout is specified or set to None then no HTTP timeout is applied on each request.
9898
:param proxies: A Dict defining the HTTP proxies to be used (default is None).
99+
:param host: The host to be used (default is "https://graph.microsoft.com").
100+
:param scopes: The scopes to be used (default is ["https://graph.microsoft.com/.default"]).
99101
:param api_version: The API version of the Microsoft Graph API to be used (default is v1).
100102
You can pass an enum named APIVersion which has 2 possible members v1 and beta,
101103
or you can pass a string as "v1.0" or "beta".
@@ -123,27 +125,22 @@ def __init__(
123125
self._api_version = self.resolve_api_version_from_value(api_version)
124126

125127
@property
126-
def api_version(self) -> APIVersion:
128+
def api_version(self) -> str | None:
127129
self.get_conn() # Make sure config has been loaded through get_conn to have correct api version!
128130
return self._api_version
129131

130132
@staticmethod
131133
def resolve_api_version_from_value(
132-
api_version: APIVersion | str, default: APIVersion | None = None
133-
) -> APIVersion:
134+
api_version: APIVersion | str, default: str | None = None
135+
) -> str | None:
134136
if isinstance(api_version, APIVersion):
135-
return api_version
136-
return next(
137-
filter(lambda version: version.value == api_version, APIVersion),
138-
default,
139-
)
137+
return api_version.value
138+
return api_version or default
140139

141-
def get_api_version(self, config: dict) -> APIVersion:
142-
if self._api_version is None:
143-
return self.resolve_api_version_from_value(
144-
api_version=config.get("api_version"), default=APIVersion.v1
145-
)
146-
return self._api_version
140+
def get_api_version(self, config: dict) -> str:
141+
return self._api_version or self.resolve_api_version_from_value(
142+
config.get("api_version"), APIVersion.v1.value
143+
) # type: ignore
147144

148145
def get_host(self, connection: Connection) -> str:
149146
if connection.schema and connection.host:
@@ -169,15 +166,15 @@ def to_httpx_proxies(cls, proxies: dict) -> dict:
169166
return proxies
170167

171168
def to_msal_proxies(self, authority: str | None, proxies: dict):
172-
self.log.info("authority: %s", authority)
169+
self.log.debug("authority: %s", authority)
173170
if authority:
174171
no_proxies = proxies.get("no")
175-
self.log.info("no_proxies: %s", no_proxies)
172+
self.log.debug("no_proxies: %s", no_proxies)
176173
if no_proxies:
177174
for url in no_proxies.split(","):
178175
self.log.info("url: %s", url)
179176
domain_name = urlparse(url).path.replace("*", "")
180-
self.log.info("domain_name: %s", domain_name)
177+
self.log.debug("domain_name: %s", domain_name)
181178
if authority.endswith(domain_name):
182179
return None
183180
return proxies
@@ -193,10 +190,10 @@ def get_conn(self) -> RequestAdapter:
193190
client_id = connection.login
194191
client_secret = connection.password
195192
config = connection.extra_dejson if connection.extra else {}
196-
tenant_id = config.get("tenant_id")
193+
tenant_id = config.get("tenant_id") or config.get("tenantId")
197194
api_version = self.get_api_version(config)
198195
host = self.get_host(connection)
199-
base_url = config.get("base_url", urljoin(host, api_version.value))
196+
base_url = config.get("base_url", urljoin(host, api_version))
200197
authority = config.get("authority")
201198
proxies = self.proxies or config.get("proxies", {})
202199
msal_proxies = self.to_msal_proxies(authority=authority, proxies=proxies)
@@ -209,15 +206,15 @@ def get_conn(self) -> RequestAdapter:
209206

210207
self.log.info(
211208
"Creating Microsoft Graph SDK client %s for conn_id: %s",
212-
api_version.value,
209+
api_version,
213210
self.conn_id,
214211
)
215212
self.log.info("Host: %s", host)
216213
self.log.info("Base URL: %s", base_url)
217214
self.log.info("Tenant id: %s", tenant_id)
218215
self.log.info("Client id: %s", client_id)
219216
self.log.info("Client secret: %s", client_secret)
220-
self.log.info("API version: %s", api_version.value)
217+
self.log.info("API version: %s", api_version)
221218
self.log.info("Scope: %s", scopes)
222219
self.log.info("Verify: %s", verify)
223220
self.log.info("Timeout: %s", self.timeout)
@@ -238,17 +235,17 @@ def get_conn(self) -> RequestAdapter:
238235
connection_verify=verify,
239236
)
240237
http_client = GraphClientFactory.create_with_default_middleware(
241-
api_version=api_version,
238+
api_version=api_version, # type: ignore
242239
client=httpx.AsyncClient(
243240
proxies=httpx_proxies,
244241
timeout=Timeout(timeout=self.timeout),
245242
verify=verify,
246243
trust_env=trust_env,
247244
),
248-
host=host,
245+
host=host, # type: ignore
249246
)
250247
auth_provider = AzureIdentityAuthenticationProvider(
251-
credentials=credentials,
248+
credentials=credentials, # type: ignore
252249
scopes=scopes,
253250
allowed_hosts=allowed_hosts,
254251
)
@@ -295,7 +292,7 @@ async def run(
295292
error_map=self.error_mapping(),
296293
)
297294

298-
self.log.info("response: %s", response)
295+
self.log.debug("response: %s", response)
299296

300297
return response
301298

airflow/providers/microsoft/azure/operators/msgraph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def __init__(
9999
key: str = XCOM_RETURN_KEY,
100100
timeout: float | None = None,
101101
proxies: dict | None = None,
102-
api_version: APIVersion | None = None,
102+
api_version: APIVersion | str | None = None,
103103
pagination_function: Callable[[MSGraphAsyncOperator, dict], tuple[str, dict]] | None = None,
104104
result_processor: Callable[[Context, Any], Any] = lambda context, result: result,
105105
serializer: type[ResponseSerializer] = ResponseSerializer,

airflow/providers/microsoft/azure/operators/powerbi.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def __init__(
7676
conn_id: str = PowerBIHook.default_conn_name,
7777
timeout: float = 60 * 60 * 24 * 7,
7878
proxies: dict | None = None,
79-
api_version: APIVersion | None = None,
79+
api_version: APIVersion | str | None = None,
8080
check_interval: int = 60,
8181
**kwargs,
8282
) -> None:
@@ -89,6 +89,14 @@ def __init__(
8989
self.timeout = timeout
9090
self.check_interval = check_interval
9191

92+
@property
93+
def proxies(self) -> dict | None:
94+
return self.hook.proxies
95+
96+
@property
97+
def api_version(self) -> str | None:
98+
return self.hook.api_version
99+
92100
def execute(self, context: Context):
93101
"""Refresh the Power BI Dataset."""
94102
if self.wait_for_termination:
@@ -98,6 +106,8 @@ def execute(self, context: Context):
98106
group_id=self.group_id,
99107
dataset_id=self.dataset_id,
100108
timeout=self.timeout,
109+
proxies=self.proxies,
110+
api_version=self.api_version,
101111
check_interval=self.check_interval,
102112
wait_for_termination=self.wait_for_termination,
103113
),

airflow/providers/microsoft/azure/sensors/msgraph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def __init__(
8282
data: dict[str, Any] | str | BytesIO | None = None,
8383
conn_id: str = KiotaRequestAdapterHook.default_conn_name,
8484
proxies: dict | None = None,
85-
api_version: APIVersion | None = None,
85+
api_version: APIVersion | str | None = None,
8686
event_processor: Callable[[Context, Any], bool] = lambda context, e: e.get("status") == "Succeeded",
8787
result_processor: Callable[[Context, Any], Any] = lambda context, result: result,
8888
serializer: type[ResponseSerializer] = ResponseSerializer,

airflow/providers/microsoft/azure/triggers/msgraph.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def __init__(
122122
conn_id: str = KiotaRequestAdapterHook.default_conn_name,
123123
timeout: float | None = None,
124124
proxies: dict | None = None,
125-
api_version: APIVersion | None = None,
125+
api_version: APIVersion | str | None = None,
126126
serializer: type[ResponseSerializer] = ResponseSerializer,
127127
):
128128
super().__init__()
@@ -152,14 +152,13 @@ def resolve_type(cls, value: str | type, default) -> type:
152152

153153
def serialize(self) -> tuple[str, dict[str, Any]]:
154154
"""Serialize the HttpTrigger arguments and classpath."""
155-
api_version = self.api_version.value if self.api_version else None
156155
return (
157156
f"{self.__class__.__module__}.{self.__class__.__name__}",
158157
{
159158
"conn_id": self.conn_id,
160159
"timeout": self.timeout,
161160
"proxies": self.proxies,
162-
"api_version": api_version,
161+
"api_version": self.api_version,
163162
"serializer": f"{self.serializer.__class__.__module__}.{self.serializer.__class__.__name__}",
164163
"url": self.url,
165164
"path_parameters": self.path_parameters,
@@ -188,7 +187,7 @@ def proxies(self) -> dict | None:
188187
return self.hook.proxies
189188

190189
@property
191-
def api_version(self) -> APIVersion:
190+
def api_version(self) -> APIVersion | str:
192191
return self.hook.api_version
193192

194193
async def run(self) -> AsyncIterator[TriggerEvent]:

airflow/providers/microsoft/azure/triggers/powerbi.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def __init__(
5858
group_id: str,
5959
timeout: float = 60 * 60 * 24 * 7,
6060
proxies: dict | None = None,
61-
api_version: APIVersion | None = None,
61+
api_version: APIVersion | str | None = None,
6262
check_interval: int = 60,
6363
wait_for_termination: bool = True,
6464
):
@@ -72,13 +72,12 @@ def __init__(
7272

7373
def serialize(self):
7474
"""Serialize the trigger instance."""
75-
api_version = self.api_version.value if self.api_version else None
7675
return (
7776
"airflow.providers.microsoft.azure.triggers.powerbi.PowerBITrigger",
7877
{
7978
"conn_id": self.conn_id,
8079
"proxies": self.proxies,
81-
"api_version": api_version,
80+
"api_version": self.api_version,
8281
"dataset_id": self.dataset_id,
8382
"group_id": self.group_id,
8483
"timeout": self.timeout,
@@ -96,7 +95,7 @@ def proxies(self) -> dict | None:
9695
return self.hook.proxies
9796

9897
@property
99-
def api_version(self) -> APIVersion:
98+
def api_version(self) -> APIVersion | str:
10099
return self.hook.api_version
101100

102101
async def run(self) -> AsyncIterator[TriggerEvent]:

docs/apache-airflow-providers-microsoft-azure/operators/msgraph.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,14 @@ Below is an example of using this operator to refresh PowerBI dataset.
7272
:start-after: [START howto_operator_powerbi_refresh_dataset]
7373
:end-before: [END howto_operator_powerbi_refresh_dataset]
7474

75+
Below is an example of using this operator to create an item schedule in Fabric.
76+
77+
.. exampleinclude:: /../../tests/system/providers/microsoft/azure/example_msfabric.py
78+
:language: python
79+
:dedent: 0
80+
:start-after: [START howto_operator_ms_fabric_create_item_schedule]
81+
:end-before: [END howto_operator_ms_fabric_create_item_schedule]
82+
7583

7684
Reference
7785
---------
@@ -80,3 +88,4 @@ For further information, look at:
8088

8189
* `Use the Microsoft Graph API <https://learn.microsoft.com/en-us/graph/use-the-api/>`__
8290
* `Using the Power BI REST APIs <https://learn.microsoft.com/en-us/rest/api/power-bi/>`__
91+
* `Using the Fabric REST APIs <https://learn.microsoft.com/en-us/rest/api/fabric/articles/using-fabric-apis/>`__

0 commit comments

Comments
 (0)