From 7e2531b179505aba4e92dac724953d4854ddc46f Mon Sep 17 00:00:00 2001 From: ivan Date: Mon, 15 Jul 2024 16:01:27 +0300 Subject: [PATCH 1/8] Implement overriding with context manager for one provider --- that_depends/providers/base.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/that_depends/providers/base.py b/that_depends/providers/base.py index 1b99970e..c331910b 100644 --- a/that_depends/providers/base.py +++ b/that_depends/providers/base.py @@ -1,5 +1,6 @@ import abc import typing +from contextlib import contextmanager T = typing.TypeVar("T") @@ -24,6 +25,12 @@ async def __call__(self) -> T_co: def override(self, mock: object) -> None: self._override = mock + @contextmanager + def override_context(self, mock: object) -> typing.Iterator[None]: + self.override(mock) + yield + self.reset_override() + def reset_override(self) -> None: self._override = None From 5a930a88e9a1e2b7247294ee44c4f284b3dbf3ed Mon Sep 17 00:00:00 2001 From: ivan Date: Mon, 15 Jul 2024 16:01:51 +0300 Subject: [PATCH 2/8] Implement batch overriding with context manager for container --- that_depends/container.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/that_depends/container.py b/that_depends/container.py index f4c874c4..35a9977f 100644 --- a/that_depends/container.py +++ b/that_depends/container.py @@ -1,5 +1,6 @@ import inspect import typing +from contextlib import contextmanager from that_depends.providers import AbstractProvider, AbstractResource, Singleton @@ -92,3 +93,25 @@ async def resolve(cls, object_to_resolve: type[T] | typing.Callable[..., T]) -> kwargs[field_name] = await providers[field_name].async_resolve() return object_to_resolve(**kwargs) + + @classmethod + @contextmanager + def override_providers(cls, providers_for_overriding: dict[str, typing.Any]) -> typing.Iterator[None]: + current_providers = cls.get_providers() + current_provider_names = set(current_providers.keys()) + given_provider_names = set(providers_for_overriding.keys()) + + for given_name in given_provider_names: + if given_name not in current_provider_names: + msg = f"Provider with name {given_name!r} not found" + raise RuntimeError(msg) + + for provider_name, mock in providers_for_overriding.items(): + provider = current_providers[provider_name] + provider.override(mock) + + yield + + for provider_name in providers_for_overriding: + provider = current_providers[provider_name] + provider.reset_override() From c5dfe1ac3e8b21faf2c591d523301b3e72850baf Mon Sep 17 00:00:00 2001 From: ivan Date: Mon, 15 Jul 2024 16:02:06 +0300 Subject: [PATCH 3/8] Exclude tests from coverage --- pyproject.toml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 75d8f281..f1bbc3c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,6 +62,11 @@ asyncio_mode = "auto" [tool.coverage.report] exclude_also = ["if typing.TYPE_CHECKING:"] +[tool.coverage.run] +omit = [ + "*/tests/*" +] + [build-system] requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" From e7b30a4c2920b263b4279ab4b773c7d311a15299 Mon Sep 17 00:00:00 2001 From: ivan Date: Mon, 15 Jul 2024 16:02:11 +0300 Subject: [PATCH 4/8] Tests --- tests/providers/test_providers_overriding.py | 64 ++++++++++++++------ 1 file changed, 45 insertions(+), 19 deletions(-) diff --git a/tests/providers/test_providers_overriding.py b/tests/providers/test_providers_overriding.py index ccefbcfa..2deb8b98 100644 --- a/tests/providers/test_providers_overriding.py +++ b/tests/providers/test_providers_overriding.py @@ -1,5 +1,7 @@ import datetime +import pytest + from tests import container @@ -9,16 +11,20 @@ async def test_providers_overriding() -> None: async_factory_mock = datetime.datetime.fromisoformat("2025-01-01") simple_factory_mock = container.SimpleFactory(dep1="override", dep2=999) singleton_mock = container.SingletonFactory(dep1=False) - container.DIContainer.async_resource.override(async_resource_mock) - container.DIContainer.sync_resource.override(sync_resource_mock) - container.DIContainer.simple_factory.override(simple_factory_mock) - container.DIContainer.singleton.override(singleton_mock) - container.DIContainer.async_factory.override(async_factory_mock) - await container.DIContainer.simple_factory() - dependent_factory = await container.DIContainer.dependent_factory() - singleton = await container.DIContainer.singleton() - async_factory = await container.DIContainer.async_factory() + providers_for_overriding = { + "async_resource": async_resource_mock, + "sync_resource": sync_resource_mock, + "simple_factory": simple_factory_mock, + "singleton": singleton_mock, + "async_factory": async_factory_mock, + } + + with container.DIContainer.override_providers(providers_for_overriding): + await container.DIContainer.simple_factory() + dependent_factory = await container.DIContainer.dependent_factory() + singleton = await container.DIContainer.singleton() + async_factory = await container.DIContainer.async_factory() assert dependent_factory.simple_factory.dep1 == simple_factory_mock.dep1 assert dependent_factory.simple_factory.dep2 == simple_factory_mock.dep2 @@ -27,7 +33,6 @@ async def test_providers_overriding() -> None: assert singleton is singleton_mock assert async_factory is async_factory_mock - container.DIContainer.reset_override() assert (await container.DIContainer.async_resource()) != async_resource_mock @@ -36,15 +41,19 @@ async def test_providers_overriding_sync_resolve() -> None: sync_resource_mock = datetime.datetime.fromisoformat("2024-01-01") simple_factory_mock = container.SimpleFactory(dep1="override", dep2=999) singleton_mock = container.SingletonFactory(dep1=False) - container.DIContainer.async_resource.override(async_resource_mock) - container.DIContainer.sync_resource.override(sync_resource_mock) - container.DIContainer.simple_factory.override(simple_factory_mock) - container.DIContainer.singleton.override(singleton_mock) - container.DIContainer.simple_factory.sync_resolve() - await container.DIContainer.async_resource.async_resolve() - dependent_factory = container.DIContainer.dependent_factory.sync_resolve() - singleton = container.DIContainer.singleton.sync_resolve() + providers_for_overriding = { + "async_resource": async_resource_mock, + "sync_resource": sync_resource_mock, + "simple_factory": simple_factory_mock, + "singleton": singleton_mock, + } + + with container.DIContainer.override_providers(providers_for_overriding): + container.DIContainer.simple_factory.sync_resolve() + await container.DIContainer.async_resource.async_resolve() + dependent_factory = container.DIContainer.dependent_factory.sync_resolve() + singleton = container.DIContainer.singleton.sync_resolve() assert dependent_factory.simple_factory.dep1 == simple_factory_mock.dep1 assert dependent_factory.simple_factory.dep2 == simple_factory_mock.dep2 @@ -52,5 +61,22 @@ async def test_providers_overriding_sync_resolve() -> None: assert dependent_factory.async_resource == async_resource_mock assert singleton is singleton_mock - container.DIContainer.reset_override() assert container.DIContainer.sync_resource.sync_resolve() != sync_resource_mock + + +def test_providers_overriding_with_context_manager() -> None: + simple_factory_mock = container.SimpleFactory(dep1="override", dep2=999) + + with container.DIContainer.simple_factory.override_context(simple_factory_mock): + assert container.DIContainer.simple_factory.sync_resolve() is simple_factory_mock + + assert container.DIContainer.simple_factory.sync_resolve() is not simple_factory_mock + + +def test_providers_overriding_fail_with_unknown_provider() -> None: + unknown_provider_name = "unknown_provider_name" + match = f"Provider with name {unknown_provider_name!r} not found" + providers_for_overriding = {unknown_provider_name: None} + + with pytest.raises(RuntimeError, match=match), container.DIContainer.override_providers(providers_for_overriding): + ... From 94cb01d392bf115f77c244330c90fd99966d0bf1 Mon Sep 17 00:00:00 2001 From: ivan Date: Mon, 15 Jul 2024 20:19:18 +0300 Subject: [PATCH 5/8] Separate tests with batch and single provider overriding --- tests/providers/test_providers_overriding.py | 57 +++++++++++++++++++- 1 file changed, 55 insertions(+), 2 deletions(-) diff --git a/tests/providers/test_providers_overriding.py b/tests/providers/test_providers_overriding.py index 2deb8b98..916d6cc3 100644 --- a/tests/providers/test_providers_overriding.py +++ b/tests/providers/test_providers_overriding.py @@ -5,7 +5,7 @@ from tests import container -async def test_providers_overriding() -> None: +async def test_batch_providers_overriding() -> None: async_resource_mock = datetime.datetime.fromisoformat("2023-01-01") sync_resource_mock = datetime.datetime.fromisoformat("2024-01-01") async_factory_mock = datetime.datetime.fromisoformat("2025-01-01") @@ -36,7 +36,7 @@ async def test_providers_overriding() -> None: assert (await container.DIContainer.async_resource()) != async_resource_mock -async def test_providers_overriding_sync_resolve() -> None: +async def test_batch_providers_overriding_sync_resolve() -> None: async_resource_mock = datetime.datetime.fromisoformat("2023-01-01") sync_resource_mock = datetime.datetime.fromisoformat("2024-01-01") simple_factory_mock = container.SimpleFactory(dep1="override", dep2=999) @@ -80,3 +80,56 @@ def test_providers_overriding_fail_with_unknown_provider() -> None: with pytest.raises(RuntimeError, match=match), container.DIContainer.override_providers(providers_for_overriding): ... + + +async def test_providers_overriding() -> None: + async_resource_mock = datetime.datetime.fromisoformat("2023-01-01") + sync_resource_mock = datetime.datetime.fromisoformat("2024-01-01") + async_factory_mock = datetime.datetime.fromisoformat("2025-01-01") + simple_factory_mock = container.SimpleFactory(dep1="override", dep2=999) + singleton_mock = container.SingletonFactory(dep1=False) + container.DIContainer.async_resource.override(async_resource_mock) + container.DIContainer.sync_resource.override(sync_resource_mock) + container.DIContainer.simple_factory.override(simple_factory_mock) + container.DIContainer.singleton.override(singleton_mock) + container.DIContainer.async_factory.override(async_factory_mock) + + await container.DIContainer.simple_factory() + dependent_factory = await container.DIContainer.dependent_factory() + singleton = await container.DIContainer.singleton() + async_factory = await container.DIContainer.async_factory() + + assert dependent_factory.simple_factory.dep1 == simple_factory_mock.dep1 + assert dependent_factory.simple_factory.dep2 == simple_factory_mock.dep2 + assert dependent_factory.sync_resource == sync_resource_mock + assert dependent_factory.async_resource == async_resource_mock + assert singleton is singleton_mock + assert async_factory is async_factory_mock + + container.DIContainer.reset_override() + assert (await container.DIContainer.async_resource()) != async_resource_mock + + +async def test_providers_overriding_sync_resolve() -> None: + async_resource_mock = datetime.datetime.fromisoformat("2023-01-01") + sync_resource_mock = datetime.datetime.fromisoformat("2024-01-01") + simple_factory_mock = container.SimpleFactory(dep1="override", dep2=999) + singleton_mock = container.SingletonFactory(dep1=False) + container.DIContainer.async_resource.override(async_resource_mock) + container.DIContainer.sync_resource.override(sync_resource_mock) + container.DIContainer.simple_factory.override(simple_factory_mock) + container.DIContainer.singleton.override(singleton_mock) + + container.DIContainer.simple_factory.sync_resolve() + await container.DIContainer.async_resource.async_resolve() + dependent_factory = container.DIContainer.dependent_factory.sync_resolve() + singleton = container.DIContainer.singleton.sync_resolve() + + assert dependent_factory.simple_factory.dep1 == simple_factory_mock.dep1 + assert dependent_factory.simple_factory.dep2 == simple_factory_mock.dep2 + assert dependent_factory.sync_resource == sync_resource_mock + assert dependent_factory.async_resource == async_resource_mock + assert singleton is singleton_mock + + container.DIContainer.reset_override() + assert container.DIContainer.sync_resource.sync_resolve() != sync_resource_mock From 60173781eeec1849a48ba46f468f7a271e3312f9 Mon Sep 17 00:00:00 2001 From: ivan Date: Mon, 15 Jul 2024 20:19:48 +0300 Subject: [PATCH 6/8] Remove omit tests from coverage --- pyproject.toml | 5 ----- 1 file changed, 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f1bbc3c6..75d8f281 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,11 +62,6 @@ asyncio_mode = "auto" [tool.coverage.report] exclude_also = ["if typing.TYPE_CHECKING:"] -[tool.coverage.run] -omit = [ - "*/tests/*" -] - [build-system] requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" From 0b39b9a818092b46f0946fb95fe2a1f94a3785c5 Mon Sep 17 00:00:00 2001 From: ivan Date: Mon, 15 Jul 2024 20:34:51 +0300 Subject: [PATCH 7/8] Wrap yield with try-finally --- that_depends/container.py | 11 ++++++----- that_depends/providers/base.py | 6 ++++-- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/that_depends/container.py b/that_depends/container.py index 35a9977f..392efa47 100644 --- a/that_depends/container.py +++ b/that_depends/container.py @@ -110,8 +110,9 @@ def override_providers(cls, providers_for_overriding: dict[str, typing.Any]) -> provider = current_providers[provider_name] provider.override(mock) - yield - - for provider_name in providers_for_overriding: - provider = current_providers[provider_name] - provider.reset_override() + try: + yield + finally: + for provider_name in providers_for_overriding: + provider = current_providers[provider_name] + provider.reset_override() diff --git a/that_depends/providers/base.py b/that_depends/providers/base.py index c331910b..353d1a7b 100644 --- a/that_depends/providers/base.py +++ b/that_depends/providers/base.py @@ -28,8 +28,10 @@ def override(self, mock: object) -> None: @contextmanager def override_context(self, mock: object) -> typing.Iterator[None]: self.override(mock) - yield - self.reset_override() + try: + yield + finally: + self.reset_override() def reset_override(self) -> None: self._override = None From 792bda620d7f86739ce4365e6a3f7a481d306070 Mon Sep 17 00:00:00 2001 From: ivan Date: Mon, 15 Jul 2024 20:40:39 +0300 Subject: [PATCH 8/8] Add no cover for line with ellipsis --- tests/providers/test_providers_overriding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/providers/test_providers_overriding.py b/tests/providers/test_providers_overriding.py index 916d6cc3..a6dd0b5e 100644 --- a/tests/providers/test_providers_overriding.py +++ b/tests/providers/test_providers_overriding.py @@ -79,7 +79,7 @@ def test_providers_overriding_fail_with_unknown_provider() -> None: providers_for_overriding = {unknown_provider_name: None} with pytest.raises(RuntimeError, match=match), container.DIContainer.override_providers(providers_for_overriding): - ... + ... # pragma: no cover async def test_providers_overriding() -> None: