diff options
Diffstat (limited to 'tensorflow/core/kernels/mkl_conv_grad_input_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/mkl_conv_grad_input_ops.cc | 9 |
1 files changed, 2 insertions, 7 deletions
diff --git a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc index 0af4568b47..b0f7faaa1a 100644 --- a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc @@ -722,14 +722,11 @@ class MklConv2DCustomBackpropInputOp diff_src_tensor->flat<T>().data())); // check if filter and diff_dst need reorder - std::vector<primitive> net; T* filter_data = nullptr; if (fwd_filter_md.data.format != conv2d_bwd_input->GetFilterMemoryFormat()) { filter.SetUsrMem(fwd_filter_md, &filter_tensor); - filter.CheckReorderToOpMem( - bwd_input_pd->weights_primitive_desc(), - &net); + filter.CheckReorderToOpMem(bwd_input_pd->weights_primitive_desc()); filter_data = static_cast<T*>(filter.GetOpMem().get_data_handle()); } else { filter_data = static_cast<T*>(const_cast<T*>( @@ -740,15 +737,13 @@ class MklConv2DCustomBackpropInputOp if (diff_dst_md.data.format != conv2d_bwd_input->GetDiffDstMemoryFormat()) { diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor); - diff_dst.CheckReorderToOpMem( - bwd_input_pd->diff_dst_primitive_desc(), &net); + diff_dst.CheckReorderToOpMem(bwd_input_pd->diff_dst_primitive_desc()); diff_dst_data = static_cast<T*>( diff_dst.GetOpMem().get_data_handle()); } else { diff_dst_data = static_cast<T*>(const_cast<T*>( diff_dst_tensor.flat<T>().data())); } - stream(stream::kind::eager).submit(net).wait(); // execute convolution input bwd conv2d_bwd_input->Execute(diff_src_data, filter_data, diff_dst_data); |