diff options
Diffstat (limited to 'tensorflow/core/kernels/mkl_maxpooling_op.cc')
-rw-r--r-- | tensorflow/core/kernels/mkl_maxpooling_op.cc | 27 |
1 files changed, 10 insertions, 17 deletions
diff --git a/tensorflow/core/kernels/mkl_maxpooling_op.cc b/tensorflow/core/kernels/mkl_maxpooling_op.cc index e27881f882..1e0ee258b0 100644 --- a/tensorflow/core/kernels/mkl_maxpooling_op.cc +++ b/tensorflow/core/kernels/mkl_maxpooling_op.cc @@ -276,11 +276,6 @@ class MklMaxPoolingGradOp : public OpKernel { mkl_context.pooling_res[dnnResourceDiffSrc] = const_cast<void*>( static_cast<const void*>(output_tensor->flat<T>().data())); - int64 output_size = output_tensor->NumElements(); - for (int64 i = 0; i < output_size; ++i) { - (static_cast<float*>(mkl_context.pooling_res[dnnResourceDiffSrc]))[i] = 0; - } - CHECK_EQ( dnnExecute_F32(mkl_context.prim_pooling_bwd, mkl_context.pooling_res), E_SUCCESS); @@ -387,19 +382,18 @@ class MklMaxPoolingGradOp : public OpKernel { if (workspace_enabled == false) { if (convert_input != nullptr) { if (input_in_mkl_format == false) { - CHECK_EQ(dnnConversionExecute_F32( - convert_input, - const_cast<void*>(static_cast<const void*>( - tensor_in.flat<T>().data())), - input_buf), - E_SUCCESS); + CHECK_EQ( + dnnConversionExecute_F32( + convert_input, const_cast<void*>(static_cast<const void*>( + tensor_in.flat<T>().data())), + input_buf), + E_SUCCESS); CHECK_EQ(dnnDelete_F32(convert_input), E_SUCCESS); convert_input = nullptr; } else { input_shape.GetConvertedFlatData( - lt_input_prim, - const_cast<void*>( - static_cast<const void*>(tensor_in.flat<T>().data())), + lt_input_prim, const_cast<void*>(static_cast<const void*>( + tensor_in.flat<T>().data())), input_buf); } pooling_resfwd[dnnResourceSrc] = input_buf; @@ -444,9 +438,8 @@ class MklMaxPoolingGradOp : public OpKernel { CHECK_EQ(dnnDelete_F32(convert_outbackprop), E_SUCCESS); } else { output_backprop_shape.GetConvertedFlatData( - lt_outbackprop_prim, - const_cast<void*>( - static_cast<const void*>(out_backprop.flat<T>().data())), + lt_outbackprop_prim, const_cast<void*>(static_cast<const void*>( + out_backprop.flat<T>().data())), outbackprop_buf); } pooling_res[dnnResourceDiffDst] = outbackprop_buf; |