diff options
Diffstat (limited to 'tensorflow/core/kernels/mkl_maxpooling_op.cc')
-rw-r--r-- | tensorflow/core/kernels/mkl_maxpooling_op.cc | 24 |
1 files changed, 11 insertions, 13 deletions
diff --git a/tensorflow/core/kernels/mkl_maxpooling_op.cc b/tensorflow/core/kernels/mkl_maxpooling_op.cc index 44e1b93dba..846bb5710d 100644 --- a/tensorflow/core/kernels/mkl_maxpooling_op.cc +++ b/tensorflow/core/kernels/mkl_maxpooling_op.cc @@ -311,7 +311,7 @@ class MklMaxPoolingGradOp : public OpKernel { lt_input_user = (dnnLayout_t)input_shape.GetCurLayout(); } - // We dont care about the output layout for now as we can create it from + // We don't care about the output layout for now as we can create it from // primitives for the max pooling fwd prop if (outbackprop_in_mkl_format == false) { CHECK_EQ(dnnLayoutCreate_F32(<_outbackprop_user, params.in_dim, @@ -382,19 +382,18 @@ class MklMaxPoolingGradOp : public OpKernel { if (workspace_enabled == false) { if (convert_input != nullptr) { if (input_in_mkl_format == false) { - CHECK_EQ(dnnConversionExecute_F32( - convert_input, - const_cast<void*>(static_cast<const void*>( - tensor_in.flat<T>().data())), - input_buf), - E_SUCCESS); + CHECK_EQ( + dnnConversionExecute_F32( + convert_input, const_cast<void*>(static_cast<const void*>( + tensor_in.flat<T>().data())), + input_buf), + E_SUCCESS); CHECK_EQ(dnnDelete_F32(convert_input), E_SUCCESS); convert_input = nullptr; } else { input_shape.GetConvertedFlatData( - lt_input_prim, - const_cast<void*>( - static_cast<const void*>(tensor_in.flat<T>().data())), + lt_input_prim, const_cast<void*>(static_cast<const void*>( + tensor_in.flat<T>().data())), input_buf); } pooling_resfwd[dnnResourceSrc] = input_buf; @@ -439,9 +438,8 @@ class MklMaxPoolingGradOp : public OpKernel { CHECK_EQ(dnnDelete_F32(convert_outbackprop), E_SUCCESS); } else { output_backprop_shape.GetConvertedFlatData( - lt_outbackprop_prim, - const_cast<void*>( - static_cast<const void*>(out_backprop.flat<T>().data())), + lt_outbackprop_prim, const_cast<void*>(static_cast<const void*>( + out_backprop.flat<T>().data())), outbackprop_buf); } pooling_res[dnnResourceDiffDst] = outbackprop_buf; |