Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
CallExpr,
ClassDef,
Context,
Decorator,
Expression,
FuncDef,
IndexExpr,
Expand Down Expand Up @@ -772,13 +773,28 @@ def incompatible_argument(
actual_type_str, expected_type_str
)
else:
if self.prefer_simple_messages():
try:
expected_type = callee.arg_types[m - 1]
except IndexError: # Varargs callees
expected_type = callee.arg_types[-1]

decorator_context = callee_name is None and isinstance(outer_context, Decorator)
simple_message = self.prefer_simple_messages() and not decorator_context

if decorator_context:
decorator = cast(Decorator, outer_context)
arg_type_str, expected_type_str = format_type_distinctly(
arg_type, expected_type, bare=True, options=self.options
)
func_name = decorator.func.name
msg = (
f'Decorated function "{func_name}" has incompatible type '
f"{quote_type_string(arg_type_str)}; expected "
f"{quote_type_string(expected_type_str)}"
)
elif simple_message:
msg = "Argument has incompatible type"
else:
try:
expected_type = callee.arg_types[m - 1]
except IndexError: # Varargs callees
expected_type = callee.arg_types[-1]
arg_type_str, expected_type_str = format_type_distinctly(
arg_type, expected_type, bare=True, options=self.options
)
Expand Down Expand Up @@ -822,6 +838,7 @@ def incompatible_argument(
quote_type_string(arg_type_str),
quote_type_string(expected_type_str),
)
if not simple_message:
expected_type = get_proper_type(expected_type)
if isinstance(expected_type, UnionType):
expected_types = list(expected_type.items)
Expand Down
108 changes: 64 additions & 44 deletions mypy/test/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import sys
import tempfile
from abc import abstractmethod
from collections.abc import Iterator
from collections.abc import Callable, Iterator
from dataclasses import dataclass
from pathlib import Path
from re import Pattern
Expand Down Expand Up @@ -53,6 +53,66 @@ def _file_arg_to_module(filename: str) -> str:
return ".".join(parts)


def _handle_out_section(
item: TestItem,
case: DataDrivenTestCase,
output: list[str],
output2: dict[int, list[str]],
out_section_missing: bool,
item_fail: Callable[[str], NoReturn],
) -> bool:
"""Handle an "out" / "outN" section from a test item.

Mutates `output` (in-place) or `output2` and returns the updated
`out_section_missing` flag.
"""
if item.arg is None:
args = []
else:
args = item.arg.split(",")

version_check = True
for arg in args:
if arg.startswith("version"):
compare_op = arg[7:9]
if compare_op not in {">=", "=="}:
item_fail("Only >= and == version checks are currently supported")
version_str = arg[9:]
try:
version = tuple(int(x) for x in version_str.split("."))
except ValueError:
item_fail(f"{version_str!r} is not a valid python version")
if compare_op == ">=":
if version <= defaults.PYTHON3_VERSION:
item_fail(
f"{arg} always true since minimum runtime version is {defaults.PYTHON3_VERSION}"
)
version_check = sys.version_info >= version
elif compare_op == "==":
if version < defaults.PYTHON3_VERSION:
item_fail(
f"{arg} always false since minimum runtime version is {defaults.PYTHON3_VERSION}"
)
if not 1 < len(version) < 4:
item_fail(
f'Only minor or patch version checks are currently supported with "==": {version_str!r}'
)
version_check = sys.version_info[: len(version)] == version
if version_check:
tmp_output = [expand_variables(line) for line in item.data]
if os.path.sep == "\\" and case.normalize_output:
tmp_output = [fix_win_path(line) for line in tmp_output]
if item.id == "out" or item.id == "out1":
# modify in place so caller's `output` reference is preserved
output[:] = tmp_output
else:
passnum = int(item.id[len("out") :])
assert passnum > 1
output2[passnum] = tmp_output
out_section_missing = False
return out_section_missing


def parse_test_case(case: DataDrivenTestCase) -> None:
"""Parse and prepare a single case from suite with test case descriptions.

Expand Down Expand Up @@ -149,49 +209,9 @@ def _item_fail(msg: str) -> NoReturn:
full = join(base_path, m.group(1))
deleted_paths.setdefault(num, set()).add(full)
elif re.match(r"out[0-9]*$", item.id):
if item.arg is None:
args = []
else:
args = item.arg.split(",")

version_check = True
for arg in args:
if arg.startswith("version"):
compare_op = arg[7:9]
if compare_op not in {">=", "=="}:
_item_fail("Only >= and == version checks are currently supported")
version_str = arg[9:]
try:
version = tuple(int(x) for x in version_str.split("."))
except ValueError:
_item_fail(f"{version_str!r} is not a valid python version")
if compare_op == ">=":
if version <= defaults.PYTHON3_VERSION:
_item_fail(
f"{arg} always true since minimum runtime version is {defaults.PYTHON3_VERSION}"
)
version_check = sys.version_info >= version
elif compare_op == "==":
if version < defaults.PYTHON3_VERSION:
_item_fail(
f"{arg} always false since minimum runtime version is {defaults.PYTHON3_VERSION}"
)
if not 1 < len(version) < 4:
_item_fail(
f'Only minor or patch version checks are currently supported with "==": {version_str!r}'
)
version_check = sys.version_info[: len(version)] == version
if version_check:
tmp_output = [expand_variables(line) for line in item.data]
if os.path.sep == "\\" and case.normalize_output:
tmp_output = [fix_win_path(line) for line in tmp_output]
if item.id == "out" or item.id == "out1":
output = tmp_output
else:
passnum = int(item.id[len("out") :])
assert passnum > 1
output2[passnum] = tmp_output
out_section_missing = False
out_section_missing = _handle_out_section(
item, case, output, output2, out_section_missing, _item_fail
)
elif item.id == "triggered" and item.arg is None:
triggered = item.data
else:
Expand Down
12 changes: 11 additions & 1 deletion test-data/unit/check-functions.test
Original file line number Diff line number Diff line change
Expand Up @@ -959,6 +959,16 @@ def dec2(f: Callable[[Any, Any], None]) -> Callable[[Any], None]: pass
@dec2
def f(x, y): pass

[case testDecoratorFactoryApplicationErrorMessage]
from typing import Callable

def decorator(f: object) -> Callable[[Callable[[int], object]], None]: ...
def f(a: int) -> None: ...

@decorator(f) # E: Decorated function "something" has incompatible type "Callable[[], None]"; expected "Callable[[int], object]"
def something() -> None:
pass

[case testNoTypeCheckDecoratorOnMethod1]
from typing import no_type_check

Expand Down Expand Up @@ -3524,7 +3534,7 @@ def decorator2(f: Callable[P, None]) -> Callable[
def key2(x: int) -> None:
...

@decorator2(key2) # E: Argument 1 has incompatible type "def foo2(y: int) -> Coroutine[Any, Any, None]"; expected "def (x: int) -> Awaitable[None]"
@decorator2(key2) # E: Decorated function "foo2" has incompatible type "def foo2(y: int) -> Coroutine[Any, Any, None]"; expected "def (x: int) -> Awaitable[None]"
async def foo2(y: int) -> None:
...

Expand Down