diff options
Diffstat (limited to 'tensorflow/core/kernels/maxpooling_op.cc')
-rw-r--r-- | tensorflow/core/kernels/maxpooling_op.cc | 20 |
1 files changed, 9 insertions, 11 deletions
diff --git a/tensorflow/core/kernels/maxpooling_op.cc b/tensorflow/core/kernels/maxpooling_op.cc index aaaf45d3e7..507fc99837 100644 --- a/tensorflow/core/kernels/maxpooling_op.cc +++ b/tensorflow/core/kernels/maxpooling_op.cc @@ -404,10 +404,10 @@ class MaxPoolingGradOp<Eigen::GpuDevice, T> : public OpKernel { "Pooling is not yet supported on the batch dimension.")); if (use_dnn_) { - DnnPoolingGradOp<T>::Compute( - context, perftools::gputools::dnn::PoolingMode::kMaximum, ksize, - stride, padding_, data_format_, &tensor_in, &tensor_out, out_backprop, - output_shape, propagate_nans_); + DnnPoolingGradOp<T>::Compute(context, se::dnn::PoolingMode::kMaximum, + ksize, stride, padding_, data_format_, + &tensor_in, &tensor_out, out_backprop, + output_shape, propagate_nans_); } else { CHECK(data_format_ == FORMAT_NHWC) << "Non-Cudnn MaxPoolGrad only supports NHWC format"; @@ -1136,10 +1136,9 @@ class MaxPoolingNoMaskOp<GPUDevice, T> : public OpKernel { // These is_int8x4 checks avoid linker errors for missing qint8 kernels. if (!is_int8x4 && use_dnn_ && data_format_ == FORMAT_NCHW) { - DnnPoolingOp<T>::Compute(context, - perftools::gputools::dnn::PoolingMode::kMaximum, - ksize_, stride_, padding_, data_format_, - tensor_in, out_shape, propagate_nans_); + DnnPoolingOp<T>::Compute(context, se::dnn::PoolingMode::kMaximum, ksize_, + stride_, padding_, data_format_, tensor_in, + out_shape, propagate_nans_); } else { Tensor* output = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output)); @@ -1240,9 +1239,8 @@ class MaxPoolingNoMaskV2Op<GPUDevice, T> : public OpKernel { ShapeFromFormat(data_format_, params.tensor_in_batch, params.out_height, params.out_width, params.depth); if (use_dnn_ && data_format_ == FORMAT_NCHW) { - DnnPoolingOp<T>::Compute(context, - perftools::gputools::dnn::PoolingMode::kMaximum, - ksize, stride, padding_, data_format_, tensor_in, + DnnPoolingOp<T>::Compute(context, se::dnn::PoolingMode::kMaximum, ksize, + stride, padding_, data_format_, tensor_in, out_shape, propagate_nans_); } else { CHECK(data_format_ == FORMAT_NHWC) |