aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/pooling_ops_common.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/pooling_ops_common.cc')
-rw-r--r--tensorflow/core/kernels/pooling_ops_common.cc10
1 files changed, 6 insertions, 4 deletions
diff --git a/tensorflow/core/kernels/pooling_ops_common.cc b/tensorflow/core/kernels/pooling_ops_common.cc
index 7dee751c4f..ac90f67ce0 100644
--- a/tensorflow/core/kernels/pooling_ops_common.cc
+++ b/tensorflow/core/kernels/pooling_ops_common.cc
@@ -143,7 +143,7 @@ void DnnPoolingOp<T>::Compute(
perftools::gputools::dnn::PoolingMode pooling_mode,
const std::vector<int32>& size, const std::vector<int32>& stride,
Padding padding, TensorFormat data_format, const Tensor& tensor_in,
- const TensorShape& tensor_out_shape) {
+ const TensorShape& tensor_out_shape, bool propagate_nans) {
Tensor* tensor_out = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(0, tensor_out_shape, &tensor_out));
@@ -188,7 +188,8 @@ void DnnPoolingOp<T>::Compute(
.set_vertical_stride(params.row_stride)
.set_horizontal_stride(params.col_stride)
.set_vertical_padding(params.pad_rows)
- .set_horizontal_padding(params.pad_cols);
+ .set_horizontal_padding(params.pad_cols)
+ .set_propagate_nans(propagate_nans);
perftools::gputools::dnn::BatchDescriptor input_desc;
input_desc.set_count(params.tensor_in_batch)
@@ -237,7 +238,7 @@ void DnnPoolingGradOp<T>::Compute(
const std::vector<int32>& size, const std::vector<int32>& stride,
Padding padding, TensorFormat data_format, const Tensor* tensor_in,
const Tensor* tensor_out, const Tensor& out_backprop,
- const TensorShape& tensor_in_shape) {
+ const TensorShape& tensor_in_shape, bool propagate_nans) {
CHECK((pooling_mode != perftools::gputools::dnn::PoolingMode::kMaximum) ||
(tensor_in && tensor_out))
<< "For MaxPoolGrad, both tensor_in and tensor_out needs to be "
@@ -327,7 +328,8 @@ void DnnPoolingGradOp<T>::Compute(
.set_vertical_stride(params.row_stride)
.set_horizontal_stride(params.col_stride)
.set_vertical_padding(params.pad_rows)
- .set_horizontal_padding(params.pad_cols);
+ .set_horizontal_padding(params.pad_cols)
+ .set_propagate_nans(propagate_nans);
perftools::gputools::dnn::BatchDescriptor orig_output_desc;
orig_output_desc.set_count(params.tensor_in_batch)