Skip to content

Commit

Permalink
add paddle.geometric.send_uv api, add unittest
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay committed Aug 8, 2022
1 parent 7ab77f9 commit 6fdfcb1
Show file tree
Hide file tree
Showing 3 changed files with 236 additions and 2 deletions.
1 change: 1 addition & 0 deletions python/paddle/__init__.py
Expand Up @@ -78,6 +78,7 @@
import paddle.reader # noqa: F401
import paddle.static # noqa: F401
import paddle.vision # noqa: F401
import paddle.geometric # noqa: F401

from .tensor.attribute import is_complex # noqa: F401
from .tensor.attribute import is_integer # noqa: F401
Expand Down
116 changes: 114 additions & 2 deletions python/paddle/fluid/tests/unittests/test_graph_send_uv_op.py
Expand Up @@ -41,10 +41,17 @@ def compute_graph_send_uv(inputs, attributes):
return results


def graph_send_uv_wrapper(x, y, src_index, dst_index, compute_type="add"):
return paddle.geometric.send_uv(x, y, src_index, dst_index,
compute_type.lower())


class TestGraphSendUVOp(OpTest):

def setUp(self):
paddle.enable_static()
self.python_api = graph_send_uv_wrapper
self.python_out_sig = ['out']
self.op_type = "graph_send_uv"
self.set_config()
self.inputs = {
Expand All @@ -58,10 +65,10 @@ def setUp(self):
self.outputs = {'out': out}

def test_check_output(self):
self.check_output(check_eager=False)
self.check_output(check_eager=True)

def test_check_grad(self):
self.check_grad(['x', 'y'], 'out', check_eager=False)
self.check_grad(['x', 'y'], 'out', check_eager=True)

def set_config(self):
self.x = np.random.random((10, 20)).astype("float64")
Expand Down Expand Up @@ -147,3 +154,108 @@ def set_config(self):
self.src_index = index[:, 0]
self.dst_index = index[:, 1]
self.compute_type = 'MUL'


class API_GeometricSendUVTest(unittest.TestCase):

def test_compute_all_dygraph(self):
paddle.disable_static()
x = paddle.to_tensor([[0, 2, 3], [1, 4, 5], [2, 6, 7]], dtype="float32")
y = paddle.to_tensor([[1, 1, 2], [2, 3, 4], [4, 5, 6]], dtype="float32")
src_index = paddle.to_tensor(np.array([0, 1, 2, 0]), dtype="int32")
dst_index = paddle.to_tensor(np.array([1, 2, 1, 0]), dtype="int32")

res_add = paddle.geometric.send_uv(x,
y,
src_index,
dst_index,
compute_type="add")
res_sub = paddle.geometric.send_uv(x,
y,
src_index,
dst_index,
compute_type="sub")
res_mul = paddle.geometric.send_uv(x,
y,
src_index,
dst_index,
compute_type="mul")
res_div = paddle.geometric.send_uv(x,
y,
src_index,
dst_index,
compute_type="div")
res = [res_add, res_sub, res_mul, res_div]

np_add = np.array([[2, 5, 7], [5, 9, 11], [4, 9, 11], [1, 3, 5]],
dtype="float32")
np_sub = np.array([[-2, -1, -1], [-3, -1, -1], [0, 3, 3], [-1, 1, 1]],
dtype="float32")
np_mul = np.array([[0, 6, 12], [4, 20, 30], [4, 18, 28], [0, 2, 6]],
dtype="float32")
np_div = np.array(
[[0, 2 / 3, 0.75], [0.25, 0.8, 5 / 6], [1, 2, 7 / 4], [0, 2, 1.5]],
dtype="float32")

for np_res, paddle_res in zip([np_add, np_sub, np_mul, np_div], res):
self.assertTrue(
np.allclose(np_res, paddle_res, atol=1e-6), "two value is\
{}\n{}, check diff!".format(np_res, paddle_res))

def test_compute_all_static(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data(name="x", shape=[3, 3], dtype="float32")
y = paddle.static.data(name="y", shape=[3, 3], dtype="float32")
src_index = paddle.static.data(name="src", shape=[4], dtype="int32")
dst_index = paddle.static.data(name="dst", shape=[4], dtype="int32")
res_add = paddle.geometric.send_uv(x,
y,
src_index,
dst_index,
compute_type="add")
res_sub = paddle.geometric.send_uv(x,
y,
src_index,
dst_index,
compute_type="sub")
res_mul = paddle.geometric.send_uv(x,
y,
src_index,
dst_index,
compute_type="mul")
res_div = paddle.geometric.send_uv(x,
y,
src_index,
dst_index,
compute_type="div")

exe = paddle.static.Executor(paddle.CPUPlace())
data1 = np.array([[0, 2, 3], [1, 4, 5], [2, 6, 7]], dtype="float32")
data2 = np.array([[1, 1, 2], [2, 3, 4], [4, 5, 6]], dtype="float32")
data3 = np.array([0, 1, 2, 0], dtype="int32")
data4 = np.array([1, 2, 1, 0], dtype="int32")

np_add = np.array([[2, 5, 7], [5, 9, 11], [4, 9, 11], [1, 3, 5]],
dtype="float32")
np_sub = np.array(
[[-2, -1, -1], [-3, -1, -1], [0, 3, 3], [-1, 1, 1]],
dtype="float32")
np_mul = np.array([[0, 6, 12], [4, 20, 30], [4, 18, 28], [0, 2, 6]],
dtype="float32")
np_div = np.array([[0, 2 / 3, 0.75], [0.25, 0.8, 5 / 6],
[1, 2, 7 / 4], [0, 2, 1.5]],
dtype="float32")

ret = exe.run(feed={
'x': data1,
'y': data2,
'src': data3,
'dst': data4,
},
fetch_list=[res_add, res_sub, res_mul, res_div])
for np_res, paddle_res in zip([np_add, np_sub, np_mul, np_div],
ret):
self.assertTrue(
np.allclose(np_res, paddle_res, atol=1e-6), "two value is\
{}\n{}, check diff!".format(np_res, paddle_res))
121 changes: 121 additions & 0 deletions python/paddle/geometric/message_passing/send.py
Expand Up @@ -11,3 +11,124 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.framework import _in_legacy_dygraph, in_dygraph_mode
from paddle.fluid.data_feeder import check_variable_and_dtype
from paddle import _C_ops

from .utils import reshape_lhs_rhs


def send_uv(x, y, src_index, dst_index, compute_type="add", name=None):
"""
Graph Learning message passing api.
This api is mainly used in Graph Learning domain, and the main purpose is to reduce intermediate memory
consumption in the process of message passing. Take `x` as the source node feature tensor, take `y` as
the destination node feature tensor. Then we use `src_index` and `dst_index` to gather the corresponding data,
and then compute the edge features in different compute_types like `add`, `sub`, `mul`, `div`.
.. code-block:: text
Given:
X = [[0, 2, 3],
[1, 4, 5],
[2, 6, 7]]
Y = [[0, 1, 2],
[2, 3, 4],
[4, 5, 6]]
src_index = [0, 1, 2, 0]
dst_index = [1, 2, 1, 0]
compute_type = "add"
Then:
Out = [[2, 5, 7],
[5, 9, 11],
[4, 9, 11],
[0, 3, 5]]
Args:
x (Tensor): The source node feature tensor, and the available data type is float32, float64, int32, int64. And we support float16 in gpu version.
y (Tensor): The destination node feature tensor, and the available data type is float32, float64, int32, int64. And we support float16 in gpu version.
src_index (Tensor): An 1-D tensor, and the available data type is int32, int64.
dst_index (Tensor): An 1-D tensor, and should have the same shape as `src_index`.
The available data type is int32, int64.
compute_type (Tensor): Different compute types for x and y, including `add`, `sub`, `mul` and `div`.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
out (Tensor): The output tensor.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([[0, 2, 3], [1, 4, 5], [2, 6, 7]], dtype="float32")
y = paddle.to_tensor([[0, 1, 2], [2, 3, 4], [4, 5, 6]], dtype="float32")
indexes = paddle.to_tensor([[0, 1], [1, 2], [2, 1], [0, 0]], dtype="int32")
src_index = indexes[:, 0]
dst_index = indexes[:, 1]
out = paddle.geometric.send_uv(x, y, src_index, dst_index, compute_type="add")
# Outputs: [[2., 5., 7.], [5., 9., 11.], [4., 9., 11.], [0., 3., 5.]]
"""

if compute_type not in ['add', 'sub', 'mul', 'div']:
raise ValueError(
"compute_type should be `add`, `sub`, `mul`, `div`, but received %s"
% compute_type)

x, y = reshape_lhs_rhs(x, y)

if compute_type == 'sub':
compute_type = 'add'
y = -y
if compute_type == 'div':
compute_type = 'mul'
y = 1. / y

if in_dygraph_mode():
return _C_ops.final_state_graph_send_uv(x, y, src_index, dst_index,
compute_type.upper())
else:
if _in_legacy_dygraph():
return _C_ops.graph_send_uv(x, y, src_index, dst_index,
"compute_type", compute_type.upper())
else:
helper = LayerHelper("send_uv", **locals())
check_variable_and_dtype(x, 'x',
['int32', 'int64', 'float32', 'float64'],
'graph_send_uv')
check_variable_and_dtype(y, 'y',
['int32', 'int64', 'float32', 'float64'],
'graph_send_uv')
check_variable_and_dtype(src_index, 'src_index', ['int32', 'int64'],
'graph_send_uv')
check_variable_and_dtype(dst_index, 'dst_index', ['int32', 'int64'],
'graph_send_uv')
out = helper.create_variable_for_type_inference(dtype=x.dtype)

inputs = {
'x': x,
'y': y,
'src_index': src_index,
'dst_index': dst_index
}
attrs = {'compute_type': compute_type.upper()}
helper.append_op(type="graph_send_uv",
inputs=inputs,
attrs=attrs,
outputs={"out": out})
return out

0 comments on commit 6fdfcb1

Please sign in to comment.