diff options
Diffstat (limited to 'tensorflow/core/kernels/mkl_concat_op.cc')
-rw-r--r-- | tensorflow/core/kernels/mkl_concat_op.cc | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/tensorflow/core/kernels/mkl_concat_op.cc b/tensorflow/core/kernels/mkl_concat_op.cc index 31d1b949ef..6f490cdc23 100644 --- a/tensorflow/core/kernels/mkl_concat_op.cc +++ b/tensorflow/core/kernels/mkl_concat_op.cc @@ -704,14 +704,14 @@ class MklConcatOp : public OpKernel { if (input_tensors[k].NumElements() == 0) continue; - auto src_dims = TFShapeToMklDnnDims( - mkl_input_shapes[k].GetTfShape()); auto src_md = mkl_input_shapes[k].GetMklLayout(); srcs[k].SetUsrMem(src_md, &input_tensors[k]); - if (src_md.data.format != mkl_common_format) + if (src_md.data.format != mkl_common_format) { + memory::dims src_dims(src_md.data.dims, &src_md.data.dims[src_md.data.ndims]); src_md = memory::desc(src_dims, MklDnnType<T>(), mkl_common_format); + } srcs_pd.push_back(memory::primitive_desc(src_md, cpu_engine)); } @@ -756,11 +756,10 @@ class MklConcatOp : public OpKernel { } std::vector<primitive::at> inputs; - std::vector<primitive> net; if (isMklReorderNeeded) { for (int k = 0; k < input_tensors.size(); k++) { if (input_tensors[k].NumElements() > 0) { - srcs[k].CheckReorderToOpMem(srcs_pd[k], &net); + srcs[k].CheckReorderToOpMem(srcs_pd[k]); } } } @@ -806,6 +805,7 @@ class MklConcatOp : public OpKernel { dst.SetUsrMem(dst_md, dst_tensor); auto concat_op = concat(concat_pd, inputs, dst.GetOpMem()); + std::vector<primitive> net; net.push_back(concat_op); stream(stream::kind::eager).submit(net).wait(); } catch (mkldnn::error& e) { |