diff --git a/mypy/checker.py b/mypy/checker.py index 17e894b9bc33..5a69f502540a 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -35,7 +35,7 @@ UnionType, TypeVarId, TypeVarType, PartialType, DeletedType, UninhabitedType, TypeVarDef, is_named_instance, union_items, TypeQuery, LiteralType, is_optional, remove_optional, TypeTranslator, StarType, get_proper_type, ProperType, - get_proper_types, is_literal_type, TypeAliasType) + get_proper_types, is_literal_type, TypeAliasType, TypeGuardType) from mypy.sametypes import is_same_type from mypy.messages import ( MessageBuilder, make_inferred_type_note, append_invariance_notes, pretty_seq, @@ -3957,6 +3957,7 @@ def find_isinstance_check(self, node: Expression ) -> Tuple[TypeMap, TypeMap]: """Find any isinstance checks (within a chain of ands). Includes implicit and explicit checks for None and calls to callable. + Also includes TypeGuard functions. Return value is a map of variables to their types if the condition is true and a map of variables to their types if the condition is false. @@ -4001,6 +4002,14 @@ def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeM if literal(expr) == LITERAL_TYPE: vartype = type_map[expr] return self.conditional_callable_type_map(expr, vartype) + elif isinstance(node.callee, RefExpr): + if node.callee.type_guard is not None: + # TODO: Follow keyword args or *args, **kwargs + if node.arg_kinds[0] != nodes.ARG_POS: + self.fail("Type guard requires positional argument", node) + return {}, {} + if literal(expr) == LITERAL_TYPE: + return {expr: TypeGuardType(node.callee.type_guard)}, {} elif isinstance(node, ComparisonExpr): # Step 1: Obtain the types of each operand and whether or not we can # narrow their types. (For example, we shouldn't try narrowing the diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 40204e7c9ccf..4a924d643676 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -14,7 +14,7 @@ make_optional_type, ) from mypy.types import ( - Type, AnyType, CallableType, Overloaded, NoneType, TypeVarDef, + Type, AnyType, CallableType, Overloaded, NoneType, TypeGuardType, TypeVarDef, TupleType, TypedDictType, Instance, TypeVarType, ErasedType, UnionType, PartialType, DeletedType, UninhabitedType, TypeType, TypeOfAny, LiteralType, LiteralValue, is_named_instance, FunctionLike, @@ -317,6 +317,11 @@ def visit_call_expr_inner(self, e: CallExpr, allow_none_return: bool = False) -> ret_type=self.object_type(), fallback=self.named_type('builtins.function')) callee_type = get_proper_type(self.accept(e.callee, type_context, always_allow_any=True)) + if (isinstance(e.callee, RefExpr) + and isinstance(callee_type, CallableType) + and callee_type.type_guard is not None): + # Cache it for find_isinstance_check() + e.callee.type_guard = callee_type.type_guard if (self.chk.options.disallow_untyped_calls and self.chk.in_checked_function() and isinstance(callee_type, CallableType) @@ -4163,6 +4168,10 @@ def narrow_type_from_binder(self, expr: Expression, known_type: Type, """ if literal(expr) >= LITERAL_TYPE: restriction = self.chk.binder.get(expr) + # Ignore the error about using get_proper_type(). + if isinstance(restriction, TypeGuardType): # type: ignore[misc] + # A type guard forces the new type even if it doesn't overlap the old. + return restriction.type_guard # If the current node is deferred, some variables may get Any types that they # otherwise wouldn't have. We don't want to narrow down these since it may # produce invalid inferred Optional[Any] types, at least. diff --git a/mypy/constraints.py b/mypy/constraints.py index 89b8e4527e24..70265285dadc 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -457,7 +457,12 @@ def visit_callable_type(self, template: CallableType) -> List[Constraint]: for t, a in zip(template.arg_types, cactual.arg_types): # Negate direction due to function argument type contravariance. res.extend(infer_constraints(t, a, neg_op(self.direction))) - res.extend(infer_constraints(template.ret_type, cactual.ret_type, + template_ret_type, cactual_ret_type = template.ret_type, cactual.ret_type + if template.type_guard is not None: + template_ret_type = template.type_guard + if cactual.type_guard is not None: + cactual_ret_type = cactual.type_guard + res.extend(infer_constraints(template_ret_type, cactual_ret_type, self.direction)) return res elif isinstance(self.actual, AnyType): diff --git a/mypy/expandtype.py b/mypy/expandtype.py index 2e3db6b109a4..f98e0750743b 100644 --- a/mypy/expandtype.py +++ b/mypy/expandtype.py @@ -97,7 +97,9 @@ def visit_type_var(self, t: TypeVarType) -> Type: def visit_callable_type(self, t: CallableType) -> Type: return t.copy_modified(arg_types=self.expand_types(t.arg_types), - ret_type=t.ret_type.accept(self)) + ret_type=t.ret_type.accept(self), + type_guard=(t.type_guard.accept(self) + if t.type_guard is not None else None)) def visit_overloaded(self, t: Overloaded) -> Type: items = [] # type: List[CallableType] diff --git a/mypy/fixup.py b/mypy/fixup.py index 30e1a0dae2b9..b90dba971e4f 100644 --- a/mypy/fixup.py +++ b/mypy/fixup.py @@ -192,6 +192,8 @@ def visit_callable_type(self, ct: CallableType) -> None: for arg in ct.bound_args: if arg: arg.accept(self) + if ct.type_guard is not None: + ct.type_guard.accept(self) def visit_overloaded(self, t: Overloaded) -> None: for ct in t.items(): diff --git a/mypy/nodes.py b/mypy/nodes.py index 0571788bf002..76521e8c2b38 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -1448,7 +1448,8 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T: class RefExpr(Expression): """Abstract base class for name-like constructs""" - __slots__ = ('kind', 'node', 'fullname', 'is_new_def', 'is_inferred_def', 'is_alias_rvalue') + __slots__ = ('kind', 'node', 'fullname', 'is_new_def', 'is_inferred_def', 'is_alias_rvalue', + 'type_guard') def __init__(self) -> None: super().__init__() @@ -1467,6 +1468,8 @@ def __init__(self) -> None: self.is_inferred_def = False # Is this expression appears as an rvalue of a valid type alias definition? self.is_alias_rvalue = False + # Cache type guard from callable_type.type_guard + self.type_guard = None # type: Optional[mypy.types.Type] class NameExpr(RefExpr): diff --git a/mypy/test/testcheck.py b/mypy/test/testcheck.py index eb1dbd9dcc30..eb61e66ddcf6 100644 --- a/mypy/test/testcheck.py +++ b/mypy/test/testcheck.py @@ -92,6 +92,7 @@ 'check-annotated.test', 'check-parameter-specification.test', 'check-generic-alias.test', + 'check-typeguard.test', ] # Tests that use Python 3.8-only AST features (like expression-scoped ignores): diff --git a/mypy/typeanal.py b/mypy/typeanal.py index 219be131c4af..3554f638d27c 100644 --- a/mypy/typeanal.py +++ b/mypy/typeanal.py @@ -345,6 +345,9 @@ def try_analyze_special_unbound_type(self, t: UnboundType, fullname: str) -> Opt " and at least one annotation", t) return AnyType(TypeOfAny.from_error) return self.anal_type(t.args[0]) + elif self.anal_type_guard_arg(t, fullname) is not None: + # In most contexts, TypeGuard[...] acts as an alias for bool (ignoring its args) + return self.named_type('builtins.bool') return None def get_omitted_any(self, typ: Type, fullname: Optional[str] = None) -> AnyType: @@ -524,15 +527,34 @@ def visit_callable_type(self, t: CallableType, nested: bool = True) -> Type: variables = t.variables else: variables = self.bind_function_type_variables(t, t) + special = self.anal_type_guard(t.ret_type) ret = t.copy_modified(arg_types=self.anal_array(t.arg_types, nested=nested), ret_type=self.anal_type(t.ret_type, nested=nested), # If the fallback isn't filled in yet, # its type will be the falsey FakeInfo fallback=(t.fallback if t.fallback.type else self.named_type('builtins.function')), - variables=self.anal_var_defs(variables)) + variables=self.anal_var_defs(variables), + type_guard=special, + ) return ret + def anal_type_guard(self, t: Type) -> Optional[Type]: + if isinstance(t, UnboundType): + sym = self.lookup_qualified(t.name, t) + if sym is not None and sym.node is not None: + return self.anal_type_guard_arg(t, sym.node.fullname) + # TODO: What if it's an Instance? Then use t.type.fullname? + return None + + def anal_type_guard_arg(self, t: UnboundType, fullname: str) -> Optional[Type]: + if fullname in ('typing_extensions.TypeGuard', 'typing.TypeGuard'): + if len(t.args) != 1: + self.fail("TypeGuard must have exactly one type argument", t) + return AnyType(TypeOfAny.from_error) + return self.anal_type(t.args[0]) + return None + def visit_overloaded(self, t: Overloaded) -> Type: # Overloaded types are manually constructed in semanal.py by analyzing the # AST and combining together the Callable types this visitor converts. diff --git a/mypy/types.py b/mypy/types.py index 10def3826120..bf138f343b5a 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -270,6 +270,16 @@ def copy_modified(self, *, self.line, self.column) +class TypeGuardType(Type): + """Only used by find_instance_check() etc.""" + def __init__(self, type_guard: Type): + super().__init__(line=type_guard.line, column=type_guard.column) + self.type_guard = type_guard + + def __repr__(self) -> str: + return "TypeGuard({})".format(self.type_guard) + + class ProperType(Type): """Not a type alias. @@ -1005,6 +1015,7 @@ class CallableType(FunctionLike): # tools that consume mypy ASTs 'def_extras', # Information about original definition we want to serialize. # This is used for more detailed error messages. + 'type_guard', # T, if -> TypeGuard[T] (ret_type is bool in this case). ) def __init__(self, @@ -1024,6 +1035,7 @@ def __init__(self, from_type_type: bool = False, bound_args: Sequence[Optional[Type]] = (), def_extras: Optional[Dict[str, Any]] = None, + type_guard: Optional[Type] = None, ) -> None: super().__init__(line, column) assert len(arg_types) == len(arg_kinds) == len(arg_names) @@ -1058,6 +1070,7 @@ def __init__(self, not definition.is_static else None} else: self.def_extras = {} + self.type_guard = type_guard def copy_modified(self, arg_types: Bogus[Sequence[Type]] = _dummy, @@ -1075,7 +1088,9 @@ def copy_modified(self, special_sig: Bogus[Optional[str]] = _dummy, from_type_type: Bogus[bool] = _dummy, bound_args: Bogus[List[Optional[Type]]] = _dummy, - def_extras: Bogus[Dict[str, Any]] = _dummy) -> 'CallableType': + def_extras: Bogus[Dict[str, Any]] = _dummy, + type_guard: Bogus[Optional[Type]] = _dummy, + ) -> 'CallableType': return CallableType( arg_types=arg_types if arg_types is not _dummy else self.arg_types, arg_kinds=arg_kinds if arg_kinds is not _dummy else self.arg_kinds, @@ -1094,6 +1109,7 @@ def copy_modified(self, from_type_type=from_type_type if from_type_type is not _dummy else self.from_type_type, bound_args=bound_args if bound_args is not _dummy else self.bound_args, def_extras=def_extras if def_extras is not _dummy else dict(self.def_extras), + type_guard=type_guard if type_guard is not _dummy else self.type_guard, ) def var_arg(self) -> Optional[FormalArgument]: @@ -1255,6 +1271,8 @@ def __eq__(self, other: object) -> bool: def serialize(self) -> JsonDict: # TODO: As an optimization, leave out everything related to # generic functions for non-generic functions. + assert (self.type_guard is None + or isinstance(get_proper_type(self.type_guard), Instance)), str(self.type_guard) return {'.class': 'CallableType', 'arg_types': [t.serialize() for t in self.arg_types], 'arg_kinds': self.arg_kinds, @@ -1269,6 +1287,7 @@ def serialize(self) -> JsonDict: 'bound_args': [(None if t is None else t.serialize()) for t in self.bound_args], 'def_extras': dict(self.def_extras), + 'type_guard': self.type_guard.serialize() if self.type_guard is not None else None, } @classmethod @@ -1286,7 +1305,9 @@ def deserialize(cls, data: JsonDict) -> 'CallableType': implicit=data['implicit'], bound_args=[(None if t is None else deserialize_type(t)) for t in data['bound_args']], - def_extras=data['def_extras'] + def_extras=data['def_extras'], + type_guard=(deserialize_type(data['type_guard']) + if data['type_guard'] is not None else None), ) @@ -2097,7 +2118,10 @@ def visit_callable_type(self, t: CallableType) -> str: s = '({})'.format(s) if not isinstance(get_proper_type(t.ret_type), NoneType): - s += ' -> {}'.format(t.ret_type.accept(self)) + if t.type_guard is not None: + s += ' -> TypeGuard[{}]'.format(t.type_guard.accept(self)) + else: + s += ' -> {}'.format(t.ret_type.accept(self)) if t.variables: vs = [] diff --git a/test-data/unit/check-python38.test b/test-data/unit/check-python38.test index dcbf96ac850f..7cb571cedc8d 100644 --- a/test-data/unit/check-python38.test +++ b/test-data/unit/check-python38.test @@ -392,3 +392,12 @@ def func() -> None: class Foo: def __init__(self) -> None: self.x = 123 + +[case testWalrusTypeGuard] +from typing_extensions import TypeGuard +def is_float(a: object) -> TypeGuard[float]: pass +def main(a: object) -> None: + if is_float(x := a): + reveal_type(x) # N: Revealed type is 'builtins.float' + reveal_type(a) # N: Revealed type is 'builtins.object' +[builtins fixtures/tuple.pyi] diff --git a/test-data/unit/check-serialize.test b/test-data/unit/check-serialize.test index 1aa9ac0662a2..b4982cc6f70a 100644 --- a/test-data/unit/check-serialize.test +++ b/test-data/unit/check-serialize.test @@ -224,6 +224,21 @@ def f(x: int) -> int: pass tmp/a.py:2: note: Revealed type is 'builtins.str' tmp/a.py:3: error: Unexpected keyword argument "x" for "f" +[case testSerializeTypeGuardFunction] +import a +[file a.py] +import b +[file a.py.2] +import b +reveal_type(b.guard('')) +reveal_type(b.guard) +[file b.py] +from typing_extensions import TypeGuard +def guard(a: object) -> TypeGuard[str]: pass +[builtins fixtures/tuple.pyi] +[out2] +tmp/a.py:2: note: Revealed type is 'builtins.bool' +tmp/a.py:3: note: Revealed type is 'def (a: builtins.object) -> TypeGuard[builtins.str]' -- -- Classes -- diff --git a/test-data/unit/check-typeguard.test b/test-data/unit/check-typeguard.test new file mode 100644 index 000000000000..e4bf3dd5c931 --- /dev/null +++ b/test-data/unit/check-typeguard.test @@ -0,0 +1,296 @@ +[case testTypeGuardBasic] +from typing_extensions import TypeGuard +class Point: pass +def is_point(a: object) -> TypeGuard[Point]: pass +def main(a: object) -> None: + if is_point(a): + reveal_type(a) # N: Revealed type is '__main__.Point' + else: + reveal_type(a) # N: Revealed type is 'builtins.object' +[builtins fixtures/tuple.pyi] + +[case testTypeGuardTypeArgsNone] +from typing_extensions import TypeGuard +def foo(a: object) -> TypeGuard: # E: TypeGuard must have exactly one type argument + pass +[builtins fixtures/tuple.pyi] + +[case testTypeGuardTypeArgsTooMany] +from typing_extensions import TypeGuard +def foo(a: object) -> TypeGuard[int, int]: # E: TypeGuard must have exactly one type argument + pass +[builtins fixtures/tuple.pyi] + +[case testTypeGuardTypeArgType] +from typing_extensions import TypeGuard +def foo(a: object) -> TypeGuard[42]: # E: Invalid type: try using Literal[42] instead? + pass +[builtins fixtures/tuple.pyi] + +[case testTypeGuardRepr] +from typing_extensions import TypeGuard +def foo(a: object) -> TypeGuard[int]: + pass +reveal_type(foo) # N: Revealed type is 'def (a: builtins.object) -> TypeGuard[builtins.int]' +[builtins fixtures/tuple.pyi] + +[case testTypeGuardCallArgsNone] +from typing_extensions import TypeGuard +class Point: pass +# TODO: error on the 'def' line (insufficient args for type guard) +def is_point() -> TypeGuard[Point]: pass +def main(a: object) -> None: + if is_point(): + reveal_type(a) # N: Revealed type is 'builtins.object' +[builtins fixtures/tuple.pyi] + +[case testTypeGuardCallArgsMultiple] +from typing_extensions import TypeGuard +class Point: pass +def is_point(a: object, b: object) -> TypeGuard[Point]: pass +def main(a: object, b: object) -> None: + if is_point(a, b): + reveal_type(a) # N: Revealed type is '__main__.Point' + reveal_type(b) # N: Revealed type is 'builtins.object' +[builtins fixtures/tuple.pyi] + +[case testTypeGuardIsBool] +from typing_extensions import TypeGuard +def f(a: TypeGuard[int]) -> None: pass +reveal_type(f) # N: Revealed type is 'def (a: builtins.bool)' +a: TypeGuard[int] +reveal_type(a) # N: Revealed type is 'builtins.bool' +class C: + a: TypeGuard[int] +reveal_type(C().a) # N: Revealed type is 'builtins.bool' +[builtins fixtures/tuple.pyi] + +[case testTypeGuardWithTypeVar] +from typing import TypeVar, Tuple +from typing_extensions import TypeGuard +T = TypeVar('T') +def is_two_element_tuple(a: Tuple[T, ...]) -> TypeGuard[Tuple[T, T]]: pass +def main(a: Tuple[T, ...]): + if is_two_element_tuple(a): + reveal_type(a) # N: Revealed type is 'Tuple[T`-1, T`-1]' +[builtins fixtures/tuple.pyi] + +[case testTypeGuardNonOverlapping] +from typing import List +from typing_extensions import TypeGuard +def is_str_list(a: List[object]) -> TypeGuard[List[str]]: pass +def main(a: List[object]): + if is_str_list(a): + reveal_type(a) # N: Revealed type is 'builtins.list[builtins.str]' +[builtins fixtures/tuple.pyi] + +[case testTypeGuardUnionIn] +from typing import Union +from typing_extensions import TypeGuard +def is_foo(a: Union[int, str]) -> TypeGuard[str]: pass +def main(a: Union[str, int]) -> None: + if is_foo(a): + reveal_type(a) # N: Revealed type is 'builtins.str' +[builtins fixtures/tuple.pyi] + +[case testTypeGuardUnionOut] +from typing import Union +from typing_extensions import TypeGuard +def is_foo(a: object) -> TypeGuard[Union[int, str]]: pass +def main(a: object) -> None: + if is_foo(a): + reveal_type(a) # N: Revealed type is 'Union[builtins.int, builtins.str]' +[builtins fixtures/tuple.pyi] + +[case testTypeGuardNonzeroFloat] +from typing_extensions import TypeGuard +def is_nonzero(a: object) -> TypeGuard[float]: pass +def main(a: int): + if is_nonzero(a): + reveal_type(a) # N: Revealed type is 'builtins.float' +[builtins fixtures/tuple.pyi] + +[case testTypeGuardHigherOrder] +from typing import Callable, TypeVar, Iterable, List +from typing_extensions import TypeGuard +T = TypeVar('T') +R = TypeVar('R') +def filter(f: Callable[[T], TypeGuard[R]], it: Iterable[T]) -> Iterable[R]: pass +def is_float(a: object) -> TypeGuard[float]: pass +a: List[object] = ["a", 0, 0.0] +b = filter(is_float, a) +reveal_type(b) # N: Revealed type is 'typing.Iterable[builtins.float*]' +[builtins fixtures/tuple.pyi] + +[case testTypeGuardMethod] +from typing_extensions import TypeGuard +class C: + def main(self, a: object) -> None: + if self.is_float(a): + reveal_type(self) # N: Revealed type is '__main__.C' + reveal_type(a) # N: Revealed type is 'builtins.float' + def is_float(self, a: object) -> TypeGuard[float]: pass +[builtins fixtures/tuple.pyi] + +[case testTypeGuardCrossModule] +import guard +from points import Point +def main(a: object) -> None: + if guard.is_point(a): + reveal_type(a) # N: Revealed type is 'points.Point' +[file guard.py] +from typing_extensions import TypeGuard +import points +def is_point(a: object) -> TypeGuard[points.Point]: pass +[file points.py] +class Point: pass +[builtins fixtures/tuple.pyi] + +[case testTypeGuardBodyRequiresBool] +from typing_extensions import TypeGuard +def is_float(a: object) -> TypeGuard[float]: + return "not a bool" # E: Incompatible return value type (got "str", expected "bool") +[builtins fixtures/tuple.pyi] + +[case testTypeGuardNarrowToTypedDict] +from typing import Dict, TypedDict +from typing_extensions import TypeGuard +class User(TypedDict): + name: str + id: int +def is_user(a: Dict[str, object]) -> TypeGuard[User]: + return isinstance(a.get("name"), str) and isinstance(a.get("id"), int) +def main(a: Dict[str, object]) -> None: + if is_user(a): + reveal_type(a) # N: Revealed type is 'TypedDict('__main__.User', {'name': builtins.str, 'id': builtins.int})' +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + +[case testTypeGuardInAssert] +from typing_extensions import TypeGuard +def is_float(a: object) -> TypeGuard[float]: pass +def main(a: object) -> None: + assert is_float(a) + reveal_type(a) # N: Revealed type is 'builtins.float' +[builtins fixtures/tuple.pyi] + +[case testTypeGuardFromAny] +from typing import Any +from typing_extensions import TypeGuard +def is_objfloat(a: object) -> TypeGuard[float]: pass +def is_anyfloat(a: Any) -> TypeGuard[float]: pass +def objmain(a: object) -> None: + if is_objfloat(a): + reveal_type(a) # N: Revealed type is 'builtins.float' + if is_anyfloat(a): + reveal_type(a) # N: Revealed type is 'builtins.float' +def anymain(a: Any) -> None: + if is_objfloat(a): + reveal_type(a) # N: Revealed type is 'builtins.float' + if is_anyfloat(a): + reveal_type(a) # N: Revealed type is 'builtins.float' +[builtins fixtures/tuple.pyi] + +[case testTypeGuardNegatedAndElse] +from typing import Union +from typing_extensions import TypeGuard +def is_int(a: object) -> TypeGuard[int]: pass +def is_str(a: object) -> TypeGuard[str]: pass +def intmain(a: Union[int, str]) -> None: + if not is_int(a): + reveal_type(a) # N: Revealed type is 'Union[builtins.int, builtins.str]' + else: + reveal_type(a) # N: Revealed type is 'builtins.int' +def strmain(a: Union[int, str]) -> None: + if is_str(a): + reveal_type(a) # N: Revealed type is 'builtins.str' + else: + reveal_type(a) # N: Revealed type is 'Union[builtins.int, builtins.str]' +[builtins fixtures/tuple.pyi] + +[case testTypeGuardClassMethod] +from typing_extensions import TypeGuard +class C: + @classmethod + def is_float(cls, a: object) -> TypeGuard[float]: pass + def method(self, a: object) -> None: + if self.is_float(a): + reveal_type(a) # N: Revealed type is 'builtins.float' +def main(a: object) -> None: + if C.is_float(a): + reveal_type(a) # N: Revealed type is 'builtins.float' +[builtins fixtures/classmethod.pyi] + +[case testTypeGuardRequiresPositionalArgs] +from typing_extensions import TypeGuard +def is_float(a: object, b: object = 0) -> TypeGuard[float]: pass +def main1(a: object) -> None: + # This is debatable -- should we support these cases? + + if is_float(a=a, b=1): # E: Type guard requires positional argument + reveal_type(a) # N: Revealed type is 'builtins.object' + + if is_float(b=1, a=a): # E: Type guard requires positional argument + reveal_type(a) # N: Revealed type is 'builtins.object' + + ta = (a,) + if is_float(*ta): # E: Type guard requires positional argument + reveal_type(ta) # N: Revealed type is 'Tuple[builtins.object]' + reveal_type(a) # N: Revealed type is 'builtins.object' + + la = [a] + if is_float(*la): # E: Type guard requires positional argument + reveal_type(la) # N: Revealed type is 'builtins.list[builtins.object*]' + reveal_type(a) # N: Revealed type is 'builtins.object*' + +[builtins fixtures/tuple.pyi] + +[case testTypeGuardOverload-skip] +# flags: --strict-optional +from typing import overload, Any, Callable, Iterable, Iterator, List, Optional, TypeVar +from typing_extensions import TypeGuard + +T = TypeVar("T") +R = TypeVar("R") + +@overload +def filter(f: Callable[[T], TypeGuard[R]], it: Iterable[T]) -> Iterator[R]: ... +@overload +def filter(f: Callable[[T], bool], it: Iterable[T]) -> Iterator[T]: ... +def filter(*args): pass + +def is_int_typeguard(a: object) -> TypeGuard[int]: pass +def is_int_bool(a: object) -> bool: pass + +def main(a: List[Optional[int]]) -> None: + bb = filter(lambda x: x is not None, a) + reveal_type(bb) # N: Revealed type is 'typing.Iterator[Union[builtins.int, None]]' + # Also, if you replace 'bool' with 'Any' in the second overload, bb is Iterator[Any] + cc = filter(is_int_typeguard, a) + reveal_type(cc) # N: Revealed type is 'typing.Iterator[builtins.int*]' + dd = filter(is_int_bool, a) + reveal_type(dd) # N: Revealed type is 'typing.Iterator[Union[builtins.int, None]]' + +[builtins fixtures/tuple.pyi] +[typing fixtures/typing-full.pyi] + +[case testTypeGuardDecorated] +from typing import TypeVar +from typing_extensions import TypeGuard +T = TypeVar("T") +def decorator(f: T) -> T: pass +@decorator +def is_float(a: object) -> TypeGuard[float]: + pass +def main(a: object) -> None: + if is_float(a): + reveal_type(a) # N: Revealed type is 'builtins.float' +[builtins fixtures/tuple.pyi] + +[case testTypeGuardMethodOverride-skip] +from typing_extensions import TypeGuard +class C: + def is_float(self, a: object) -> TypeGuard[float]: pass +class D(C): + def is_float(self, a: object) -> bool: pass # E: Some error +[builtins fixtures/tuple.pyi] diff --git a/test-data/unit/lib-stub/typing_extensions.pyi b/test-data/unit/lib-stub/typing_extensions.pyi index 946430d106a6..478e5dc1b283 100644 --- a/test-data/unit/lib-stub/typing_extensions.pyi +++ b/test-data/unit/lib-stub/typing_extensions.pyi @@ -24,6 +24,8 @@ Annotated: _SpecialForm = ... ParamSpec: _SpecialForm Concatenate: _SpecialForm +TypeGuard: _SpecialForm + # Fallback type for all typed dicts (does not exist at runtime). class _TypedDict(Mapping[str, object]): # Needed to make this class non-abstract. It is explicitly declared abstract in