Skip to content

Commit

Permalink
change reindex and sample neighbors flag name
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay committed Aug 17, 2022
1 parent 1e06e51 commit d5ac560
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 24 deletions.
4 changes: 2 additions & 2 deletions python/paddle/fluid/tests/unittests/test_graph_reindex.py
Expand Up @@ -213,7 +213,7 @@ def test_reindex_result(self):
reindex_src, reindex_dst, out_nodes = \
paddle.geometric.graph_reindex(x, neighbors, count,
value_buffer, index_buffer,
flag_buffer_hashtable=True)
has_buffer_hashtable=True)
self.assertTrue(np.allclose(self.reindex_src, reindex_src))
self.assertTrue(np.allclose(self.reindex_dst, reindex_dst))
self.assertTrue(np.allclose(self.out_nodes, out_nodes))
Expand Down Expand Up @@ -299,7 +299,7 @@ def test_reindex_result_static(self):
reindex_src_2, reindex_dst_2, out_nodes_2 = \
paddle.geometric.graph_reindex(x, neighbors, count,
value_buffer, index_buffer,
flag_buffer_hashtable=True)
has_buffer_hashtable=True)

exe = paddle.static.Executor(paddle.CPUPlace())
ret = exe.run(feed={
Expand Down
Expand Up @@ -289,7 +289,7 @@ def test_sample_result_fisher_yates_sampling(self):
nodes,
perm_buffer=perm_buffer,
sample_size=self.sample_size,
flag_perm_buffer=True)
has_perm_buffer=True)
out_count_cumsum = paddle.cumsum(out_count)
for i in range(len(out_count)):
if i == 0:
Expand Down Expand Up @@ -361,7 +361,7 @@ def check_perm_buffer_error():
colptr,
nodes,
sample_size=self.sample_size,
flag_perm_buffer=True)
has_perm_buffer=True)

self.assertRaises(ValueError, check_eid_error)
self.assertRaises(ValueError, check_perm_buffer_error)
Expand Down Expand Up @@ -390,7 +390,7 @@ def test_sample_result_with_eids(self):
perm_buffer=perm_buffer,
sample_size=self.sample_size,
return_eids=True,
flag_perm_buffer=True)
has_perm_buffer=True)

paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
Expand Down
6 changes: 6 additions & 0 deletions python/paddle/geometric/sampling/__init__.py
Expand Up @@ -15,3 +15,9 @@
from .graph_reindex import graph_reindex # noqa: F401
from .khop_sampler import khop_sampler # noqa: F401
from .sample_neighbors import sample_neighbors # noqa: F401

__all__ = [
'graph_reindex',
'khop_sampler',
'sample_neighbors',
]
20 changes: 11 additions & 9 deletions python/paddle/geometric/sampling/graph_reindex.py
Expand Up @@ -19,13 +19,15 @@
from paddle.fluid import core
from paddle import _C_ops

__all__ = []


def graph_reindex(x,
neighbors,
count,
value_buffer=None,
index_buffer=None,
flag_buffer_hashtable=False,
has_buffer_hashtable=False,
name=None):
"""
Graph Reindex API.
Expand Down Expand Up @@ -57,7 +59,7 @@ def graph_reindex(x,
be int32, and should be filled with -1.
index_buffer (Tensor|None): Index buffer for hashtable. The data type should
be int32, and should be filled with -1.
flag_buffer_hashtable (bool): Whether to use buffer for hashtable to speed up.
has_buffer_hashtable (bool): Whether to use buffer for hashtable to speed up.
Default is False. Only useful for gpu version currently.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Expand Down Expand Up @@ -102,23 +104,23 @@ def graph_reindex(x,
# out_nodes: [0, 1, 2, 8, 9, 4, 7, 6, 3, 5]
"""
if flag_buffer_hashtable:
if has_buffer_hashtable:
if value_buffer is None or index_buffer is None:
raise ValueError(f"`value_buffer` and `index_buffer` should not"
"be None if `flag_buffer_hashtable` is True.")
"be None if `has_buffer_hashtable` is True.")

if _non_static_mode():
reindex_src, reindex_dst, out_nodes = \
_C_ops.graph_reindex(x, neighbors, count, value_buffer, index_buffer,
"flag_buffer_hashtable", flag_buffer_hashtable)
"flag_buffer_hashtable", has_buffer_hashtable)
return reindex_src, reindex_dst, out_nodes

check_variable_and_dtype(x, "X", ("int32", "int64"), "graph_reindex")
check_variable_and_dtype(neighbors, "Neighbors", ("int32", "int64"),
"graph_reindex")
check_variable_and_dtype(count, "Count", ("int32"), "graph_reindex")

if flag_buffer_hashtable:
if has_buffer_hashtable:
check_variable_and_dtype(value_buffer, "HashTable_Value", ("int32"),
"graph_reindex")
check_variable_and_dtype(index_buffer, "HashTable_Value", ("int32"),
Expand All @@ -137,14 +139,14 @@ def graph_reindex(x,
"Count":
count,
"HashTable_Value":
value_buffer if flag_buffer_hashtable else None,
value_buffer if has_buffer_hashtable else None,
"HashTable_Index":
index_buffer if flag_buffer_hashtable else None,
index_buffer if has_buffer_hashtable else None,
},
outputs={
"Reindex_Src": reindex_src,
"Reindex_Dst": reindex_dst,
"Out_Nodes": out_nodes
},
attrs={"flag_buffer_hashtable": flag_buffer_hashtable})
attrs={"flag_buffer_hashtable": has_buffer_hashtable})
return reindex_src, reindex_dst, out_nodes
2 changes: 2 additions & 0 deletions python/paddle/geometric/sampling/khop_sampler.py
Expand Up @@ -19,6 +19,8 @@
from paddle.fluid import core
from paddle import _C_ops

__all__ = []


def khop_sampler(row,
colptr,
Expand Down
21 changes: 11 additions & 10 deletions python/paddle/geometric/sampling/sample_neighbors.py
Expand Up @@ -19,6 +19,8 @@
from paddle.fluid import core
from paddle import _C_ops

__all__ = []


def sample_neighbors(row,
colptr,
Expand All @@ -27,7 +29,7 @@ def sample_neighbors(row,
perm_buffer=None,
sample_size=-1,
return_eids=False,
flag_perm_buffer=False,
has_perm_buffer=False,
name=None):
"""
Graph Sample Neighbors API.
Expand All @@ -53,13 +55,13 @@ def sample_neighbors(row,
eids (Tensor): The eid information of the input graph. If return_eids is True,
then `eids` should not be None. The data type should be the
same with `row`. Default is None.
perm_buffer (Tensor): Permutation buffer for fisher-yates sampling. If `flag_perm_buffer`
perm_buffer (Tensor): Permutation buffer for fisher-yates sampling. If `has_perm_buffer`
is True, then `perm_buffer` should not be None. The data type should
be the same with `row`. Default is None.
sample_size (int): The number of neighbors we need to sample. Default value is
-1, which means returning all the neighbors of the input nodes.
return_eids (bool): Whether to return eid information of sample edges. Default is False.
flag_perm_buffer (bool): Using the permutation for fisher-yates sampling in GPU. Default
has_perm_buffer (bool): Using the permutation for fisher-yates sampling in GPU. Default
value is false, means not using it.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Expand Down Expand Up @@ -94,17 +96,17 @@ def sample_neighbors(row,
raise ValueError(
f"`eids` should not be None if `return_eids` is True.")

if flag_perm_buffer:
if has_perm_buffer:
if perm_buffer is None:
raise ValueError(
f"`perm_buffer` should not be None if `flag_perm_buffer`"
f"`perm_buffer` should not be None if `has_perm_buffer`"
"is True.")

if _non_static_mode():
out_neighbors, out_count, out_eids = _C_ops.graph_sample_neighbors(
row, colptr, input_nodes, eids, perm_buffer, "sample_size",
sample_size, "return_eids", return_eids, "flag_perm_buffer",
flag_perm_buffer)
has_perm_buffer)
if return_eids:
return out_neighbors, out_count, out_eids
return out_neighbors, out_count
Expand All @@ -118,7 +120,7 @@ def sample_neighbors(row,
if return_eids:
check_variable_and_dtype(eids, "Eids", ("int32", "int64"),
"graph_sample_neighbors")
if flag_perm_buffer:
if has_perm_buffer:
check_variable_and_dtype(perm_buffer, "Perm_Buffer", ("int32", "int64"),
"graph_sample_neighbors")

Expand All @@ -132,8 +134,7 @@ def sample_neighbors(row,
"Col_Ptr": colptr,
"X": input_nodes,
"Eids": eids if return_eids else None,
"Perm_Buffer":
perm_buffer if flag_perm_buffer else None
"Perm_Buffer": perm_buffer if has_perm_buffer else None
},
outputs={
"Out": out_neighbors,
Expand All @@ -143,7 +144,7 @@ def sample_neighbors(row,
attrs={
"sample_size": sample_size,
"return_eids": return_eids,
"flag_perm_buffer": flag_perm_buffer
"flag_perm_buffer": has_perm_buffer
})
if return_eids:
return out_neighbors, out_count, out_eids
Expand Down

0 comments on commit d5ac560

Please sign in to comment.