|
3 | 3 | from celery import chain, group, chord |
4 | 4 | from celery.worker import WorkController |
5 | 5 |
|
| 6 | +from _pytest.fixtures import SubRequest |
| 7 | +import pytest |
| 8 | +from pytest_django.fixtures import SettingsWrapper |
| 9 | + |
6 | 10 | from demo.factories import AddToJobFactory, SumJobFactory, ValueJobFactory |
7 | 11 | from demo.models import AddToJob, SumJob, ValueJob |
8 | 12 |
|
9 | 13 | pytest_plugins = ("celery.contrib.pytest",) |
10 | 14 |
|
11 | 15 |
|
12 | | -def test_chain(transactional_db: None, celery_worker: WorkController) -> None: |
| 16 | +@pytest.fixture(params=[True, False], ids=["eager", "async"]) |
| 17 | +def execution_mode(request: SubRequest, settings: SettingsWrapper) -> None: |
| 18 | + if request.param: |
| 19 | + settings.CELERY_TASK_ALWAYS_EAGER = True |
| 20 | + settings.CELERY_TASK_STORE_EAGER_RESULT = True |
| 21 | + else: |
| 22 | + settings.CELERY_TASK_ALWAYS_EAGER = False |
| 23 | + settings.CELERY_TASK_STORE_EAGER_RESULT = False |
| 24 | + |
| 25 | + |
| 26 | +def test_chain(execution_mode: None, transactional_db: None, celery_worker: WorkController) -> None: |
13 | 27 | value_job = cast(ValueJob, ValueJobFactory(value=5)) |
14 | 28 | add_to_job = cast(AddToJob, AddToJobFactory(value=15)) |
15 | 29 | assert chain(value_job.s(), add_to_job.s())().get() == value_job.value + add_to_job.value |
16 | 30 |
|
17 | 31 |
|
18 | | -def test_group(transactional_db: None, celery_worker: WorkController) -> None: |
| 32 | +def test_group(execution_mode: None, transactional_db: None, celery_worker: WorkController) -> None: |
19 | 33 | value_jobs = [cast(ValueJob, ValueJobFactory(value=i)) for i in range(1, 4)] |
20 | 34 | assert group([job.s() for job in value_jobs])().get() == [1, 2, 3] |
21 | 35 |
|
22 | 36 |
|
23 | | -def test_chord(transactional_db: None, celery_worker: WorkController) -> None: |
| 37 | +def test_chord(execution_mode: None, transactional_db: None, celery_worker: WorkController) -> None: |
24 | 38 | value_jobs = [cast(ValueJob, ValueJobFactory(value=i)) for i in range(1, 4)] |
25 | 39 | sum_job = cast(SumJob, SumJobFactory()) |
26 | 40 | assert chord([job.s() for job in value_jobs])(sum_job.s()).get() == 6 |
0 commit comments