|
33 | 33 | from airflow.models.baseoperator import BaseOperator |
34 | 34 | from airflow.models.taskinstance import TaskInstance, TaskInstanceKey |
35 | 35 | from airflow.utils import timezone |
36 | | -from airflow.utils.state import State |
| 36 | +from airflow.utils.state import State, TaskInstanceState |
37 | 37 |
|
38 | 38 |
|
39 | 39 | def test_supports_sentry(): |
@@ -363,3 +363,54 @@ def test_running_retry_attempt_type(loop_duration, total_tries): |
363 | 363 | assert a.elapsed > min_seconds_for_test |
364 | 364 | assert a.total_tries == total_tries |
365 | 365 | assert a.tries_after_min == 1 |
| 366 | + |
| 367 | + |
| 368 | +def test_state_fail(): |
| 369 | + executor = BaseExecutor() |
| 370 | + key = TaskInstanceKey("my_dag1", "my_task1", timezone.utcnow(), 1) |
| 371 | + executor.running.add(key) |
| 372 | + info = "info" |
| 373 | + executor.fail(key, info=info) |
| 374 | + assert not executor.running |
| 375 | + assert executor.event_buffer[key] == (TaskInstanceState.FAILED, info) |
| 376 | + |
| 377 | + |
| 378 | +def test_state_success(): |
| 379 | + executor = BaseExecutor() |
| 380 | + key = TaskInstanceKey("my_dag1", "my_task1", timezone.utcnow(), 1) |
| 381 | + executor.running.add(key) |
| 382 | + info = "info" |
| 383 | + executor.success(key, info=info) |
| 384 | + assert not executor.running |
| 385 | + assert executor.event_buffer[key] == (TaskInstanceState.SUCCESS, info) |
| 386 | + |
| 387 | + |
| 388 | +def test_state_queued(): |
| 389 | + executor = BaseExecutor() |
| 390 | + key = TaskInstanceKey("my_dag1", "my_task1", timezone.utcnow(), 1) |
| 391 | + executor.running.add(key) |
| 392 | + info = "info" |
| 393 | + executor.queued(key, info=info) |
| 394 | + assert not executor.running |
| 395 | + assert executor.event_buffer[key] == (TaskInstanceState.QUEUED, info) |
| 396 | + |
| 397 | + |
| 398 | +def test_state_generic(): |
| 399 | + executor = BaseExecutor() |
| 400 | + key = TaskInstanceKey("my_dag1", "my_task1", timezone.utcnow(), 1) |
| 401 | + executor.running.add(key) |
| 402 | + info = "info" |
| 403 | + executor.queued(key, info=info) |
| 404 | + assert not executor.running |
| 405 | + assert executor.event_buffer[key] == (TaskInstanceState.QUEUED, info) |
| 406 | + |
| 407 | + |
| 408 | +def test_state_running(): |
| 409 | + executor = BaseExecutor() |
| 410 | + key = TaskInstanceKey("my_dag1", "my_task1", timezone.utcnow(), 1) |
| 411 | + executor.running.add(key) |
| 412 | + info = "info" |
| 413 | + executor.running_state(key, info=info) |
| 414 | + # Running state should not remove a command as running |
| 415 | + assert executor.running |
| 416 | + assert executor.event_buffer[key] == (TaskInstanceState.RUNNING, info) |
0 commit comments