Skip to content

Commit

Permalink
Merge pull request #41688 from Intel-tensorflow:yang/fix_concat
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 323612153
Change-Id: I72c92a6af1c5a908e599185ce6a03b6f42d04533
  • Loading branch information
tensorflower-gardener committed Jul 28, 2020
2 parents d98f120 + 0be68a5 commit 482d273
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 7 deletions.
14 changes: 7 additions & 7 deletions tensorflow/core/kernels/mkl_concat_op.cc
Expand Up @@ -376,19 +376,20 @@ class MklConcatFwdPrimitive : public MklPrimitive {
context_.data_mem_shdptr.push_back(src_mem);
context_.data_mem.push_back(*context_.data_mem_shdptr[i]);
}
// Store the expected memory format
context_.dst_md.reset(new memory::desc({concat_fwd_dims.dst_dims},
MklDnnType<T>(),
concat_fwd_dims.mkl_common_format));
// Create a concat primitive descriptor
#ifdef ENABLE_MKLDNN_V1
context_.fwd_pd.reset(new concat::primitive_desc(
concat_fwd_dims.concat_dims, context_.src_md, cpu_engine_));
*context_.dst_md, concat_fwd_dims.concat_dims, context_.src_md,
cpu_engine_));
#else
context_.fwd_pd.reset(new concat::primitive_desc(
concat_fwd_dims.concat_dims, context_.src_pd));
#endif // ENABLE_MKLDNN_V1

// Store the expected memory format
context_.dst_md.reset(new memory::desc({concat_fwd_dims.dst_dims},
MklDnnType<T>(),
concat_fwd_dims.mkl_common_format));
#ifdef ENABLE_MKLDNN_V1
// Create memory primitive based on dummy data
context_.dst_mem.reset(
Expand All @@ -404,8 +405,7 @@ class MklConcatFwdPrimitive : public MklPrimitive {
#ifdef ENABLE_MKLDNN_V1
context_.concat_fwd.reset(new concat(*context_.fwd_pd));
std::unordered_map<int, memory> net_args = {
{ MKLDNN_ARG_DST,
*context_.dst_mem }};
{MKLDNN_ARG_DST, *context_.dst_mem}};
for (int i = 0; i < concat_fwd_dims.num_inputs; ++i) {
net_args.insert({MKLDNN_ARG_MULTIPLE_SRC + i, context_.data_mem[i]});
}
Expand Down
16 changes: 16 additions & 0 deletions tensorflow/python/kernel_tests/concat_op_test.py
Expand Up @@ -68,6 +68,22 @@ def testVStack(self):
self.assertAllEqual(result[:, :4], params[p1])
self.assertAllEqual(result[:, 4:], params[p2])

@test_util.run_deprecated_v1
def test4DStack(self):
with self.session(use_gpu=True):
p1 = array_ops.placeholder(dtypes.float32, shape=[2, 3, 1, 1])
p2 = array_ops.placeholder(dtypes.float32, shape=[2, 3, 4, 1])
c = array_ops.concat([p1, p2], 2)
params = {
p1: np.random.rand(2, 3, 1, 1).astype("f"),
p2: np.random.rand(2, 3, 4, 1).astype("f")
}
result = c.eval(feed_dict=params)

self.assertEqual(result.shape, c.get_shape())
self.assertAllEqual(result[:, :, :1, :], params[p1])
self.assertAllEqual(result[:, :, 1:, :], params[p2])

def testInt32GPU(self):
with test_util.use_gpu():
p1 = np.random.rand(2, 3).astype("i")
Expand Down

0 comments on commit 482d273

Please sign in to comment.