Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions tests/providers/test_providers_overriding.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,87 @@
import datetime

import pytest

from tests import container


async def test_batch_providers_overriding() -> None:
async_resource_mock = datetime.datetime.fromisoformat("2023-01-01")
sync_resource_mock = datetime.datetime.fromisoformat("2024-01-01")
async_factory_mock = datetime.datetime.fromisoformat("2025-01-01")
simple_factory_mock = container.SimpleFactory(dep1="override", dep2=999)
singleton_mock = container.SingletonFactory(dep1=False)

providers_for_overriding = {
"async_resource": async_resource_mock,
"sync_resource": sync_resource_mock,
"simple_factory": simple_factory_mock,
"singleton": singleton_mock,
"async_factory": async_factory_mock,
}

with container.DIContainer.override_providers(providers_for_overriding):
await container.DIContainer.simple_factory()
dependent_factory = await container.DIContainer.dependent_factory()
singleton = await container.DIContainer.singleton()
async_factory = await container.DIContainer.async_factory()

assert dependent_factory.simple_factory.dep1 == simple_factory_mock.dep1
assert dependent_factory.simple_factory.dep2 == simple_factory_mock.dep2
assert dependent_factory.sync_resource == sync_resource_mock
assert dependent_factory.async_resource == async_resource_mock
assert singleton is singleton_mock
assert async_factory is async_factory_mock

assert (await container.DIContainer.async_resource()) != async_resource_mock


async def test_batch_providers_overriding_sync_resolve() -> None:
async_resource_mock = datetime.datetime.fromisoformat("2023-01-01")
sync_resource_mock = datetime.datetime.fromisoformat("2024-01-01")
simple_factory_mock = container.SimpleFactory(dep1="override", dep2=999)
singleton_mock = container.SingletonFactory(dep1=False)

providers_for_overriding = {
"async_resource": async_resource_mock,
"sync_resource": sync_resource_mock,
"simple_factory": simple_factory_mock,
"singleton": singleton_mock,
}

with container.DIContainer.override_providers(providers_for_overriding):
container.DIContainer.simple_factory.sync_resolve()
await container.DIContainer.async_resource.async_resolve()
dependent_factory = container.DIContainer.dependent_factory.sync_resolve()
singleton = container.DIContainer.singleton.sync_resolve()

assert dependent_factory.simple_factory.dep1 == simple_factory_mock.dep1
assert dependent_factory.simple_factory.dep2 == simple_factory_mock.dep2
assert dependent_factory.sync_resource == sync_resource_mock
assert dependent_factory.async_resource == async_resource_mock
assert singleton is singleton_mock

assert container.DIContainer.sync_resource.sync_resolve() != sync_resource_mock


def test_providers_overriding_with_context_manager() -> None:
simple_factory_mock = container.SimpleFactory(dep1="override", dep2=999)

with container.DIContainer.simple_factory.override_context(simple_factory_mock):
assert container.DIContainer.simple_factory.sync_resolve() is simple_factory_mock

assert container.DIContainer.simple_factory.sync_resolve() is not simple_factory_mock


def test_providers_overriding_fail_with_unknown_provider() -> None:
unknown_provider_name = "unknown_provider_name"
match = f"Provider with name {unknown_provider_name!r} not found"
providers_for_overriding = {unknown_provider_name: None}

with pytest.raises(RuntimeError, match=match), container.DIContainer.override_providers(providers_for_overriding):
... # pragma: no cover


async def test_providers_overriding() -> None:
async_resource_mock = datetime.datetime.fromisoformat("2023-01-01")
sync_resource_mock = datetime.datetime.fromisoformat("2024-01-01")
Expand Down
24 changes: 24 additions & 0 deletions that_depends/container.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import inspect
import typing
from contextlib import contextmanager

from that_depends.providers import AbstractProvider, AbstractResource, Singleton

Expand Down Expand Up @@ -92,3 +93,26 @@ async def resolve(cls, object_to_resolve: type[T] | typing.Callable[..., T]) ->
kwargs[field_name] = await providers[field_name].async_resolve()

return object_to_resolve(**kwargs)

@classmethod
@contextmanager
def override_providers(cls, providers_for_overriding: dict[str, typing.Any]) -> typing.Iterator[None]:
current_providers = cls.get_providers()
current_provider_names = set(current_providers.keys())
given_provider_names = set(providers_for_overriding.keys())

for given_name in given_provider_names:
if given_name not in current_provider_names:
msg = f"Provider with name {given_name!r} not found"
raise RuntimeError(msg)

for provider_name, mock in providers_for_overriding.items():
provider = current_providers[provider_name]
provider.override(mock)

try:
yield
finally:
for provider_name in providers_for_overriding:
provider = current_providers[provider_name]
provider.reset_override()
9 changes: 9 additions & 0 deletions that_depends/providers/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import abc
import typing
from contextlib import contextmanager


T = typing.TypeVar("T")
Expand All @@ -24,6 +25,14 @@ async def __call__(self) -> T_co:
def override(self, mock: object) -> None:
self._override = mock

@contextmanager
def override_context(self, mock: object) -> typing.Iterator[None]:
self.override(mock)
try:
yield
finally:
self.reset_override()

def reset_override(self) -> None:
self._override = None

Expand Down