Skip to content

Commit

Permalink
[cherry-pick] updating mul and matmul with set_mem_desc and fix squee…
Browse files Browse the repository at this point in the history
…ze_transpose for MKLDNN (#47951)

* Fix slice bugs in MKLDNN when input dims are zeros (#46671)

* fix slice bugs

* fix

* update code

* fix

* update code

* updating mul and matmul with set_mem_desc (#45624)

* - mul & matmul changes

- fix

- bs16 correction of strides

* - cosmetic fixes

* - lint

* - fix

* - fix

* - format -> mem_desc

* - fix

* - fix

* - fix

* - fix

* - fix

* fix squueze_transpose (#47911)

Co-authored-by: Jacek Czaja <jacek.czaja@intel.com>
  • Loading branch information
yeliang2258 and jczaja committed Nov 29, 2022
1 parent 7a0b862 commit 9e2ba9b
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 32 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/framework/ir/graph_pattern_detector.cc
Expand Up @@ -1045,6 +1045,7 @@ PDNode *patterns::SeqConvEltAddRelu::operator()(
PDNode *patterns::Squeeze2Transpose2::operator()() {
auto *squeeze2_op_in = pattern->NewNode(squeeze2_op_in_repr())
->AsInput()
->assert_has_n_outputs(1)
->assert_is_op_input("squeeze2", "X");
auto *squeeze2_op = pattern->NewNode(squeeze2_op_repr())
->assert_is_op("squeeze2")
Expand Down
33 changes: 18 additions & 15 deletions paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc
Expand Up @@ -214,10 +214,7 @@ class MatMulMKLDNNHandler
}
astream.wait();

auto format =
MKLDNNFormatForSize(out->dims().size(), dnnl::memory::format_tag::nchw);
out->set_format(format);
out->set_layout(DataLayout::kMKLDNN);
out->set_mem_desc(dst_memory_p->get_desc().reshape(out->dims()));
}

std::shared_ptr<dnnl::memory> AcquireDstMemory(
Expand Down Expand Up @@ -651,10 +648,18 @@ void ExecuteMatMulV2(const ExecutionContext &ctx,
auto &astream = MKLDNNDeviceContext::tls().get_stream();
matmul_p->execute(astream, matmul_args);
astream.wait();
auto format =
MKLDNNFormatForSize(out->dims().size(), dnnl::memory::format_tag::nchw);
out->set_format(format);
out->set_layout(DataLayout::kMKLDNN);

// TODO(jczaja): Explain why int8 format of dst is ABCD and do not need
// permute
if (IsOutputFused(ctx) && !IsInt8<T_out>()) {
auto axis = ctx.Attr<std::vector<int>>("fused_transpose_Out");
auto permuted_md = dst_memory_p->get_desc().permute_axes(axis);
out->set_mem_desc(
permuted_md.reshape(phi::vectorize<int64_t>(out->dims())));
} else {
out->set_mem_desc(
dst_memory_p->get_desc().reshape(phi::vectorize<int64_t>(out->dims())));
}
}

template <typename T>
Expand Down Expand Up @@ -836,8 +841,7 @@ class MatMulV2GradMKLDNNKernel : public paddle::framework::OpKernel<T> {
reduction_p->execute(astream, reduction_args);
astream.wait();

dx->set_format(paddle::platform::GetMKLDNNFormat(
dst_memory_p->get_desc().reshape(squeezed_dims)));
dx->set_mem_desc(dst_memory_p->get_desc().reshape(squeezed_dims));
}

std::vector<int64_t> ExtendDimsWithOnes(const std::vector<int64_t> &dims,
Expand Down Expand Up @@ -1119,9 +1123,8 @@ void MatMulGradMKLDNNKernel<T>::ExecuteMatMulGrad(
matmul_p->execute(astream, matmul_args);
astream.wait();

out->set_layout(framework::DataLayout::kMKLDNN);
out->set_format(platform::GetMKLDNNFormat(
dst_memory_p->get_desc().reshape(vectorize<int64_t>(out->dims()))));
out->set_mem_desc(
dst_memory_p->get_desc().reshape(vectorize<int64_t>(out->dims())));
}

template <typename T>
Expand Down Expand Up @@ -1184,13 +1187,13 @@ void MatMulGradMKLDNNKernel<T>::RunKernel(const ExecutionContext &ctx) const {
if (dx) {
if (dx_dims != x.dims()) {
dx->Resize(dx_dims);
dx->set_format(x.format());
dx->set_mem_desc(x.mem_desc());
}
}
if (dy) {
if (dy_dims != y.dims()) {
dy->Resize(dy_dims);
dy->set_format(y.format());
dy->set_mem_desc(y.mem_desc());
}
}
}
Expand Down
26 changes: 12 additions & 14 deletions paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc
Expand Up @@ -221,7 +221,7 @@ class MulPrimitiveFactory {
to_void_cast<T>(x_tmp.data<T>()));

x_tmp.Resize(data->dims());
x_tmp.set_format(platform::GetMKLDNNFormat(dst_mdesc));
x_tmp.set_mem_desc(dst_mdesc);
data_matrix = framework::ReshapeToMatrix(x_tmp, num_col_dims);
} else {
data_matrix = framework::ReshapeToMatrix(*data, num_col_dims);
Expand All @@ -235,11 +235,7 @@ class MulPrimitiveFactory {
const Tensor *in) {
x_input_->set_data_handle(to_void_cast<XT>(in->data<XT>()));
output_->set_data_handle(out->mutable_data<OT>(ctx.GetPlace()));

if (out->format() == MKLDNNMemoryFormat::undef) {
auto output_format = platform::GetMKLDNNFormat(*output_);
out->set_format((MKLDNNMemoryFormat)output_format);
}
out->set_mem_desc(output_->get_desc());
}

template <typename T>
Expand Down Expand Up @@ -272,7 +268,7 @@ class MulPrimitiveFactory {
auto buffer_size = dst_desc.get_size();

OT *output_data = output->mutable_data<OT>(ctx.GetPlace(), buffer_size);
output->set_format(paddle::platform::GetMKLDNNFormat(dst_desc));
output->set_mem_desc(dst_desc);
return memory(dst_desc, engine_, to_void_cast<OT>(output_data));
}

Expand Down Expand Up @@ -392,9 +388,10 @@ class MulMKLDNNINT8Kernel : public framework::OpKernel<XT> {
if (out_dims.size() != 2) {
out->Resize(out_dims);
}
out->set_layout(DataLayout::kMKLDNN);
out->set_format(platform::MKLDNNFormatForSize(out_dims.size(),
MKLDNNMemoryFormat::nchw));

auto in_md = dnnl::memory::desc(*dnnl_primitive_desc_query_md(
mul.get_primitive_desc(), dnnl_query_dst_md, 0));
out->set_mem_desc(in_md.reshape(phi::vectorize<int64_t>(out->dims())));
}
};

Expand Down Expand Up @@ -442,10 +439,11 @@ class MulMKLDNNKernel : public framework::OpKernel<XT> {
matmul_p->execute(astream, matmul_args);
astream.wait();

out->set_layout(framework::DataLayout::kMKLDNN);
// plain output formats are enforced inside handler
out->set_format(platform::MKLDNNFormatForSize(
out->dims().size(), dnnl::memory::format_tag::nchw));
// This kernel is flattening dims so then we need to unflattened version
// that should be set in out reshape require plain layout, but
// MatmulV2MKLDNNHanlder enforces one so it should work
out->set_mem_desc(
dst_memory_p->get_desc().reshape(phi::vectorize<int64_t>(out->dims())));
}

private:
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/platform/mkldnn_reuse.h
Expand Up @@ -301,7 +301,8 @@ class MatMulV2MKLDNNHandler
out_strides[i] = out_ddims[i + 1] * out_strides[i + 1];
}

if (!IsInt8<OT>() && !IsBfloat16<OT>() && is_output_fused) {
// TODO(jczaja): Why not for int8??
if (!IsInt8<OT>() && is_output_fused) {
out_strides = FakeTransposeStrides(out_ddims);
}

Expand Down
6 changes: 4 additions & 2 deletions paddle/phi/kernels/transfer_layout_kernel.cc
Expand Up @@ -121,8 +121,10 @@ void TransferLayoutMKLDNN(const Context& dev_ctx,
OneDNNContext::tls().set_cur_paddle_data_layout(src_layout);
}

out->set_layout(DataLayout::ONEDNN);
out->set_format(out_format);
dnnl::memory::desc out_mem_desc(vectorize<int64_t>(out->dims()),
funcs::ToOneDNNDataType(x.dtype()),
out_format);
out->set_mem_desc(out_mem_desc);
} else if (src_layout == DataLayout::ONEDNN &&
dst_layout != DataLayout::ONEDNN) {
// Case2 - transfrom from MKLDNN OPKernel to Non-MKLDNN OPKernel
Expand Down

0 comments on commit 9e2ba9b

Please sign in to comment.