diff options
Diffstat (limited to 'tensorflow/core/kernels/mkl_concat_op.cc')
-rw-r--r-- | tensorflow/core/kernels/mkl_concat_op.cc | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/tensorflow/core/kernels/mkl_concat_op.cc b/tensorflow/core/kernels/mkl_concat_op.cc index 31d1b949ef..d054f0d404 100644 --- a/tensorflow/core/kernels/mkl_concat_op.cc +++ b/tensorflow/core/kernels/mkl_concat_op.cc @@ -704,14 +704,14 @@ class MklConcatOp : public OpKernel { if (input_tensors[k].NumElements() == 0) continue; - auto src_dims = TFShapeToMklDnnDims( - mkl_input_shapes[k].GetTfShape()); auto src_md = mkl_input_shapes[k].GetMklLayout(); srcs[k].SetUsrMem(src_md, &input_tensors[k]); - if (src_md.data.format != mkl_common_format) + if (src_md.data.format != mkl_common_format) { + memory::dims src_dims(src_md.data.dims, &src_md.data.dims[src_md.data.ndims]); src_md = memory::desc(src_dims, MklDnnType<T>(), mkl_common_format); + } srcs_pd.push_back(memory::primitive_desc(src_md, cpu_engine)); } |