Skip to content

Commit

Permalink
bind elementwise_mod_op_xpu (#42175)
Browse files Browse the repository at this point in the history
* bind elementwise_mod_op_xpu *test=kunlun

* add more supported dtypes and UTs *test=kunlun

* fix datatype error

* add op to in xpu1_op_list

* Update Mac cmake version >=3.15 (#41456)

* Update Mac cmake version >=3.15

* notest;read test1

notest;read test2

notest;read test3

* fix inference link error

* fix inference link error

* fix windows link error

* fix cmake_policy

* fix build big size

* Add paddle::variant and replace paddle::any (#42139)

* add variant and replace any

* split attribute

* disable unittest failed in eager CI in temporary (#42101)

* test=py3-eager

* test=py3-eager

* test=py3-eager

* combine graph_table and feature_table in graph_engine (#42134)

* extract sub-graph

* graph-engine merging

* fix

* fix

* fix heter-ps config

* test performance

* test performance

* test performance

* test

* test

* update bfs

* change cmake

* test

* test gpu speed

* gpu_graph_engine optimization

* add dsm sample method

* add graph_neighbor_sample_v2

* Add graph_neighbor_sample_v2

* fix for loop

* add cpu sample interface

* fix kernel judgement

* add ssd layer to graph_engine

* fix allocation

* fix syntax error

* fix syntax error

* fix pscore class

* fix

* change index settings

* recover test

* recover test

* fix spelling

* recover

* fix

* move cudamemcpy after cuda stream sync

* fix linking problem

* remove comment

* add cpu test

* test

* add cpu test

* change comment

* combine feature table and graph table

* test

* test

* pybind

* test

* test

* test

* test

* pybind

* pybind

* fix cmake

* pybind

* fix

* fix

* add pybind

* add pybind

Co-authored-by: DesmonDay <908660116@qq.com>

* [CustomDevice] add eager mode support (#42034)

* fix FlattenContiguousRangeOpConverter out dim error (#42087)

* fix FlattenContiguousRangeOpConverter out dim error

* update code

* fix python3.10 compile bug on windows (#42140)

* Optimize dygraph GetExpectedKernelType perf (#42154)

* opt dygraph scheduling

* revert part impl

* fix incorrect usages of std::move and other compile errors (#41045)

* fix bug of std::move and others

* fix an compile error in debug mode

* fix wrong copy assignment operator

Signed-off-by: tiancaishaonvjituizi <452565578@qq.com>

* reformat

Signed-off-by: tiancaishaonvjituizi <452565578@qq.com>

* reformat

Signed-off-by: tiancaishaonvjituizi <452565578@qq.com>

* fix ArrayRef constructor following llvm

* fix format

* fix conflict with master

* fix variant compile error (#42203)

* [Eager] Support numpy.ndarry in CastNumpy2Scalar (#42136)

* [Eager] Remove redundancy code, fix fp16 case (#42169)

* [Eager] Support div(scalar) in eager mode (#42148)

* [Eager] Support div scalar in eager mode

* Updated and remove debug logs

* Remove list, use 'or' directly

* Remove useless statement

* fix recompute (#42128)

* fix recompute

* modify return

* add LICENSE in wheel dist-info package (#42187)

* replace any by variant in infermeta (#42181)

* 【PaddlePaddle Hackathon 2】24、为 Paddle 新增 nn.ChannelShuffle 组网 API (#40743)

* Add infermeta for ChannelShuffle

* Create channel_shuffle_grad_kernel.h

* Create channel_shuffle_kernel.h

* Create channel_shuffle_sig.cc

* Create channel_shuffle_op.cc

ChannelShuffle算子的描述

* Create channel_shuffle_kernel_impl.h

ChannelShuffle核函数的实现

* Create channel_shuffle_grad_kernel_impl.h

ChannelShuffle反向核函数的实现

* Add kernel register of channel shuffle and grad

注册ChannelShuffle及其反向的核函数

* add nn.functional.channel_shuffle

* add nn.ChannelShuffle

* Create test_channel_shuffle.py

* Update example of ChannelShuffle in vision.py

* Update test_channel_shuffle.py

* 修改channel_shuffle核函数的实现位置

* 修正代码格式

* 删除多余空格

* 完善channel_shuffle的错误检查

* Update unary.cc

* Update channel_shuffle_op.cc

* Update test_channel_shuffle.py

* Update unary.cc

* add channel_shuffle

* Update test_channel_shuffle.py

* Update vision.py

* 调整代码格式

* Update channel_shuffle_sig.cc

* 更新ChannelShuffle的文档

* 更新channel_shuffle的文档

* remove ChannelShuffleOpArgumentMapping

* add ChannelShuffleGradInferMeta

* Update channel_shuffle_op.cc

* 调整channel_shuffle及其梯度的核函数的位置

* Do not reset default stream for StreamSafeCUDAAllocator (#42149)

* remove redundant computation in Categorical.probs (#42114)

* Downloading data for test_analyzer_vit_ocr (#42041)

* Change server URL

* update config

* add test to parallel UT rule

* add checksum to ensure files are downloaded

* change downloading target

* reuse existing variable

* change target directory

* fix en docs of some Apis (gradients, scope_guard, cuda_places, name_scope, device_guard, load_program_state, scale, ParamAttr and WeightNormParamAttr) (#41604)

* Update scope_guard; test=document_fix

* gradients; test=document_fix

* gradients; test=document_fix

* name_scope; test=document_fix

* cpu_places; test=document_fix

* WeightNormParamAttr; test=document_fix

* cuda_places; test=document_fix

* load_program_state; test=document_fix

* device_guard; test=document_fix

* device_guard; test=document_fix

* ParamAttr; test=document_fix

* scale; test=document_fix

* scale; test=document_fix

* update code example;test=document_fix

Co-authored-by: Chen Long <1300851984@qq.com>

* fix datatype error

add op to in xpu1_op_list

*test=kunlun

* fix elementwise_mod op path error  *test=kunlun

* fix elementwise_mod UT error  *test=kunlun

* fix datatype error

add op to in xpu1_op_list

*test=kunlun

add op to in xpu1_op_list

fix elementwise_mod op path error  *test=kunlun

fix elementwise_mod UT error  *test=kunlun

Co-authored-by: tianshuo78520a <707759223@qq.com>
Co-authored-by: Chen Weihang <chenweihang@baidu.com>
Co-authored-by: pangyoki <pangyoki@126.com>
Co-authored-by: seemingwang <seemingwang@users.noreply.github.com>
Co-authored-by: DesmonDay <908660116@qq.com>
Co-authored-by: ronnywang <524019753@qq.com>
Co-authored-by: baoachun <962571062@qq.com>
Co-authored-by: Zhou Wei <1183042833@qq.com>
Co-authored-by: tiancaishaonvjituizi <452565578@qq.com>
Co-authored-by: Weilong Wu <veyron_wu@163.com>
Co-authored-by: Roc <30228238+sljlp@users.noreply.github.com>
Co-authored-by: BrilliantYuKaimin <91609464+BrilliantYuKaimin@users.noreply.github.com>
Co-authored-by: Ruibiao Chen <chenruibiao@baidu.com>
Co-authored-by: Feiyu Chan <chenfeiyu@baidu.com>
Co-authored-by: Sławomir Siwek <slawomir.siwek@intel.com>
Co-authored-by: Yilingyelu <103369238+Yilingyelu@users.noreply.github.com>
Co-authored-by: Chen Long <1300851984@qq.com>
  • Loading branch information
18 people committed May 6, 2022
1 parent e3e3fa6 commit 6ea2f04
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 0 deletions.
49 changes: 49 additions & 0 deletions paddle/fluid/operators/elementwise/elementwise_mod_op_xpu.cc
@@ -0,0 +1,49 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#ifdef PADDLE_WITH_XPU
#include <memory>
#include <string>
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/common_infer_shape_functions.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"

#include "paddle/fluid/operators/elementwise/elementwise_xpu.h"
#include "paddle/fluid/platform/device/device_wrapper.h"

namespace paddle {
namespace operators {

template <typename T>
class ElementwiseModXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;

public:
void Compute(const framework::ExecutionContext& ctx) const override {
XPUElementwise<T, XPUType>(ctx, xpu::broadcast_mod<XPUType>);
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;

REGISTER_OP_XPU_KERNEL(elementwise_mod, ops::ElementwiseModXPUKernel<float>,
ops::ElementwiseModXPUKernel<paddle::platform::float16>,
ops::ElementwiseModXPUKernel<int32_t>,
ops::ElementwiseModXPUKernel<int64_t>);

#endif
5 changes: 5 additions & 0 deletions paddle/fluid/platform/device/xpu/xpu2_op_list.h
Expand Up @@ -130,6 +130,11 @@ XPUOpMap& get_kl2_ops() {
{"elementwise_sub",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"elementwise_mod",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace())})},
{"equal", XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
Expand Down
133 changes: 133 additions & 0 deletions python/paddle/fluid/tests/unittests/xpu/test_elementwise_mod_op_xpu.py
@@ -0,0 +1,133 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append("..")
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard

import paddle
from op_test_xpu import XPUOpTest
from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper

paddle.enable_static()


class XPUTestElementwiseModOp(XPUOpTestWrapper):
def __init__(self) -> None:
self.op_name = 'elementwise_mod'
self.use_dynamic_create_class = False

class ElementwiseModOp(XPUOpTest):
def init_kernel_type(self):
self.use_mkldnn = False

def init_input_output(self):
self.x = np.random.uniform(0, 10000, [10, 10]).astype(self.dtype)
self.y = np.random.uniform(0, 1000, [10, 10]).astype(self.dtype)
self.out = np.mod(self.x, self.y)
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y)
}
self.outputs = {'Out': self.out}
self.attrs = {'axis': self.axis, 'use_mkldnn': self.use_mkldnn}

def init_dtype(self):
pass

def init_axis(self):
pass

def setUp(self):
self.op_type = 'elementwise_mod'
self.use_xpu = True
self.dtype = self.in_type
self.axis = -1
self.init_dtype()
self.init_input_output()
self.init_kernel_type()
self.init_axis()

def test_check_output(self):
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
self.check_output_with_place(place)

class TestElementwiseModOp_broadcast_1(ElementwiseModOp):
def init_input_output(self):
self.inputs = {
'X': np.random.rand(2, 100, 3).astype(self.dtype),
'Y': np.random.rand(2, 100, 3).astype(self.dtype)
}

self.attrs = {'axis': 1}
self.outputs = {'Out': self.inputs['X'] % self.inputs['Y']}

class TestElementwiseModOp_broadcast_2(ElementwiseModOp):
def init_input_output(self):
self.inputs = {
'X': np.random.rand(22, 128, 3).astype(self.dtype),
'Y': np.random.rand(22, 128, 3).astype(self.dtype)
}

self.attrs = {'axis': 1}
self.outputs = {'Out': self.inputs['X'] % self.inputs['Y']}

class TestRemainderOp(unittest.TestCase):
def test_dygraph(self):
with fluid.dygraph.guard():
np_x = np.random.rand(22, 128, 3).astype('int64')
np_y = np.random.rand(22, 128, 3).astype('int64')
x = paddle.to_tensor(np_x)
y = paddle.to_tensor(np_y)
z = paddle.remainder(x, y)
np_z = z.numpy()
z_expected = np.mod(np_x, np_y)
self.assertEqual((np_z == z_expected).all(), True)

np_x = np.array([-3.3, 11.5, -2, 3.5])
np_y = np.array([-1.2, 2., 3.3, -2.3])
x = paddle.to_tensor(np_x)
y = paddle.to_tensor(np_y)
z = x % y
z_expected = np.array([-0.9, 1.5, 1.3, -1.1])
self.assertEqual(np.allclose(z_expected, z.numpy()), True)

np_x = np.random.rand(22, 128, 3).astype('int32')
np_y = np.random.rand(22, 128, 3).astype('int32')
x = paddle.to_tensor(np_x)
y = paddle.to_tensor(np_y)
z = paddle.remainder(x, y)
np_z = z.numpy()
z_expected = np.mod(np_x, np_y)
self.assertEqual((np_z == z_expected).all(), True)

np_x = np.array([-3, 11, -2, 3])
np_y = np.array([-1, 2, 3, -2])
x = paddle.to_tensor(np_x, dtype="float16")
y = paddle.to_tensor(np_y, dtype="float16")
z = x % y
z_expected = np.array([0, 1, 1, -1])
self.assertEqual(np.allclose(z_expected, z.numpy()), True)


support_types = get_xpu_op_support_types('elementwise_mod')
for stype in support_types:
create_test_class(globals(), XPUTestElementwiseModOp, stype)

if __name__ == '__main__':
unittest.main()

0 comments on commit 6ea2f04

Please sign in to comment.