aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/mkl_maxpooling_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/mkl_maxpooling_op.cc')
-rw-r--r--tensorflow/core/kernels/mkl_maxpooling_op.cc27
1 files changed, 10 insertions, 17 deletions
diff --git a/tensorflow/core/kernels/mkl_maxpooling_op.cc b/tensorflow/core/kernels/mkl_maxpooling_op.cc
index e27881f882..1e0ee258b0 100644
--- a/tensorflow/core/kernels/mkl_maxpooling_op.cc
+++ b/tensorflow/core/kernels/mkl_maxpooling_op.cc
@@ -276,11 +276,6 @@ class MklMaxPoolingGradOp : public OpKernel {
mkl_context.pooling_res[dnnResourceDiffSrc] = const_cast<void*>(
static_cast<const void*>(output_tensor->flat<T>().data()));
- int64 output_size = output_tensor->NumElements();
- for (int64 i = 0; i < output_size; ++i) {
- (static_cast<float*>(mkl_context.pooling_res[dnnResourceDiffSrc]))[i] = 0;
- }
-
CHECK_EQ(
dnnExecute_F32(mkl_context.prim_pooling_bwd, mkl_context.pooling_res),
E_SUCCESS);
@@ -387,19 +382,18 @@ class MklMaxPoolingGradOp : public OpKernel {
if (workspace_enabled == false) {
if (convert_input != nullptr) {
if (input_in_mkl_format == false) {
- CHECK_EQ(dnnConversionExecute_F32(
- convert_input,
- const_cast<void*>(static_cast<const void*>(
- tensor_in.flat<T>().data())),
- input_buf),
- E_SUCCESS);
+ CHECK_EQ(
+ dnnConversionExecute_F32(
+ convert_input, const_cast<void*>(static_cast<const void*>(
+ tensor_in.flat<T>().data())),
+ input_buf),
+ E_SUCCESS);
CHECK_EQ(dnnDelete_F32(convert_input), E_SUCCESS);
convert_input = nullptr;
} else {
input_shape.GetConvertedFlatData(
- lt_input_prim,
- const_cast<void*>(
- static_cast<const void*>(tensor_in.flat<T>().data())),
+ lt_input_prim, const_cast<void*>(static_cast<const void*>(
+ tensor_in.flat<T>().data())),
input_buf);
}
pooling_resfwd[dnnResourceSrc] = input_buf;
@@ -444,9 +438,8 @@ class MklMaxPoolingGradOp : public OpKernel {
CHECK_EQ(dnnDelete_F32(convert_outbackprop), E_SUCCESS);
} else {
output_backprop_shape.GetConvertedFlatData(
- lt_outbackprop_prim,
- const_cast<void*>(
- static_cast<const void*>(out_backprop.flat<T>().data())),
+ lt_outbackprop_prim, const_cast<void*>(static_cast<const void*>(
+ out_backprop.flat<T>().data())),
outbackprop_buf);
}
pooling_res[dnnResourceDiffDst] = outbackprop_buf;