diff --git a/src/attr/_compat.py b/src/attr/_compat.py index 35a85a3fa..cc6327ec8 100644 --- a/src/attr/_compat.py +++ b/src/attr/_compat.py @@ -9,6 +9,7 @@ import warnings from collections.abc import Mapping, Sequence # noqa +from typing import _GenericAlias PYPY = platform.python_implementation() == "PyPy" @@ -174,3 +175,10 @@ def func(): # don't have a direct reference to the thread-local in their globals dict. # If they have such a reference, it breaks cloudpickle. repr_context = threading.local() + + +def get_generic_base(cl): + """If this is a generic class (A[str]), return the generic base for it.""" + if cl.__class__ is _GenericAlias: + return cl.__origin__ + return None diff --git a/src/attr/_funcs.py b/src/attr/_funcs.py index 518be16eb..6fa2456dc 100644 --- a/src/attr/_funcs.py +++ b/src/attr/_funcs.py @@ -3,6 +3,7 @@ import copy +from ._compat import get_generic_base from ._make import NOTHING, _obj_setattr, fields from .exceptions import AttrsAttributeNotFoundError @@ -296,7 +297,19 @@ def has(cls): :rtype: bool """ - return getattr(cls, "__attrs_attrs__", None) is not None + attrs = getattr(cls, "__attrs_attrs__", None) + if attrs is not None: + return True + + # No attrs, maybe it's a specialized generic (A[str])? + generic_base = get_generic_base(cls) + if generic_base is not None: + generic_attrs = getattr(generic_base, "__attrs_attrs__", None) + if generic_attrs is not None: + # Stick it on here for speed next time. + cls.__attrs_attrs__ = generic_attrs + return generic_attrs is not None + return False def assoc(inst, **changes): diff --git a/src/attr/_make.py b/src/attr/_make.py index 014b7bc23..d72f738ee 100644 --- a/src/attr/_make.py +++ b/src/attr/_make.py @@ -12,7 +12,12 @@ # We need to import _compat itself in addition to the _compat members to avoid # having the thread-local in the globals here. from . import _compat, _config, setters -from ._compat import PY310, _AnnotationExtractor, set_closure_cell +from ._compat import ( + PY310, + _AnnotationExtractor, + get_generic_base, + set_closure_cell, +) from .exceptions import ( DefaultAlreadySetError, FrozenInstanceError, @@ -1918,12 +1923,26 @@ def fields(cls): .. versionchanged:: 16.2.0 Returned tuple allows accessing the fields by name. + .. versionchanged:: 23.1.0 Add support for generic classes. """ - if not isinstance(cls, type): + generic_base = get_generic_base(cls) + + if generic_base is None and not isinstance(cls, type): raise TypeError("Passed object must be a class.") + attrs = getattr(cls, "__attrs_attrs__", None) + if attrs is None: + if generic_base is not None: + attrs = getattr(generic_base, "__attrs_attrs__", None) + if attrs is not None: + # Even though this is global state, stick it on here to speed + # it up. We rely on `cls` being cached for this to be + # efficient. + cls.__attrs_attrs__ = attrs + return attrs raise NotAnAttrsClassError(f"{cls!r} is not an attrs-decorated class.") + return attrs diff --git a/tests/test_funcs.py b/tests/test_funcs.py index f77bfd4ab..9f6845647 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -6,6 +6,7 @@ from collections import OrderedDict +from typing import Generic, TypeVar import pytest @@ -418,6 +419,37 @@ def test_negative(self): """ assert not has(object) + def test_generics(self): + """ + Works with generic classes. + """ + T = TypeVar("T") + + @attr.define + class A(Generic[T]): + a: T + + assert has(A) + + assert has(A[str]) + # Verify twice, since there's caching going on. + assert has(A[str]) + + def test_generics_negative(self): + """ + Returns `False` on non-decorated generic classes. + """ + T = TypeVar("T") + + class A(Generic[T]): + a: T + + assert not has(A) + + assert not has(A[str]) + # Verify twice, since there's caching going on. + assert not has(A[str]) + class TestAssoc: """ diff --git a/tests/test_make.py b/tests/test_make.py index 79373d3ef..127de5d97 100644 --- a/tests/test_make.py +++ b/tests/test_make.py @@ -13,6 +13,7 @@ import sys from operator import attrgetter +from typing import Generic, TypeVar import pytest @@ -1114,6 +1115,22 @@ def test_handler_non_attrs_class(self): f"{object!r} is not an attrs-decorated class." ) == e.value.args[0] + def test_handler_non_attrs_generic_class(self): + """ + Raises `ValueError` if passed a non-*attrs* generic class. + """ + T = TypeVar("T") + + class B(Generic[T]): + pass + + with pytest.raises(NotAnAttrsClassError) as e: + fields(B[str]) + + assert ( + f"{B[str]!r} is not an attrs-decorated class." + ) == e.value.args[0] + @given(simple_classes()) def test_fields(self, C): """ @@ -1129,6 +1146,24 @@ def test_fields_properties(self, C): for attribute in fields(C): assert getattr(fields(C), attribute.name) is attribute + def test_generics(self): + """ + Fields work with generic classes. + """ + T = TypeVar("T") + + @attr.define + class A(Generic[T]): + a: T + + assert len(fields(A)) == 1 + assert fields(A).a.name == "a" + assert fields(A).a.default is attr.NOTHING + + assert len(fields(A[str])) == 1 + assert fields(A[str]).a.name == "a" + assert fields(A[str]).a.default is attr.NOTHING + class TestFieldsDict: """