diff options
Diffstat (limited to 'tensorflow/core/kernels/pooling_ops_common.cc')
-rw-r--r-- | tensorflow/core/kernels/pooling_ops_common.cc | 10 |
1 files changed, 6 insertions, 4 deletions
diff --git a/tensorflow/core/kernels/pooling_ops_common.cc b/tensorflow/core/kernels/pooling_ops_common.cc index 7dee751c4f..ac90f67ce0 100644 --- a/tensorflow/core/kernels/pooling_ops_common.cc +++ b/tensorflow/core/kernels/pooling_ops_common.cc @@ -143,7 +143,7 @@ void DnnPoolingOp<T>::Compute( perftools::gputools::dnn::PoolingMode pooling_mode, const std::vector<int32>& size, const std::vector<int32>& stride, Padding padding, TensorFormat data_format, const Tensor& tensor_in, - const TensorShape& tensor_out_shape) { + const TensorShape& tensor_out_shape, bool propagate_nans) { Tensor* tensor_out = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, tensor_out_shape, &tensor_out)); @@ -188,7 +188,8 @@ void DnnPoolingOp<T>::Compute( .set_vertical_stride(params.row_stride) .set_horizontal_stride(params.col_stride) .set_vertical_padding(params.pad_rows) - .set_horizontal_padding(params.pad_cols); + .set_horizontal_padding(params.pad_cols) + .set_propagate_nans(propagate_nans); perftools::gputools::dnn::BatchDescriptor input_desc; input_desc.set_count(params.tensor_in_batch) @@ -237,7 +238,7 @@ void DnnPoolingGradOp<T>::Compute( const std::vector<int32>& size, const std::vector<int32>& stride, Padding padding, TensorFormat data_format, const Tensor* tensor_in, const Tensor* tensor_out, const Tensor& out_backprop, - const TensorShape& tensor_in_shape) { + const TensorShape& tensor_in_shape, bool propagate_nans) { CHECK((pooling_mode != perftools::gputools::dnn::PoolingMode::kMaximum) || (tensor_in && tensor_out)) << "For MaxPoolGrad, both tensor_in and tensor_out needs to be " @@ -327,7 +328,8 @@ void DnnPoolingGradOp<T>::Compute( .set_vertical_stride(params.row_stride) .set_horizontal_stride(params.col_stride) .set_vertical_padding(params.pad_rows) - .set_horizontal_padding(params.pad_cols); + .set_horizontal_padding(params.pad_cols) + .set_propagate_nans(propagate_nans); perftools::gputools::dnn::BatchDescriptor orig_output_desc; orig_output_desc.set_count(params.tensor_in_batch) |