aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/mkl_concat_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/mkl_concat_op.cc')
-rw-r--r--tensorflow/core/kernels/mkl_concat_op.cc6
1 files changed, 3 insertions, 3 deletions
diff --git a/tensorflow/core/kernels/mkl_concat_op.cc b/tensorflow/core/kernels/mkl_concat_op.cc
index 31d1b949ef..d054f0d404 100644
--- a/tensorflow/core/kernels/mkl_concat_op.cc
+++ b/tensorflow/core/kernels/mkl_concat_op.cc
@@ -704,14 +704,14 @@ class MklConcatOp : public OpKernel {
if (input_tensors[k].NumElements() == 0)
continue;
- auto src_dims = TFShapeToMklDnnDims(
- mkl_input_shapes[k].GetTfShape());
auto src_md = mkl_input_shapes[k].GetMklLayout();
srcs[k].SetUsrMem(src_md, &input_tensors[k]);
- if (src_md.data.format != mkl_common_format)
+ if (src_md.data.format != mkl_common_format) {
+ memory::dims src_dims(src_md.data.dims, &src_md.data.dims[src_md.data.ndims]);
src_md = memory::desc(src_dims, MklDnnType<T>(),
mkl_common_format);
+ }
srcs_pd.push_back(memory::primitive_desc(src_md, cpu_engine));
}