From 3c7e21600874948fb15e6ba2370d3f44a81b9378 Mon Sep 17 00:00:00 2001 From: Richard Si <63936253+ichard26@users.noreply.github.com> Date: Thu, 1 Sep 2022 13:14:38 -0400 Subject: [PATCH] [mypyc] Support __pos__ and __abs__ dunders (#13490) Calls to these dunders on native classes will be specialized to use a direct method call instead of using PyNumber_Absolute. Also calls to abs() on any types have been optimized. They no longer involve a builtins dictionary lookup. It's probably possible to write a C helper function for abs(int) to avoid the C-API entirely for native integers, but I don't feel skilled enough to do that yet. --- mypyc/codegen/emitclass.py | 9 +++++++-- mypyc/doc/native_operations.rst | 1 + mypyc/irbuild/ll_builder.py | 2 ++ mypyc/irbuild/specialize.py | 14 ++++++++++++++ mypyc/primitives/generic_ops.py | 10 ++++++++++ mypyc/test-data/fixtures/ir.py | 12 ++++++++++-- mypyc/test-data/irbuild-any.test | 22 ++++++++++++++++++++++ mypyc/test-data/irbuild-dunders.test | 19 +++++++++++++++++++ mypyc/test-data/run-dunders.test | 11 +++++++++++ 9 files changed, 96 insertions(+), 4 deletions(-) diff --git a/mypyc/codegen/emitclass.py b/mypyc/codegen/emitclass.py index a93ef1b57a1e..99153929231c 100644 --- a/mypyc/codegen/emitclass.py +++ b/mypyc/codegen/emitclass.py @@ -68,11 +68,15 @@ def wrapper_slot(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str: AS_SEQUENCE_SLOT_DEFS: SlotTable = {"__contains__": ("sq_contains", generate_contains_wrapper)} AS_NUMBER_SLOT_DEFS: SlotTable = { + # Unary operations. "__bool__": ("nb_bool", generate_bool_wrapper), - "__neg__": ("nb_negative", generate_dunder_wrapper), - "__invert__": ("nb_invert", generate_dunder_wrapper), "__int__": ("nb_int", generate_dunder_wrapper), "__float__": ("nb_float", generate_dunder_wrapper), + "__neg__": ("nb_negative", generate_dunder_wrapper), + "__pos__": ("nb_positive", generate_dunder_wrapper), + "__abs__": ("nb_absolute", generate_dunder_wrapper), + "__invert__": ("nb_invert", generate_dunder_wrapper), + # Binary operations. "__add__": ("nb_add", generate_bin_op_wrapper), "__radd__": ("nb_add", generate_bin_op_wrapper), "__sub__": ("nb_subtract", generate_bin_op_wrapper), @@ -97,6 +101,7 @@ def wrapper_slot(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str: "__rxor__": ("nb_xor", generate_bin_op_wrapper), "__matmul__": ("nb_matrix_multiply", generate_bin_op_wrapper), "__rmatmul__": ("nb_matrix_multiply", generate_bin_op_wrapper), + # In-place binary operations. "__iadd__": ("nb_inplace_add", generate_dunder_wrapper), "__isub__": ("nb_inplace_subtract", generate_dunder_wrapper), "__imul__": ("nb_inplace_multiply", generate_dunder_wrapper), diff --git a/mypyc/doc/native_operations.rst b/mypyc/doc/native_operations.rst index 896217063fee..2587e982feac 100644 --- a/mypyc/doc/native_operations.rst +++ b/mypyc/doc/native_operations.rst @@ -24,6 +24,7 @@ Functions * ``cast(, obj)`` * ``type(obj)`` * ``len(obj)`` +* ``abs(obj)`` * ``id(obj)`` * ``iter(obj)`` * ``next(iter: Iterator)`` diff --git a/mypyc/irbuild/ll_builder.py b/mypyc/irbuild/ll_builder.py index 14657848e648..c545e86d9561 100644 --- a/mypyc/irbuild/ll_builder.py +++ b/mypyc/irbuild/ll_builder.py @@ -1486,6 +1486,8 @@ def unary_op(self, value: Value, expr_op: str, line: int) -> Value: if isinstance(typ, RInstance): if expr_op == "-": method = "__neg__" + elif expr_op == "+": + method = "__pos__" elif expr_op == "~": method = "__invert__" else: diff --git a/mypyc/irbuild/specialize.py b/mypyc/irbuild/specialize.py index d09d1bd05687..3e208dccf492 100644 --- a/mypyc/irbuild/specialize.py +++ b/mypyc/irbuild/specialize.py @@ -34,6 +34,7 @@ from mypy.types import AnyType, TypeOfAny from mypyc.ir.ops import BasicBlock, Integer, RaiseStandardError, Register, Unreachable, Value from mypyc.ir.rtypes import ( + RInstance, RTuple, RType, bool_rprimitive, @@ -138,6 +139,19 @@ def translate_globals(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Va return None +@specialize_function("builtins.abs") +def translate_abs(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value | None: + """Specialize calls on native classes that implement __abs__.""" + if len(expr.args) == 1 and expr.arg_kinds == [ARG_POS]: + arg = expr.args[0] + arg_typ = builder.node_type(arg) + if isinstance(arg_typ, RInstance) and arg_typ.class_ir.has_method("__abs__"): + obj = builder.accept(arg) + return builder.gen_method_call(obj, "__abs__", [], None, expr.line) + + return None + + @specialize_function("builtins.len") def translate_len(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value | None: if len(expr.args) == 1 and expr.arg_kinds == [ARG_POS]: diff --git a/mypyc/primitives/generic_ops.py b/mypyc/primitives/generic_ops.py index cdaa94931604..f6817ad024b7 100644 --- a/mypyc/primitives/generic_ops.py +++ b/mypyc/primitives/generic_ops.py @@ -145,6 +145,16 @@ priority=0, ) +# abs(obj) +function_op( + name="builtins.abs", + arg_types=[object_rprimitive], + return_type=object_rprimitive, + c_function_name="PyNumber_Absolute", + error_kind=ERR_MAGIC, + priority=0, +) + # obj1[obj2] method_op( name="__getitem__", diff --git a/mypyc/test-data/fixtures/ir.py b/mypyc/test-data/fixtures/ir.py index e0b706d7ff9d..0e437f4597ea 100644 --- a/mypyc/test-data/fixtures/ir.py +++ b/mypyc/test-data/fixtures/ir.py @@ -3,7 +3,7 @@ from typing import ( TypeVar, Generic, List, Iterator, Iterable, Dict, Optional, Tuple, Any, Set, - overload, Mapping, Union, Callable, Sequence, FrozenSet + overload, Mapping, Union, Callable, Sequence, FrozenSet, Protocol ) T = TypeVar('T') @@ -12,6 +12,10 @@ K = TypeVar('K') # for keys in mapping V = TypeVar('V') # for values in mapping +class __SupportsAbs(Protocol[T_co]): + def __abs__(self) -> T_co: pass + + class object: def __init__(self) -> None: pass def __eq__(self, x: object) -> bool: pass @@ -40,6 +44,7 @@ def __truediv__(self, x: float) -> float: pass def __mod__(self, x: int) -> int: pass def __neg__(self) -> int: pass def __pos__(self) -> int: pass + def __abs__(self) -> int: pass def __invert__(self) -> int: pass def __and__(self, n: int) -> int: pass def __or__(self, n: int) -> int: pass @@ -88,6 +93,9 @@ def __sub__(self, n: float) -> float: pass def __mul__(self, n: float) -> float: pass def __truediv__(self, n: float) -> float: pass def __neg__(self) -> float: pass + def __pos__(self) -> float: pass + def __abs__(self) -> float: pass + def __invert__(self) -> float: pass class complex: def __init__(self, x: object, y: object = None) -> None: pass @@ -296,7 +304,7 @@ def zip(x: Iterable[T], y: Iterable[S]) -> Iterator[Tuple[T, S]]: ... @overload def zip(x: Iterable[T], y: Iterable[S], z: Iterable[V]) -> Iterator[Tuple[T, S, V]]: ... def eval(e: str) -> Any: ... -def abs(x: float) -> float: ... +def abs(x: __SupportsAbs[T]) -> T: ... def exit() -> None: ... def min(x: T, y: T) -> T: ... def max(x: T, y: T) -> T: ... diff --git a/mypyc/test-data/irbuild-any.test b/mypyc/test-data/irbuild-any.test index bace026bc957..bcf9a1880635 100644 --- a/mypyc/test-data/irbuild-any.test +++ b/mypyc/test-data/irbuild-any.test @@ -176,3 +176,25 @@ L6: r4 = unbox(int, r3) n = r4 return 1 + +[case testAbsSpecialization] +# Specialization of native classes that implement __abs__ is checked in +# irbuild-dunders.test +def f() -> None: + a = abs(1) + b = abs(1.1) +[out] +def f(): + r0, r1 :: object + r2, a :: int + r3, r4, b :: float +L0: + r0 = object 1 + r1 = PyNumber_Absolute(r0) + r2 = unbox(int, r1) + a = r2 + r3 = 1.1 + r4 = PyNumber_Absolute(r3) + b = r4 + return 1 + diff --git a/mypyc/test-data/irbuild-dunders.test b/mypyc/test-data/irbuild-dunders.test index d06a570aa7b0..24e708913354 100644 --- a/mypyc/test-data/irbuild-dunders.test +++ b/mypyc/test-data/irbuild-dunders.test @@ -148,11 +148,19 @@ class C: def __float__(self) -> float: return 4.0 + def __pos__(self) -> int: + return 5 + + def __abs__(self) -> int: + return 6 + def f(c: C) -> None: -c ~c int(c) float(c) + +c + abs(c) [out] def C.__neg__(self): self :: __main__.C @@ -172,10 +180,19 @@ def C.__float__(self): L0: r0 = 4.0 return r0 +def C.__pos__(self): + self :: __main__.C +L0: + return 10 +def C.__abs__(self): + self :: __main__.C +L0: + return 12 def f(c): c :: __main__.C r0, r1 :: int r2, r3, r4, r5 :: object + r6, r7 :: int L0: r0 = c.__neg__() r1 = c.__invert__() @@ -183,5 +200,7 @@ L0: r3 = PyObject_CallFunctionObjArgs(r2, c, 0) r4 = load_address PyFloat_Type r5 = PyObject_CallFunctionObjArgs(r4, c, 0) + r6 = c.__pos__() + r7 = c.__abs__() return 1 diff --git a/mypyc/test-data/run-dunders.test b/mypyc/test-data/run-dunders.test index aee2a956c47f..0b156e5c3af8 100644 --- a/mypyc/test-data/run-dunders.test +++ b/mypyc/test-data/run-dunders.test @@ -332,6 +332,13 @@ class C: def __float__(self) -> float: return float(self.x + 4) + def __pos__(self) -> int: + return self.x + 5 + + def __abs__(self) -> int: + return abs(self.x) + 6 + + def test_unary_dunders_generic() -> None: a: Any = C(10) @@ -339,6 +346,8 @@ def test_unary_dunders_generic() -> None: assert ~a == 12 assert int(a) == 13 assert float(a) == 14.0 + assert +a == 15 + assert abs(a) == 16 def test_unary_dunders_native() -> None: c = C(10) @@ -347,6 +356,8 @@ def test_unary_dunders_native() -> None: assert ~c == 12 assert int(c) == 13 assert float(c) == 14.0 + assert +c == 15 + assert abs(c) == 16 [case testDundersBinarySimple] from typing import Any