Skip to content

Commit 001480c

Browse files
Test canvas related code in both eager and async mode
1 parent c1313b5 commit 001480c

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

tests/test_canvas.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,38 @@
33
from celery import chain, group, chord
44
from celery.worker import WorkController
55

6+
from _pytest.fixtures import SubRequest
7+
import pytest
8+
from pytest_django.fixtures import SettingsWrapper
9+
610
from demo.factories import AddToJobFactory, SumJobFactory, ValueJobFactory
711
from demo.models import AddToJob, SumJob, ValueJob
812

913
pytest_plugins = ("celery.contrib.pytest",)
1014

1115

12-
def test_chain(transactional_db: None, celery_worker: WorkController) -> None:
16+
@pytest.fixture(params=[True, False], ids=["eager", "async"])
17+
def execution_mode(request: SubRequest, settings: SettingsWrapper) -> None:
18+
if request.param:
19+
settings.CELERY_TASK_ALWAYS_EAGER = True
20+
settings.CELERY_TASK_STORE_EAGER_RESULT = True
21+
else:
22+
settings.CELERY_TASK_ALWAYS_EAGER = False
23+
settings.CELERY_TASK_STORE_EAGER_RESULT = False
24+
25+
26+
def test_chain(execution_mode: None, transactional_db: None, celery_worker: WorkController) -> None:
1327
value_job = cast(ValueJob, ValueJobFactory(value=5))
1428
add_to_job = cast(AddToJob, AddToJobFactory(value=15))
1529
assert chain(value_job.s(), add_to_job.s())().get() == value_job.value + add_to_job.value
1630

1731

18-
def test_group(transactional_db: None, celery_worker: WorkController) -> None:
32+
def test_group(execution_mode: None, transactional_db: None, celery_worker: WorkController) -> None:
1933
value_jobs = [cast(ValueJob, ValueJobFactory(value=i)) for i in range(1, 4)]
2034
assert group([job.s() for job in value_jobs])().get() == [1, 2, 3]
2135

2236

23-
def test_chord(transactional_db: None, celery_worker: WorkController) -> None:
37+
def test_chord(execution_mode: None, transactional_db: None, celery_worker: WorkController) -> None:
2438
value_jobs = [cast(ValueJob, ValueJobFactory(value=i)) for i in range(1, 4)]
2539
sum_job = cast(SumJob, SumJobFactory())
2640
assert chord([job.s() for job in value_jobs])(sum_job.s()).get() == 6

0 commit comments

Comments
 (0)