Skip to content

Commit

Permalink
support __floordiv__ (#47060)
Browse files Browse the repository at this point in the history
  • Loading branch information
veyron95 committed Oct 17, 2022
1 parent 9e08633 commit 6430790
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 2 deletions.
94 changes: 94 additions & 0 deletions paddle/fluid/pybind/eager_math_op_patch.cc
Expand Up @@ -1295,6 +1295,96 @@ static PyObject* tensor__le__method(TensorObject* self,
EAGER_CATCH_AND_THROW_RETURN_NULL
}

static PyObject* tensor__floordiv__method(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
paddle::platform::RecordEvent pythonc_record_event(
"floordiv pybind_patch_func",
paddle::platform::TracerEventType::UserDefined,
1);
EAGER_TRY
VLOG(6) << "Running Eager tensor__floordiv__method";

// Set Device ID
auto place = egr::Controller::Instance().GetExpectedPlace();
SetDevice(place);

paddle::experimental::Tensor ret;
paddle::experimental::Tensor self_tensor = self->tensor;

PyObject* other_obj = PyTuple_GET_ITEM(args, 0);

// 1. scalar exists cases or not
// there is no scalar case for floordiv, but alse need to cast self_tensor
// in need.
double other_double = 0.0;
bool has_other_double = false;
if (PyFloat_Check(other_obj) || PyCheckInteger(other_obj) ||
IsNumpyType(other_obj)) {
if (PyFloat_Check(other_obj)) {
other_double = CastPyArg2Double(other_obj, "__floordiv__", 0);
has_other_double = true;
if (_supported_int_dtype_.find(self_tensor.dtype()) !=
_supported_int_dtype_.end()) {
eager_gil_scoped_release guard;
self_tensor = cast_ad_func(self_tensor, DataType::FLOAT32);
}
} else if (PyCheckInteger(other_obj) || IsNumpyType(other_obj)) {
other_double = CastPyArg2Double(other_obj, "__floordiv__", 0);
has_other_double = true;
}
}

// 2. create or get tensor for other_obj
paddle::experimental::Tensor other_tensor;
if (has_other_double) {
eager_gil_scoped_release guard;
other_tensor = full_ad_func(self_tensor.shape(),
phi::Scalar(other_double),
self_tensor.dtype(),
self_tensor.place());
} else if (!PyCheckTensor(other_obj)) {
paddle::experimental::Scalar value =
CastPyArg2Scalar(other_obj, "__floordiv__", 0);
if (PyComplex_Check(other_obj)) {
eager_gil_scoped_release guard;
other_tensor =
full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place());
} else {
eager_gil_scoped_release guard;
other_tensor =
full_ad_func({1}, value, self_tensor.dtype(), self_tensor.place());
}
} else {
other_tensor = CastPyArg2Tensor(other_obj, 0);
}

// 3. promote types or unify right var type to left var
phi::DataType lhs_dtype = self_tensor.dtype();
phi::DataType rhs_dtype = other_tensor.dtype();
if (lhs_dtype != rhs_dtype) {
// note: only op_type in _supported_promote_complex_types_ should promote
// dtype, floordiv is not in _supported_promote_complex_types_, will not do
// promote dtype
VLOG(6) << "The dtype of left and right Tensor are not the same, left "
"dtype is "
<< lhs_dtype << ", but right dtype is " << rhs_dtype
<< ", the right dtype will convert to " << lhs_dtype;
eager_gil_scoped_release guard;
other_tensor = cast_ad_func(other_tensor, lhs_dtype);
}

// 4. calculation
VLOG(6) << "Calling floor_divide_ad_func in tensor__floordiv__method";
{
eager_gil_scoped_release guard;
ret = floor_divide_ad_func(self_tensor, other_tensor);
}

return ToPyObject(ret);
EAGER_CATCH_AND_THROW_RETURN_NULL
}

PyMethodDef math_op_patch_methods[] = {
{"__add__",
(PyCFunction)(void (*)(void))tensor__add__method,
Expand Down Expand Up @@ -1336,6 +1426,10 @@ PyMethodDef math_op_patch_methods[] = {
(PyCFunction)(void (*)(void))tensor__rdiv__method,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"__floordiv__",
(PyCFunction)(void (*)(void))tensor__floordiv__method,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"__mod__",
(PyCFunction)(void (*)(void))tensor__mod__method,
METH_VARARGS | METH_KEYWORDS,
Expand Down
3 changes: 1 addition & 2 deletions python/paddle/fluid/dygraph/math_op_patch.py
Expand Up @@ -392,8 +392,6 @@ def __impl__(self, other_var):
True)),
('__rpow__', _binary_creator_('__rpow__', 'elementwise_pow', True,
None)),
('__floordiv__',
_binary_creator_('__floordiv__', 'floor_divide', False, None, True)),
# for logical compare
('__eq__', _binary_creator_('__eq__', 'equal', False, None, True)),
('__ne__', _binary_creator_('__ne__', 'not_equal', False, None, True)),
Expand All @@ -417,6 +415,7 @@ def __impl__(self, other_var):
'__ge__',
'__lt__',
'__le__',
'__floordiv__',
]

global _already_patch_varbase
Expand Down

0 comments on commit 6430790

Please sign in to comment.