Skip to content

Commit

Permalink
[IPU] add activation ops (#43662)
Browse files Browse the repository at this point in the history
* add argmin and argsort ops (#800)

* add argmin and arsort ops

* Add dot bmm ops (#803)

* add bmm

* add dot op

* clean CreateConst

* clean CreateCast

* add activation ops (#808)

* add activation ops

* fix 1function-redefined error
  • Loading branch information
gglin001 committed Jun 21, 2022
1 parent 2a795df commit 2353db3
Show file tree
Hide file tree
Showing 14 changed files with 986 additions and 63 deletions.
Expand Up @@ -119,6 +119,21 @@ Node *tanh_handler(Graph *graph, Node *node) {
return activation_op_handler(graph, node, "popart_tanh");
}

Node *brelu_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto t_min_ = BOOST_GET_CONST(float, op->GetAttr("t_min"));
auto t_max_ = BOOST_GET_CONST(float, op->GetAttr("t_max"));
auto x = GetInputVarNode("X", node);
auto cli_min = CreateConst(graph, node, std::vector<float>{t_min_}, {1},
ONNXDataType::FLOAT)
->outputs.front();
auto clip_max = CreateConst(graph, node, std::vector<float>{t_max_}, {1},
ONNXDataType::FLOAT)
->outputs.front();
return CreateBaseOp(graph, node, "popart_clip", {x, cli_min, clip_max},
node->outputs);
}

Node *gelu_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto approximate_ = BOOST_GET_CONST(bool, op->GetAttr("approximate"));
Expand Down Expand Up @@ -160,6 +175,245 @@ Node *log_softmax_handler(Graph *graph, Node *node) {
node->outputs);
}

Node *elu_handler(Graph *graph, Node *node) {
auto alpha_ = BOOST_GET_CONST(float, node->Op()->GetAttr("alpha"));
return CreateBaseOp(graph, node, "popart_elu", node->inputs, node->outputs,
{
{"alpha", alpha_},
});
}

Node *hard_shrink_handler(Graph *graph, Node *node) {
auto threshold_ = BOOST_GET_CONST(float, node->Op()->GetAttr("threshold"));
return CreateBaseOp(graph, node, "popart_shrink", node->inputs, node->outputs,
{
{"lambd", threshold_},
{"bias", 0.0f},
});
}

Node *hard_sigmoid_handler(Graph *graph, Node *node) {
auto slope_ = BOOST_GET_CONST(float, node->Op()->GetAttr("slope"));
auto offset_ = BOOST_GET_CONST(float, node->Op()->GetAttr("offset"));
return CreateBaseOp(graph, node, "popart_hardsigmoid", node->inputs,
node->outputs,
{
{"alpha", slope_},
{"beta", offset_},
});
}

Node *hard_swish_handler(Graph *graph, Node *node) {
auto x = GetInputVarNode("X", node);
auto scale_ = BOOST_GET_CONST(float, node->Op()->GetAttr("scale"));
auto offset_ = BOOST_GET_CONST(float, node->Op()->GetAttr("offset"));
auto threshold_ = BOOST_GET_CONST(float, node->Op()->GetAttr("threshold"));
auto scale_node =
CreateConst(graph, node, std::vector<float>{scale_}, {1}, GetVarDType(x))
->outputs.front();
auto offset_node =
CreateConst(graph, node, std::vector<float>{offset_}, {1}, GetVarDType(x))
->outputs.front();
auto add_node = CreateBaseOp(graph, node, "popart_add", {x, offset_node}, {})
->outputs.front();
auto cli_min = CreateConst(graph, node, std::vector<float>{0.0}, {1},
ONNXDataType::FLOAT)
->outputs.front();
auto clip_max = CreateConst(graph, node, std::vector<float>{threshold_}, {1},
ONNXDataType::FLOAT)
->outputs.front();
auto clip_node = CreateBaseOp(graph, node, "popart_clip",
{add_node, cli_min, clip_max}, {})
->outputs.front();
auto mul_node = CreateBaseOp(graph, node, "popart_mul", {x, clip_node}, {})
->outputs.front();
return CreateBaseOp(graph, node, "popart_div", {mul_node, scale_node},
{GetOutputVarNode("Out", node)});
}

Node *leaky_relu_handler(Graph *graph, Node *node) {
auto alpha_ = BOOST_GET_CONST(float, node->Op()->GetAttr("alpha"));
return CreateBaseOp(graph, node, "popart_leakyrelu", node->inputs,
node->outputs,
{
{"alpha", alpha_},
});
}

Node *log10_handler(Graph *graph, Node *node) {
auto x = GetInputVarNode("X", node);
float ln10 = 2.30258509299404568401;
auto ln10_tensor =
CreateConst(graph, node, std::vector<float>{ln10}, {1}, GetVarDType(x))
->outputs.front();
auto log = CreateBaseOp(graph, node, "popart_log", {x}, {})->outputs.front();
return CreateBaseOp(graph, node, "popart_div", {log, ln10_tensor},
node->outputs);
}

Node *log1p_handler(Graph *graph, Node *node) {
auto x = GetInputVarNode("X", node);
auto one =
CreateConst(graph, node, std::vector<float>{1.0}, {1}, GetVarDType(x))
->outputs.front();
auto add =
CreateBaseOp(graph, node, "popart_add", {x, one}, {})->outputs.front();
return CreateBaseOp(graph, node, "popart_log", {add}, node->outputs);
}

Node *log2_handler(Graph *graph, Node *node) {
auto x = GetInputVarNode("X", node);
float ln2 = 0.693147180559945309;
auto ln2_tensor =
CreateConst(graph, node, std::vector<float>{ln2}, {1}, GetVarDType(x))
->outputs.front();
auto log = CreateBaseOp(graph, node, "popart_log", {x}, {})->outputs.front();
return CreateBaseOp(graph, node, "popart_div", {log, ln2_tensor},
node->outputs);
}

Node *logsigmoid_handler(Graph *graph, Node *node) {
auto sigmoid = CreateBaseOp(graph, node, "popart_sigmoid",
{GetInputVarNode("X", node)}, {})
->outputs.front();
return CreateBaseOp(graph, node, "popart_log", {sigmoid}, node->outputs);
}

Node *mish_handler(Graph *graph, Node *node) {
auto threshold_ = BOOST_GET_CONST(float, node->Op()->GetAttr("threshold"));
if (!is_float_equal(threshold_, 20.0f)) {
PADDLE_THROW(platform::errors::Unimplemented(
"For mish op, only support threshold = 20.0"));
}
auto x = GetInputVarNode("X", node);
auto softplus =
CreateBaseOp(graph, node, "popart_softplus", {x}, {})->outputs.front();
auto tanh =
CreateBaseOp(graph, node, "popart_tanh", {softplus}, {})->outputs.front();
return CreateBaseOp(graph, node, "popart_mul", {x, tanh}, node->outputs);
}

Node *prelu_handler(Graph *graph, Node *node) {
auto x = GetInputVarNode("X", node);
auto alpha = GetInputVarNode("Alpha", node);
auto out = GetOutputVarNode("Out", node);
auto x_rank = x->Var()->GetShape().size();
auto alpha_rank = alpha->Var()->GetShape().size();
if (x_rank != alpha_rank) {
if (alpha_rank > 1) {
PADDLE_THROW(platform::errors::Unimplemented(
"For prelu op, Only support rank of alpha <=1 while Rank(alpha) != "
"Rank(input)."));
}
}

if (x_rank != alpha_rank) {
if (alpha_rank > 1) {
PADDLE_THROW(platform::errors::Unimplemented(
"For prelu op, Only support rank of alpha <= 1 while rank of alpha "
"is not equal with rank of input for operator prelu"));
}
if (x_rank <= 1) {
PADDLE_THROW(
platform::errors::Unimplemented("For prelu op, Rank of input should "
"greater than 2 for operator prelu"));
}
auto shape = std::vector<int64_t>(x_rank - 1, 1);
shape[0] = -1;
int64_t size = shape.size();
auto dim = std::vector<int64_t>{size};
auto reshape_const =
CreateConst(graph, node, shape, dim, ONNXDataType::INT64)
->outputs.front();
alpha =
CreateBaseOp(graph, node, "popart_reshape", {alpha, reshape_const}, {})
->outputs.front();
}
return CreateBaseOp(graph, node, "popart_prelu", {x, alpha}, {out});
}

Node *relu6_handler(Graph *graph, Node *node) {
auto threshold_ = BOOST_GET_CONST(float, node->Op()->GetAttr("threshold"));
auto cli_min = CreateConst(graph, node, std::vector<float>{0.0}, {1},
ONNXDataType::FLOAT)
->outputs.front();
auto clip_max = CreateConst(graph, node, std::vector<float>{threshold_}, {1},
ONNXDataType::FLOAT)
->outputs.front();
return CreateBaseOp(graph, node, "popart_clip",
{GetInputVarNode("X", node), cli_min, clip_max},
node->outputs);
}

Node *rsqrt_handler(Graph *graph, Node *node) {
auto rsqrt =
CreateBaseOp(graph, node, "popart_sqrt", {GetInputVarNode("X", node)}, {})
->outputs.front();
return CreateBaseOp(graph, node, "popart_reciprocal", {rsqrt}, node->outputs);
}

Node *selu_handler(Graph *graph, Node *node) {
auto alpha_ = BOOST_GET_CONST(float, node->Op()->GetAttr("alpha"));
auto scale_ = BOOST_GET_CONST(float, node->Op()->GetAttr("scale"));
return CreateBaseOp(graph, node, "popart_selu", node->inputs, node->outputs,
{
{"alpha", alpha_},
{"gamma", scale_},
});
}

Node *silu_handler(Graph *graph, Node *node) {
auto x = GetInputVarNode("X", node);
auto sigmoid =
CreateBaseOp(graph, node, "popart_sigmoid", {x}, {})->outputs.front();
return CreateBaseOp(graph, node, "popart_mul", {x, sigmoid}, node->outputs);
}

Node *softshrink_handler(Graph *graph, Node *node) {
auto lambda_ = BOOST_GET_CONST(float, node->Op()->GetAttr("lambda"));
return CreateBaseOp(graph, node, "popart_shrink", node->inputs, node->outputs,
{
{"lambd", lambda_},
{"bias", lambda_},
});
}

Node *square_handler(Graph *graph, Node *node) {
auto x = GetInputVarNode("X", node);
return CreateBaseOp(graph, node, "popart_mul", {x, x}, node->outputs);
}

Node *swish_handler(Graph *graph, Node *node) {
auto x = GetInputVarNode("X", node);
auto out = GetOutputVarNode("Out", node);
auto beta_ = BOOST_GET_CONST(float, node->Op()->GetAttr("beta"));
auto beta_node =
CreateConst(graph, node, std::vector<float>{beta_}, {1}, GetVarDType(x))
->outputs.front();
auto beta_x_node = CreateBaseOp(graph, node, "popart_mul", {x, beta_node}, {})
->outputs.front();
auto sigmod_node =
CreateBaseOp(graph, node, "popart_sigmoid", {beta_x_node}, {})
->outputs.front();
return CreateBaseOp(graph, node, "popart_mul", {x, sigmod_node}, {out});
}

Node *tanh_shrink_handler(Graph *graph, Node *node) {
auto x = GetInputVarNode("X", node);
auto tanh =
CreateBaseOp(graph, node, "popart_tanh", {x}, {})->outputs.front();
return CreateBaseOp(graph, node, "popart_sub", {x, tanh}, node->outputs);
}

Node *thresholded_relu_handler(Graph *graph, Node *node) {
auto threshold_ = BOOST_GET_CONST(float, node->Op()->GetAttr("threshold"));
auto x = GetInputVarNode("X", node);
return CreateBaseOp(graph, node, "popart_thresholdedrelu", {x}, node->outputs,
{
{"alpha", threshold_},
});
}

} // namespace
} // namespace ipu
} // namespace platform
Expand Down Expand Up @@ -188,5 +442,26 @@ REGISTER_HANDLER(softsign, softsign_handler);
REGISTER_HANDLER(sqrt, sqrt_handler);
REGISTER_HANDLER(tan, tan_handler);
REGISTER_HANDLER(tanh, tanh_handler);
REGISTER_HANDLER(brelu, brelu_handler);
REGISTER_HANDLER(gelu, gelu_handler);
REGISTER_HANDLER(log_softmax, log_softmax_handler);
REGISTER_HANDLER(elu, elu_handler);
REGISTER_HANDLER(hard_shrink, hard_shrink_handler);
REGISTER_HANDLER(hard_sigmoid, hard_sigmoid_handler);
REGISTER_HANDLER(hard_swish, hard_swish_handler);
REGISTER_HANDLER(leaky_relu, leaky_relu_handler);
REGISTER_HANDLER(log10, log10_handler);
REGISTER_HANDLER(log1p, log1p_handler);
REGISTER_HANDLER(log2, log2_handler);
REGISTER_HANDLER(logsigmoid, logsigmoid_handler);
REGISTER_HANDLER(mish, mish_handler);
REGISTER_HANDLER(prelu, prelu_handler);
REGISTER_HANDLER(relu6, relu6_handler);
REGISTER_HANDLER(rsqrt, rsqrt_handler);
REGISTER_HANDLER(selu, selu_handler);
REGISTER_HANDLER(silu, silu_handler);
REGISTER_HANDLER(softshrink, softshrink_handler);
REGISTER_HANDLER(square, square_handler);
REGISTER_HANDLER(swish, swish_handler);
REGISTER_HANDLER(tanh_shrink, tanh_shrink_handler);
REGISTER_HANDLER(thresholded_relu, thresholded_relu_handler);
Expand Up @@ -117,15 +117,20 @@ const bool is_float_equal(float a, float b, float eps) {
return std::fabs(a - b) <= eps;
}

const int GetOutputVarDType(const Node *node, const std::string &output_name) {
auto out_node = GetOutputVarNode(output_name, node);
PADDLE_ENFORCE_NOT_NULL(out_node, platform::errors::Unavailable(
"Node's out node does not exist."));
auto var = out_node->Var();
const ONNXDataType GetVarDType(const Node *node) {
auto var = node->Var();
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::Unavailable("Node is not a variable."));
auto proto_var_type = var->GetDataType();
return static_cast<int>(VarType2OnnxDType(proto_var_type));
return VarType2OnnxDType(proto_var_type);
}

const ONNXDataType GetOutputVarDType(const Node *node,
const std::string &output_name) {
auto out_node = GetOutputVarNode(output_name, node);
PADDLE_ENFORCE_NOT_NULL(out_node, platform::errors::Unavailable(
"Node's out node does not exist."));
return GetVarDType(out_node);
}

} // namespace ipu
Expand Down
Expand Up @@ -78,8 +78,9 @@ Node *GetOutputVarNodeByVarName(const std::string &var_name,
const Node *op_node);

const bool is_float_equal(float a, float b, float eps = 1e-8);
const int GetOutputVarDType(const Node *node,
const std::string &output_name = "Out");
const ONNXDataType GetVarDType(const Node *node);
const ONNXDataType GetOutputVarDType(const Node *node,
const std::string &output_name = "Out");

} // namespace ipu
} // namespace platform
Expand Down

0 comments on commit 2353db3

Please sign in to comment.