Skip to content

Commit

Permalink
change reindex
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay committed Aug 25, 2022
1 parent e9c50fb commit c8b50fb
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 27 deletions.
18 changes: 9 additions & 9 deletions python/paddle/fluid/tests/unittests/test_graph_reindex.py
Expand Up @@ -213,13 +213,13 @@ def test_reindex_result(self):
index_buffer = paddle.full([self.num_nodes], -1, dtype="int32")

reindex_src, reindex_dst, out_nodes = \
paddle.geometric.graph_reindex(x, neighbors, count)
paddle.geometric.reindex_graph(x, neighbors, count)
np.testing.assert_allclose(self.reindex_src, reindex_src, rtol=1e-05)
np.testing.assert_allclose(self.reindex_dst, reindex_dst, rtol=1e-05)
np.testing.assert_allclose(self.out_nodes, out_nodes, rtol=1e-05)

reindex_src, reindex_dst, out_nodes = \
paddle.geometric.graph_reindex(x, neighbors, count,
paddle.geometric.reindex_graph(x, neighbors, count,
value_buffer, index_buffer,
has_buffer_hashtable=True)
np.testing.assert_allclose(self.reindex_src, reindex_src, rtol=1e-05)
Expand All @@ -235,7 +235,7 @@ def test_heter_reindex_result(self):
count = paddle.concat([count, count])

reindex_src, reindex_dst, out_nodes = \
paddle.geometric.graph_reindex(x, neighbors, count)
paddle.geometric.reindex_graph(x, neighbors, count)
np.testing.assert_allclose(self.reindex_src,
reindex_src[:self.neighbors.shape[0]],
rtol=1e-05)
Expand Down Expand Up @@ -276,7 +276,7 @@ def test_heter_reindex_result_v2(self):
reindex_dst = np.array(reindex_dst, dtype="int64")

reindex_src_, reindex_dst_, out_nodes_ = \
paddle.geometric.graph_reindex(paddle.to_tensor(x),
paddle.geometric.reindex_graph(paddle.to_tensor(x),
paddle.to_tensor(neighbors),
paddle.to_tensor(counts))
np.testing.assert_allclose(reindex_src, reindex_src_, rtol=1e-05)
Expand Down Expand Up @@ -311,7 +311,7 @@ def test_heter_reindex_result_v3(self):
neighbors = [paddle.to_tensor(neighbors1), paddle.to_tensor(neighbors2)]
count = [paddle.to_tensor(count1), paddle.to_tensor(count2)]
reindex_src_, reindex_dst_, out_nodes_ = \
paddle.geometric.heter_graph_reindex(paddle.to_tensor(x),
paddle.geometric.reindex_heter_graph(paddle.to_tensor(x),
neighbors, count)
np.testing.assert_allclose(reindex_src, reindex_src_, rtol=1e-05)
np.testing.assert_allclose(reindex_dst, reindex_dst_, rtol=1e-05)
Expand All @@ -337,9 +337,9 @@ def test_reindex_result_static(self):
dtype="int32")

reindex_src_1, reindex_dst_1, out_nodes_1 = \
paddle.geometric.graph_reindex(x, neighbors, count)
paddle.geometric.reindex_graph(x, neighbors, count)
reindex_src_2, reindex_dst_2, out_nodes_2 = \
paddle.geometric.graph_reindex(x, neighbors, count,
paddle.geometric.reindex_graph(x, neighbors, count,
value_buffer, index_buffer,
has_buffer_hashtable=True)

Expand Down Expand Up @@ -420,11 +420,11 @@ def test_heter_reindex_result_static(self):
dtype="int32")

reindex_src_1, reindex_dst_1, out_nodes_1 = \
paddle.geometric.heter_graph_reindex(x,
paddle.geometric.reindex_heter_graph(x,
[neighbors1, neighbors2],
[count1, count2])
reindex_src_2, reindex_dst_2, out_nodes_2 = \
paddle.geometric.heter_graph_reindex(x,
paddle.geometric.reindex_heter_graph(x,
[neighbors1, neighbors2],
[count1, count2],
value_buffer, index_buffer,
Expand Down
8 changes: 4 additions & 4 deletions python/paddle/geometric/__init__.py
Expand Up @@ -19,8 +19,8 @@
from .math import segment_mean # noqa: F401
from .math import segment_min # noqa: F401
from .math import segment_max # noqa: F401
from .sampling import graph_reindex # noqa: F401
from .sampling import heter_graph_reindex # noqa: F401
from .reindex import reindex_graph # noqa: F401
from .reindex import reindex_heter_graph # noqa: F401
from .sampling import sample_neighbors # noqa: F401

__all__ = [
Expand All @@ -31,7 +31,7 @@
'segment_mean',
'segment_min',
'segment_max',
'graph_reindex',
'heter_graph_reindex',
'reindex_graph',
'reindex_heter_graph',
'sample_neighbors',
]
Expand Up @@ -22,15 +22,15 @@
__all__ = []


def graph_reindex(x,
def reindex_graph(x,
neighbors,
count,
value_buffer=None,
index_buffer=None,
has_buffer_hashtable=False,
name=None):
"""
Graph Reindex API.
Reindex Graph API.
This API is mainly used in Graph Learning domain, which should be used
in conjunction with `graph_sample_neighbors` API. And the main purpose
Expand All @@ -41,7 +41,7 @@ def graph_reindex(x,
The number in x should be unique, otherwise it would cause potential errors.
We will reindex all the nodes from 0.
Take input nodes x = [0, 1, 2] as an example.
Take input nodes x = [0, 1, 2] as an example.
If we have neighbors = [8, 9, 0, 4, 7, 6, 7], and count = [2, 3, 2],
then we know that the neighbors of 0 is [8, 9], the neighbors of 1
is [0, 4, 7], and the neighbors of 2 is [6, 7].
Expand Down Expand Up @@ -89,7 +89,7 @@ def graph_reindex(x,
count = paddle.to_tensor(count, dtype="int32")
reindex_src, reindex_dst, out_nodes = \
paddle.geometric.graph_reindex(x, neighbors, count)
paddle.geometric.reindex_graph(x, neighbors, count)
# reindex_src: [3, 4, 0, 5, 6, 7, 6]
# reindex_dst: [0, 0, 1, 1, 1, 2, 2]
# out_nodes: [0, 1, 2, 8, 9, 4, 7, 6]
Expand Down Expand Up @@ -117,7 +117,7 @@ def graph_reindex(x,
check_variable_and_dtype(index_buffer, "HashTable_Index", ("int32"),
"graph_reindex")

helper = LayerHelper("graph_reindex", **locals())
helper = LayerHelper("reindex_graph", **locals())
reindex_src = helper.create_variable_for_type_inference(dtype=x.dtype)
reindex_dst = helper.create_variable_for_type_inference(dtype=x.dtype)
out_nodes = helper.create_variable_for_type_inference(dtype=x.dtype)
Expand All @@ -143,15 +143,15 @@ def graph_reindex(x,
return reindex_src, reindex_dst, out_nodes


def heter_graph_reindex(x,
def reindex_heter_graph(x,
neighbors,
count,
value_buffer=None,
index_buffer=None,
has_buffer_hashtable=False,
name=None):
"""
HeterGraph Reindex API.
Reindex HeterGraph API.
This API is mainly used in Graph Learning domain, which should be used
in conjunction with `graph_sample_neighbors` API. And the main purpose
Expand All @@ -160,7 +160,7 @@ def heter_graph_reindex(x,
**Notes**:
The number in x should be unique, otherwise it would cause potential errors.
We support multi-edge-types neighbors reindexing in heter_graph_reindex api.
We support multi-edge-types neighbors reindexing in reindex_heter_graph api.
We will reindex all the nodes from 0.
Take input nodes x = [0, 1, 2] as an example.
Expand Down Expand Up @@ -219,7 +219,7 @@ def heter_graph_reindex(x,
neighbors = [neighbors_a, neighbors_b]
count = [count_a, count_b]
reindex_src, reindex_dst, out_nodes = \
paddle.geometric.heter_graph_reindex(x, neighbors, count)
paddle.geometric.reindex_heter_graph(x, neighbors, count)
# reindex_src: [3, 4, 0, 5, 6, 7, 6, 0, 2, 8, 9, 1]
# reindex_dst: [0, 0, 1, 1, 1, 2, 2, 0, 1, 1, 1, 2]
# out_nodes: [0, 1, 2, 8, 9, 4, 7, 6, 3, 5]
Expand Down Expand Up @@ -258,7 +258,7 @@ def heter_graph_reindex(x,
check_variable_and_dtype(index_buffer, "HashTable_Index", ("int32"),
"graph_reindex")

helper = LayerHelper("heter_graph_reindex", **locals())
helper = LayerHelper("reindex_heter_graph", **locals())
reindex_src = helper.create_variable_for_type_inference(dtype=x.dtype)
reindex_dst = helper.create_variable_for_type_inference(dtype=x.dtype)
out_nodes = helper.create_variable_for_type_inference(dtype=x.dtype)
Expand Down
4 changes: 0 additions & 4 deletions python/paddle/geometric/sampling/__init__.py
Expand Up @@ -12,12 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .graph_reindex import graph_reindex # noqa: F401
from .graph_reindex import heter_graph_reindex # noqa: F401
from .neighbors import sample_neighbors # noqa: F401

__all__ = [
'graph_reindex',
'heter_graph_reindex',
'sample_neighbors',
]

0 comments on commit c8b50fb

Please sign in to comment.