aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-20 11:03:17 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-20 11:03:35 -0700
commitdfcf85601a0372f3130baf9fe36605fe528144cc (patch)
treee1aafe304b0540e71f0989ff7c9b7c94efaef840 /tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
parent8ed40cdd3ea7e9aea996339678efba2b2b04e1ad (diff)
parenta979fb29de16dff6495b51cb1363eea752b43513 (diff)
Merge pull request #20636 from Intel-tensorflow:agramesh/conv_bwd_reorder_reuse
PiperOrigin-RevId: 205421644
Diffstat (limited to 'tensorflow/core/kernels/mkl_conv_grad_input_ops.cc')
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_input_ops.cc9
1 files changed, 2 insertions, 7 deletions
diff --git a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
index 0af4568b47..b0f7faaa1a 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
@@ -722,14 +722,11 @@ class MklConv2DCustomBackpropInputOp
diff_src_tensor->flat<T>().data()));
// check if filter and diff_dst need reorder
- std::vector<primitive> net;
T* filter_data = nullptr;
if (fwd_filter_md.data.format !=
conv2d_bwd_input->GetFilterMemoryFormat()) {
filter.SetUsrMem(fwd_filter_md, &filter_tensor);
- filter.CheckReorderToOpMem(
- bwd_input_pd->weights_primitive_desc(),
- &net);
+ filter.CheckReorderToOpMem(bwd_input_pd->weights_primitive_desc());
filter_data = static_cast<T*>(filter.GetOpMem().get_data_handle());
} else {
filter_data = static_cast<T*>(const_cast<T*>(
@@ -740,15 +737,13 @@ class MklConv2DCustomBackpropInputOp
if (diff_dst_md.data.format !=
conv2d_bwd_input->GetDiffDstMemoryFormat()) {
diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
- diff_dst.CheckReorderToOpMem(
- bwd_input_pd->diff_dst_primitive_desc(), &net);
+ diff_dst.CheckReorderToOpMem(bwd_input_pd->diff_dst_primitive_desc());
diff_dst_data = static_cast<T*>(
diff_dst.GetOpMem().get_data_handle());
} else {
diff_dst_data = static_cast<T*>(const_cast<T*>(
diff_dst_tensor.flat<T>().data()));
}
- stream(stream::kind::eager).submit(net).wait();
// execute convolution input bwd
conv2d_bwd_input->Execute(diff_src_data, filter_data, diff_dst_data);