New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[geometric]Add paddle.geometric.send_u_recv API #44580
[geometric]Add paddle.geometric.send_u_recv API #44580
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
@@ -58,6 +58,10 @@ class GraphSendRecvOpMaker : public framework::OpProtoAndCheckerMaker { | |||
"The input tensor with data type float32, float64, int32, int64."); | |||
AddInput("Src_index", "The source index tensor."); | |||
AddInput("Dst_index", "The destination index tensor."); | |||
AddInput("OutSizeTensor", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的命名为什么要强调是Tensor,看看是否命名是Out_size,和上面的变量命名一致
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
改成Out_size了
@@ -58,6 +58,10 @@ class GraphSendRecvOpMaker : public framework::OpProtoAndCheckerMaker { | |||
"The input tensor with data type float32, float64, int32, int64."); | |||
AddInput("Src_index", "The source index tensor."); | |||
AddInput("Dst_index", "The destination index tensor."); | |||
AddInput("OutSizeTensor", | |||
"(Tensor<int>, optional). The 0th dimension of the output." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tensor, optional
为什么名字要加Tensor,是不是Out_size更好
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tensor, optional的写法是参考其他类似的 OP,这里不改动;另外已经改成Out_size了。
@@ -886,7 +886,7 @@ | |||
backward : gelu_grad | |||
|
|||
- api : graph_send_recv | |||
args : (Tensor x, Tensor src_index, Tensor dst_index, str pool_type = "SUM", int64_t out_size = 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里求解释一下,为什么不用设置默认值了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
发现可以加默认值,不会报错,改好了。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
out_size 默认 -1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
个人觉得还是默认0比较好。
const auto& src_dims = x.dims(); | ||
int64_t memset_size = 1; | ||
if (out_size <= 0) { | ||
for (int i = 0; i < src_dims.size(); ++i) { | ||
memset_size *= src_dims[i]; | ||
} | ||
} else { | ||
// set out dim following out_size. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
set大写
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
paddle/phi/infermeta/ternary.cc
Outdated
} | ||
out->set_dims(phi::make_ddim(dims_)); | ||
} | ||
out->set_dims(dims); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里把不定长那个维度设置成-1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
paddle/phi/infermeta/ternary.cc
Outdated
} else { | ||
dst_count->set_dims({out_size}); | ||
} | ||
dst_count->set_dims({dims[0]}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
const auto& src_dims = x.dims(); | ||
int64_t memset_size = 1; | ||
if (out_size <= 0) { | ||
for (int i = 0; i < src_dims.size(); ++i) { | ||
memset_size *= src_dims[i]; | ||
} | ||
} else { | ||
// set out dim following out_size. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
set -> Set
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -0,0 +1,162 @@ | |||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
似乎不用每个文件都要放一个函数,看看能一类放在一个文件里面
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
把文件名改成了 send_recv.py,用来放置既有 send 也有 recv 过程的函数;后续还会有 send.py,recv.py,用来存放其他类型的函数。
[1, 4, 5]] | ||
|
||
Args: | ||
x (Tensor): The input tensor, and the available data type is float32, float64, int32, int64. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是不是已经支持了float16
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
没有支持,在另一个 PR 修改。
# TODO(daisiming): Should we add judgement for out_size: max(dst_index) + 1. | ||
|
||
if _in_legacy_dygraph(): | ||
out_size = convert_out_size_to_list(out_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里需要确认一下是否能动转静
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
找留杰确认了,动转静不走这个分支。
e7f16f9
to
f8fac9a
Compare
pool_type.upper(), out_size) | ||
|
||
check_variable_and_dtype(x, "X", ("float32", "float64", "int32", "int64"), | ||
"send_u_recv") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这些地方得改成 graph_send_recv,因为 OP 的名字是这个。
return _C_ops.final_state_graph_send_recv(x, src_index, dst_index, | ||
pool_type.upper(), out_size) | ||
|
||
check_variable_and_dtype(x, "X", ("float32", "float64", "int32", "int64"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
看起来float16也得加上。在下一个 PR 添加。
… change_graph_send_recv_outsize
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Others
PR changes
APIs
Describe
out_size
attribute from int64_t to INTArray, supporting Tensor input and number input.Note