aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/mkl_concat_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/mkl_concat_op.cc')
-rw-r--r--tensorflow/core/kernels/mkl_concat_op.cc10
1 files changed, 5 insertions, 5 deletions
diff --git a/tensorflow/core/kernels/mkl_concat_op.cc b/tensorflow/core/kernels/mkl_concat_op.cc
index 31d1b949ef..6f490cdc23 100644
--- a/tensorflow/core/kernels/mkl_concat_op.cc
+++ b/tensorflow/core/kernels/mkl_concat_op.cc
@@ -704,14 +704,14 @@ class MklConcatOp : public OpKernel {
if (input_tensors[k].NumElements() == 0)
continue;
- auto src_dims = TFShapeToMklDnnDims(
- mkl_input_shapes[k].GetTfShape());
auto src_md = mkl_input_shapes[k].GetMklLayout();
srcs[k].SetUsrMem(src_md, &input_tensors[k]);
- if (src_md.data.format != mkl_common_format)
+ if (src_md.data.format != mkl_common_format) {
+ memory::dims src_dims(src_md.data.dims, &src_md.data.dims[src_md.data.ndims]);
src_md = memory::desc(src_dims, MklDnnType<T>(),
mkl_common_format);
+ }
srcs_pd.push_back(memory::primitive_desc(src_md, cpu_engine));
}
@@ -756,11 +756,10 @@ class MklConcatOp : public OpKernel {
}
std::vector<primitive::at> inputs;
- std::vector<primitive> net;
if (isMklReorderNeeded) {
for (int k = 0; k < input_tensors.size(); k++) {
if (input_tensors[k].NumElements() > 0) {
- srcs[k].CheckReorderToOpMem(srcs_pd[k], &net);
+ srcs[k].CheckReorderToOpMem(srcs_pd[k]);
}
}
}
@@ -806,6 +805,7 @@ class MklConcatOp : public OpKernel {
dst.SetUsrMem(dst_md, dst_tensor);
auto concat_op = concat(concat_pd, inputs, dst.GetOpMem());
+ std::vector<primitive> net;
net.push_back(concat_op);
stream(stream::kind::eager).submit(net).wait();
} catch (mkldnn::error& e) {