aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/framework/common_shape_fns.cc110
-rw-r--r--tensorflow/core/framework/common_shape_fns.h3
-rw-r--r--tensorflow/core/kernels/maxpooling_op.cc450
-rw-r--r--tensorflow/core/kernels/pooling_ops_common.h215
-rw-r--r--tensorflow/core/ops/nn_ops.cc96
-rw-r--r--tensorflow/python/kernel_tests/pooling_ops_test.py726
-rw-r--r--tensorflow/python/ops/hidden_ops.txt1
-rw-r--r--tensorflow/python/ops/nn_grad.py31
8 files changed, 1289 insertions, 343 deletions
diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc
index 9df5cbdec0..bd5d6e4af4 100644
--- a/tensorflow/core/framework/common_shape_fns.cc
+++ b/tensorflow/core/framework/common_shape_fns.cc
@@ -673,6 +673,116 @@ Status MaxPoolShape(shape_inference::InferenceContext* c) {
return Status::OK();
}
+Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs) {
+ ShapeHandle input_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
+
+ string data_format;
+ Status s = c->GetAttr("data_format", &data_format);
+
+ std::vector<int32> kernel_sizes;
+ std::vector<int32> strides;
+
+ if (c->num_inputs() + 2 == num_inputs) {
+ TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes));
+
+ TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
+ } else {
+ // Verify shape of ksize and strides input.
+ ShapeHandle size;
+ DimensionHandle unused;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 2), 1, &size));
+ TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 4, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 1), 1, &size));
+ TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 4, &unused));
+
+ const Tensor* kernel_sizes_tensor = c->input_tensor(c->num_inputs() - 2);
+ if (kernel_sizes_tensor == nullptr) {
+ c->set_output(0, c->UnknownShape());
+ return Status::OK();
+ }
+ kernel_sizes.resize(kernel_sizes_tensor->shape().num_elements());
+ auto kernel_sizes_vec = kernel_sizes_tensor->flat<int32>();
+ std::copy_n(&kernel_sizes_vec(0), kernel_sizes.size(), kernel_sizes.begin());
+
+ const Tensor* strides_tensor = c->input_tensor(c->num_inputs() - 1);
+ if (strides_tensor == nullptr) {
+ c->set_output(0, c->UnknownShape());
+ return Status::OK();
+ }
+ strides.resize(strides_tensor->shape().num_elements());
+ auto strides_vec = strides_tensor->flat<int32>();
+ std::copy_n(&strides_vec(0), strides.size(), strides.begin());
+ }
+
+ if (strides.size() != 4) {
+ return errors::InvalidArgument(
+ "MaxPool requires the stride attribute to contain 4 values, but "
+ "got: ",
+ strides.size());
+ }
+ if (kernel_sizes.size() != 4) {
+ return errors::InvalidArgument(
+ "MaxPool requires the ksize attribute to contain 4 values, but got: ",
+ kernel_sizes.size());
+ }
+
+ int32 stride_rows, stride_cols, stride_depth;
+ int32 kernel_rows, kernel_cols, kernel_depth;
+
+ if (s.ok() && data_format == "NCHW") {
+ // Canonicalize input shape to NHWC so the shape inference code below can
+ // process it.
+ auto dim = [&](char dimension) {
+ return c->Dim(input_shape, GetTensorDimIndex<2>(FORMAT_NCHW, dimension));
+ };
+ input_shape = c->MakeShape({{dim('N'), dim('0'), dim('1'), dim('C')}});
+ stride_depth = strides[1];
+ stride_rows = strides[2];
+ stride_cols = strides[3];
+ kernel_depth = kernel_sizes[1];
+ kernel_rows = kernel_sizes[2];
+ kernel_cols = kernel_sizes[3];
+ } else {
+ stride_rows = strides[1];
+ stride_cols = strides[2];
+ stride_depth = strides[3];
+ kernel_rows = kernel_sizes[1];
+ kernel_cols = kernel_sizes[2];
+ kernel_depth = kernel_sizes[3];
+ }
+
+ DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
+ DimensionHandle in_rows_dim = c->Dim(input_shape, 1);
+ DimensionHandle in_cols_dim = c->Dim(input_shape, 2);
+ DimensionHandle in_depth_dim = c->Dim(input_shape, 3);
+
+ Padding padding;
+ TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
+
+ ShapeHandle output_shape;
+ DimensionHandle output_rows, output_cols, output_depth;
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
+ c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows));
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
+ c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols));
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
+ c, in_depth_dim, kernel_depth, stride_depth, padding, &output_depth));
+
+ output_shape =
+ c->MakeShape({batch_size_dim, output_rows, output_cols, output_depth});
+ if (data_format == "NCHW") {
+ // Convert output shape back to expected NCHW data format.
+ auto dim = [&](char dimension) {
+ return c->Dim(output_shape, GetTensorDimIndex<2>(FORMAT_NHWC, dimension));
+ };
+ output_shape = c->MakeShape({{dim('N'), dim('C'), dim('0'), dim('1')}});
+ }
+
+ c->set_output(0, output_shape);
+ return Status::OK();
+}
+
Status Pool3DShape(shape_inference::InferenceContext* c) {
ShapeHandle input_shape;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape));
diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h
index 73b915652f..fb79df07a4 100644
--- a/tensorflow/core/framework/common_shape_fns.h
+++ b/tensorflow/core/framework/common_shape_fns.h
@@ -179,6 +179,9 @@ Status AvgPoolShape(shape_inference::InferenceContext* c);
// Shape function for MaxPool-like operations.
Status MaxPoolShape(shape_inference::InferenceContext* c);
+// Shape function for MaxPoolV2-like operations.
+Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs);
+
// Shape function for 3D Pooling operations.
Status Pool3DShape(shape_inference::InferenceContext* c);
diff --git a/tensorflow/core/kernels/maxpooling_op.cc b/tensorflow/core/kernels/maxpooling_op.cc
index 6cb56797bf..8d825c13d7 100644
--- a/tensorflow/core/kernels/maxpooling_op.cc
+++ b/tensorflow/core/kernels/maxpooling_op.cc
@@ -208,22 +208,26 @@ class MaxPoolingGradOp : public OpKernel {
errors::InvalidArgument("Default MaxPoolingGradOp only supports NHWC ",
"on device type ",
DeviceTypeString(context->device_type())));
- OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
- OP_REQUIRES(context, ksize_.size() == 4,
- errors::InvalidArgument("Sliding window ksize field must "
- "specify 4 dimensions"));
- OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
- OP_REQUIRES(context, stride_.size() == 4,
- errors::InvalidArgument("Sliding window strides field must "
- "specify 4 dimensions"));
+
+ if (context->num_inputs() == 3) {
+ OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
+ OP_REQUIRES(context, ksize_.size() == 4,
+ errors::InvalidArgument("Sliding window ksize field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
+ OP_REQUIRES(context, stride_.size() == 4,
+ errors::InvalidArgument("Sliding window strides field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
+ errors::Unimplemented(
+ "Pooling is not yet supported on the batch dimension."));
+ OP_REQUIRES(
+ context, ksize_[3] == 1 && stride_[3] == 1,
+ errors::Unimplemented(
+ "MaxPoolingGrad is not yet supported on the depth dimension."));
+ }
+
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
- OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
- errors::Unimplemented(
- "Pooling is not yet supported on the batch dimension."));
- OP_REQUIRES(
- context, ksize_[3] == 1 && stride_[3] == 1,
- errors::Unimplemented(
- "MaxPoolingGrad is not yet supported on the depth dimension."));
}
void Compute(OpKernelContext* context) override {
@@ -250,8 +254,35 @@ class MaxPoolingGradOp : public OpKernel {
OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<int64>::v(),
tensor_out.shape(),
&tensor_out_arg_max));
+ std::vector<int32> ksize = ksize_;
+ std::vector<int32> stride = stride_;
+ if (context->num_inputs() == 5) {
+ const Tensor& tensor_ksize = context->input(3);
+ auto value_ksize = tensor_ksize.flat<int32>();
+ ksize.resize(tensor_ksize.shape().num_elements());
+ std::copy_n(&value_ksize(0), ksize.size(), ksize.begin());
+
+ const Tensor& tensor_stride = context->input(4);
+ auto value_stride = tensor_stride.flat<int32>();
+ stride.resize(tensor_stride.shape().num_elements());
+ std::copy_n(&value_stride(0), stride.size(), stride.begin());
+ }
- PoolParameters params{context, ksize_, stride_,
+ OP_REQUIRES(context, ksize.size() == 4,
+ errors::InvalidArgument("Sliding window ksize field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES(context, stride.size() == 4,
+ errors::InvalidArgument("Sliding window strides field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES(context, ksize[0] == 1 && stride[0] == 1,
+ errors::Unimplemented(
+ "Pooling is not yet supported on the batch dimension."));
+ OP_REQUIRES(
+ context, ksize[3] == 1 && stride[3] == 1,
+ errors::Unimplemented(
+ "MaxPoolingGrad is not yet supported on the depth dimension."));
+
+ PoolParameters params{context, ksize, stride,
padding_, FORMAT_NHWC, tensor_in.shape()};
if (!context->status().ok()) {
return;
@@ -309,20 +340,22 @@ class MaxPoolingGradOp<Eigen::GpuDevice, T> : public OpKernel {
OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
errors::InvalidArgument("Invalid data format"));
- OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
- OP_REQUIRES(context, ksize_.size() == 4,
- errors::InvalidArgument("Sliding window ksize field must "
- "specify 4 dimensions"));
- OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
- OP_REQUIRES(context, stride_.size() == 4,
- errors::InvalidArgument("Sliding window strides field must "
- "specify 4 dimensions"));
+ if (context->num_inputs() == 3) {
+ OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
+ OP_REQUIRES(context, ksize_.size() == 4,
+ errors::InvalidArgument("Sliding window ksize field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
+ OP_REQUIRES(context, stride_.size() == 4,
+ errors::InvalidArgument("Sliding window strides field must "
+ "specify 4 dimensions"));
+ const int32 ksize_n = GetTensorDim(ksize_, data_format_, 'N');
+ const int32 stride_n = GetTensorDim(stride_, data_format_, 'N');
+ OP_REQUIRES(context, ksize_n == 1 && stride_n == 1,
+ errors::Unimplemented(
+ "Pooling is not yet supported on the batch dimension."));
+ }
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
- const int32 ksize_n = GetTensorDim(ksize_, data_format_, 'N');
- const int32 stride_n = GetTensorDim(stride_, data_format_, 'N');
- OP_REQUIRES(context, ksize_n == 1 && stride_n == 1,
- errors::Unimplemented(
- "Pooling is not yet supported on the batch dimension."));
use_dnn_ = CanUseCudnn();
}
@@ -343,15 +376,40 @@ class MaxPoolingGradOp<Eigen::GpuDevice, T> : public OpKernel {
TensorShape output_shape = tensor_in.shape();
+ std::vector<int32> ksize = ksize_;
+ std::vector<int32> stride = stride_;
+ if (context->num_inputs() == 5) {
+ const Tensor& tensor_ksize = context->input(3);
+ auto value_ksize = tensor_ksize.flat<int32>();
+ ksize.resize(tensor_ksize.shape().num_elements());
+ std::copy_n(&value_ksize(0), ksize.size(), ksize.begin());
+
+ const Tensor& tensor_stride = context->input(4);
+ auto value_stride = tensor_stride.flat<int32>();
+ stride.resize(tensor_stride.shape().num_elements());
+ std::copy_n(&value_stride(0), stride.size(), stride.begin());
+ }
+ OP_REQUIRES(context, ksize.size() == 4,
+ errors::InvalidArgument("Sliding window ksize field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES(context, stride.size() == 4,
+ errors::InvalidArgument("Sliding window strides field must "
+ "specify 4 dimensions"));
+ const int32 ksize_n = GetTensorDim(ksize, data_format_, 'N');
+ const int32 stride_n = GetTensorDim(stride, data_format_, 'N');
+ OP_REQUIRES(context, ksize_n == 1 && stride_n == 1,
+ errors::Unimplemented(
+ "Pooling is not yet supported on the batch dimension."));
+
if (use_dnn_) {
DnnPoolingGradOp<T>::Compute(
- context, perftools::gputools::dnn::PoolingMode::kMaximum, ksize_,
- stride_, padding_, data_format_, &tensor_in, &tensor_out,
- out_backprop, output_shape);
+ context, perftools::gputools::dnn::PoolingMode::kMaximum, ksize,
+ stride, padding_, data_format_, &tensor_in, &tensor_out, out_backprop,
+ output_shape);
} else {
CHECK(data_format_ == FORMAT_NHWC)
<< "Non-Cudnn MaxPoolGrad only supports NHWC format";
- MaxPoolingBackwardCustomKernel<T>(context, ksize_, stride_, padding_,
+ MaxPoolingBackwardCustomKernel<T>(context, ksize, stride, padding_,
&tensor_in, out_backprop, output_shape);
}
}
@@ -386,22 +444,25 @@ class MaxPoolingGradGradOp : public OpKernel {
errors::InvalidArgument(
"Default MaxPoolingGradGradOp only supports NHWC ",
"on device type ", DeviceTypeString(context->device_type())));
- OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
- OP_REQUIRES(context, ksize_.size() == 4,
- errors::InvalidArgument("Sliding window ksize field must "
- "specify 4 dimensions"));
- OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
- OP_REQUIRES(context, stride_.size() == 4,
- errors::InvalidArgument("Sliding window strides field must "
- "specify 4 dimensions"));
+
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
- OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
- errors::Unimplemented(
- "Pooling is not yet supported on the batch dimension."));
- OP_REQUIRES(
- context, ksize_[3] == 1 && stride_[3] == 1,
- errors::Unimplemented(
- "MaxPoolingGradGrad is not yet supported on the depth dimension."));
+
+ if (context->num_inputs() == 3) {
+ OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
+ OP_REQUIRES(context, ksize_.size() == 4,
+ errors::InvalidArgument("Sliding window ksize field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
+ OP_REQUIRES(context, stride_.size() == 4,
+ errors::InvalidArgument("Sliding window strides field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
+ errors::Unimplemented(
+ "Pooling is not yet supported on the batch dimension."));
+ OP_REQUIRES(context, ksize_[3] == 1 && stride_[3] == 1,
+ errors::Unimplemented("MaxPoolingGradGrad is not yet "
+ "supported on the depth dimension."));
+ }
}
void Compute(OpKernelContext* context) override {
@@ -419,7 +480,35 @@ class MaxPoolingGradGradOp : public OpKernel {
context, out_grad_backprop.dims() == 4,
errors::InvalidArgument("out_grad_backprop must be 4-dimensional"));
- PoolParameters params{context, ksize_, stride_,
+ std::vector<int32> ksize = ksize_;
+ std::vector<int32> stride = stride_;
+ if (context->num_inputs() == 5) {
+ const Tensor& tensor_ksize = context->input(3);
+ auto value_ksize = tensor_ksize.flat<int32>();
+ ksize.resize(tensor_ksize.shape().num_elements());
+ std::copy_n(&value_ksize(0), ksize.size(), ksize.begin());
+
+ const Tensor& tensor_stride = context->input(4);
+ auto value_stride = tensor_stride.flat<int32>();
+ stride.resize(tensor_stride.shape().num_elements());
+ std::copy_n(&value_stride(0), stride.size(), stride.begin());
+ }
+
+ OP_REQUIRES(context, ksize.size() == 4,
+ errors::InvalidArgument("Sliding window ksize field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES(context, stride.size() == 4,
+ errors::InvalidArgument("Sliding window strides field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES(context, ksize[0] == 1 && stride[0] == 1,
+ errors::Unimplemented(
+ "Pooling is not yet supported on the batch dimension."));
+ OP_REQUIRES(
+ context, ksize[3] == 1 && stride[3] == 1,
+ errors::Unimplemented(
+ "MaxPoolingGrad is not yet supported on the depth dimension."));
+
+ PoolParameters params{context, ksize, stride,
padding_, FORMAT_NHWC, tensor_in.shape()};
Tensor* output = nullptr;
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
@@ -474,7 +563,7 @@ class MaxPoolingGradGradOp : public OpKernel {
// tensor_out_as_matrix with the corresponding values in
// top_diff_as_matrix.
auto shard = [&params, &in_mat, &out_mat, &top_diff_mat, &bottom_diff_mat](
- int64 start, int64 limit) {
+ int64 start, int64 limit) {
const int32 depth = params.depth;
const int32 in_rows = params.tensor_in_rows;
const int32 in_cols = params.tensor_in_cols;
@@ -555,20 +644,22 @@ class MaxPoolingGradGradOp<Eigen::GpuDevice, T> : public OpKernel {
OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
errors::InvalidArgument("Invalid data format"));
- OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
- OP_REQUIRES(context, ksize_.size() == 4,
- errors::InvalidArgument("Sliding window ksize field must "
- "specify 4 dimensions"));
- OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
- OP_REQUIRES(context, stride_.size() == 4,
- errors::InvalidArgument("Sliding window strides field must "
- "specify 4 dimensions"));
+ if (context->num_inputs() == 3) {
+ OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
+ OP_REQUIRES(context, ksize_.size() == 4,
+ errors::InvalidArgument("Sliding window ksize field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
+ OP_REQUIRES(context, stride_.size() == 4,
+ errors::InvalidArgument("Sliding window strides field must "
+ "specify 4 dimensions"));
+ const int32 ksize_n = GetTensorDim(ksize_, data_format_, 'N');
+ const int32 stride_n = GetTensorDim(stride_, data_format_, 'N');
+ OP_REQUIRES(context, ksize_n == 1 && stride_n == 1,
+ errors::Unimplemented(
+ "Pooling is not yet supported on the batch dimension."));
+ }
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
- const int32 ksize_n = GetTensorDim(ksize_, data_format_, 'N');
- const int32 stride_n = GetTensorDim(stride_, data_format_, 'N');
- OP_REQUIRES(context, ksize_n == 1 && stride_n == 1,
- errors::Unimplemented(
- "Pooling is not yet supported on the batch dimension."));
}
void Compute(OpKernelContext* context) override {
@@ -590,7 +681,33 @@ class MaxPoolingGradGradOp<Eigen::GpuDevice, T> : public OpKernel {
OP_REQUIRES_OK(context,
context->allocate_output(0, tensor_out.shape(), &output));
- PoolParameters params{context, ksize_, stride_,
+ std::vector<int32> ksize = ksize_;
+ std::vector<int32> stride = stride_;
+ if (context->num_inputs() == 5) {
+ const Tensor& tensor_ksize = context->input(3);
+ auto value_ksize = tensor_ksize.flat<int32>();
+ ksize.resize(tensor_ksize.shape().num_elements());
+ std::copy_n(&value_ksize(0), ksize.size(), ksize.begin());
+
+ const Tensor& tensor_stride = context->input(4);
+ auto value_stride = tensor_stride.flat<int32>();
+ stride.resize(tensor_stride.shape().num_elements());
+ std::copy_n(&value_stride(0), stride.size(), stride.begin());
+ }
+
+ OP_REQUIRES(context, ksize.size() == 4,
+ errors::InvalidArgument("Sliding window ksize field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES(context, stride.size() == 4,
+ errors::InvalidArgument("Sliding window strides field must "
+ "specify 4 dimensions"));
+ const int32 ksize_n = GetTensorDim(ksize, data_format_, 'N');
+ const int32 stride_n = GetTensorDim(stride, data_format_, 'N');
+ OP_REQUIRES(context, ksize_n == 1 && stride_n == 1,
+ errors::Unimplemented(
+ "Pooling is not yet supported on the batch dimension."));
+
+ PoolParameters params{context, ksize, stride,
padding_, data_format_, tensor_in.shape()};
functor::MaxPoolGradBackwardNoMask<T>()(
@@ -670,6 +787,84 @@ class MaxPoolingNoMaskOp : public OpKernel {
};
template <typename Device, typename T>
+class MaxPoolingNoMaskV2Op : public OpKernel {
+ public:
+ explicit MaxPoolingNoMaskV2Op(OpKernelConstruction* context)
+ : OpKernel(context) {
+ string data_format;
+ OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
+ OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
+ errors::InvalidArgument("Invalid data format"));
+ OP_REQUIRES(
+ context, data_format_ == FORMAT_NHWC,
+ errors::InvalidArgument(
+ "Default MaxPoolingNoMaskOp only supports NHWC on device type ",
+ DeviceTypeString(context->device_type())));
+ if (context->num_inputs() == 1) {
+ OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
+ OP_REQUIRES(context, ksize_.size() == 4,
+ errors::InvalidArgument("Sliding window ksize field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
+ OP_REQUIRES(context, stride_.size() == 4,
+ errors::InvalidArgument("Sliding window stride field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
+ errors::Unimplemented(
+ "Pooling is not yet supported on the batch dimension."));
+ }
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& tensor_in = context->input(0);
+
+ std::vector<int32> ksize = ksize_;
+ std::vector<int32> stride = stride_;
+
+ if (context->num_inputs() != 1) {
+ const Tensor& tensor_ksize = context->input(1);
+ auto value_ksize = tensor_ksize.flat<int32>();
+ ksize.resize(tensor_ksize.shape().num_elements());
+ std::copy_n(&value_ksize(0), ksize.size(), ksize.begin());
+
+ const Tensor& tensor_stride = context->input(2);
+ auto value_stride = tensor_stride.flat<int32>();
+ stride.resize(tensor_stride.shape().num_elements());
+ std::copy_n(&value_stride(0), stride.size(), stride.begin());
+ }
+ OP_REQUIRES(context, ksize.size() == 4,
+ errors::InvalidArgument("Sliding window ksize field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES(context, stride.size() == 4,
+ errors::InvalidArgument("Sliding window stride field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES(context, ksize[0] == 1 && stride[0] == 1,
+ errors::Unimplemented(
+ "Pooling is not yet supported on the batch dimension."));
+ PoolParameters params{context, ksize, stride,
+ padding_, data_format_, tensor_in.shape()};
+ if (!context->status().ok()) {
+ return;
+ }
+
+ TensorShape out_shape({params.tensor_in_batch, params.out_height,
+ params.out_width, params.depth});
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
+
+ LaunchMaxPoolingNoMask<Device, T>::launch(context, params, tensor_in,
+ output);
+ }
+
+ private:
+ std::vector<int32> ksize_;
+ std::vector<int32> stride_;
+ Padding padding_;
+ TensorFormat data_format_;
+};
+
+template <typename Device, typename T>
struct LaunchMaxPoolingWithArgmax;
template <typename Device, typename T>
@@ -879,6 +1074,95 @@ class MaxPoolingNoMaskOp<GPUDevice, T> : public OpKernel {
};
template <typename T>
+class MaxPoolingNoMaskV2Op<GPUDevice, T> : public OpKernel {
+ public:
+ typedef GPUDevice Device;
+ explicit MaxPoolingNoMaskV2Op(OpKernelConstruction* context)
+ : OpKernel(context) {
+ string data_format;
+ OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
+ OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
+ errors::InvalidArgument("Invalid data format"));
+ if (context->num_inputs() == 1) {
+ OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
+ OP_REQUIRES(context, ksize_.size() == 4,
+ errors::InvalidArgument("Sliding window ksize field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
+ OP_REQUIRES(context, stride_.size() == 4,
+ errors::InvalidArgument("Sliding window stride field must "
+ "specify 4 dimensions"));
+ const int32 ksize_n = GetTensorDim(ksize_, data_format_, 'N');
+ const int32 stride_n = GetTensorDim(stride_, data_format_, 'N');
+ OP_REQUIRES(context, ksize_n == 1 && stride_n == 1,
+ errors::Unimplemented(
+ "Pooling is not yet supported on the batch dimension."));
+ }
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ use_dnn_ = CanUseCudnn();
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& tensor_in = context->input(0);
+
+ std::vector<int32> ksize = ksize_;
+ std::vector<int32> stride = stride_;
+
+ if (context->num_inputs() != 1) {
+ const Tensor& tensor_ksize = context->input(1);
+ auto value_ksize = tensor_ksize.flat<int32>();
+ ksize.resize(tensor_ksize.shape().num_elements());
+ std::copy_n(&value_ksize(0), ksize.size(), ksize.begin());
+
+ const Tensor& tensor_stride = context->input(2);
+ auto value_stride = tensor_stride.flat<int32>();
+ stride.resize(tensor_stride.shape().num_elements());
+ std::copy_n(&value_stride(0), stride.size(), stride.begin());
+ }
+ OP_REQUIRES(context, ksize.size() == 4,
+ errors::InvalidArgument("Sliding window ksize field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES(context, stride.size() == 4,
+ errors::InvalidArgument("Sliding window stride field must "
+ "specify 4 dimensions"));
+ const int32 ksize_n = GetTensorDim(ksize, data_format_, 'N');
+ const int32 stride_n = GetTensorDim(stride, data_format_, 'N');
+ OP_REQUIRES(context, ksize_n == 1 && stride_n == 1,
+ errors::Unimplemented(
+ "Pooling is not yet supported on the batch dimension."));
+
+ PoolParameters params{context, ksize, stride,
+ padding_, data_format_, tensor_in.shape()};
+ if (!context->status().ok()) {
+ return;
+ }
+
+ TensorShape out_shape =
+ ShapeFromFormat(data_format_, params.tensor_in_batch, params.out_height,
+ params.out_width, params.depth);
+ if (use_dnn_ && data_format_ == FORMAT_NCHW) {
+ DnnPoolingOp<T>::Compute(
+ context, perftools::gputools::dnn::PoolingMode::kMaximum, ksize,
+ stride, padding_, data_format_, tensor_in, out_shape);
+ } else {
+ CHECK(data_format_ == FORMAT_NHWC)
+ << "Non-Cudnn MaxPool only supports NHWC format";
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
+ LaunchMaxPoolingNoMask<Device, T>::launch(context, params, tensor_in,
+ output);
+ }
+ }
+
+ private:
+ std::vector<int32> ksize_;
+ std::vector<int32> stride_;
+ Padding padding_;
+ TensorFormat data_format_;
+ bool use_dnn_;
+};
+
+template <typename T>
struct LaunchMaxPoolingNoMask<Eigen::GpuDevice, T> {
static void launch(OpKernelContext* context, const PoolParameters& params,
const Tensor& input, Tensor* output) {
@@ -969,13 +1253,28 @@ struct LaunchMaxPoolingGradGradWithArgmax<Eigen::GpuDevice, T> {
MaxPoolingGradOp<D##Device, T>); \
REGISTER_KERNEL_BUILDER( \
Name("MaxPoolGradGrad").Device(DEVICE_##D).TypeConstraint<T>("T"), \
- MaxPoolingGradGradOp<D##Device, T>);
+ MaxPoolingGradGradOp<D##Device, T>); \
+ REGISTER_KERNEL_BUILDER(Name("MaxPoolGradV2") \
+ .Device(DEVICE_##D) \
+ .HostMemory("ksize") \
+ .HostMemory("strides") \
+ .TypeConstraint<T>("T"), \
+ MaxPoolingGradOp<D##Device, T>); \
+ REGISTER_KERNEL_BUILDER(Name("MaxPoolGradGradV2") \
+ .Device(DEVICE_##D) \
+ .HostMemory("ksize") \
+ .HostMemory("strides") \
+ .TypeConstraint<T>("T"), \
+ MaxPoolingGradGradOp<D##Device, T>);
// Below kernels implemented only for CPU device.
-#define REGISTER_CPU_ONLY_POOL_KERNELS(T) \
- REGISTER_KERNEL_BUILDER( \
- Name("MaxPool").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
- MaxPoolingOp<CPUDevice, T>);
+#define REGISTER_CPU_ONLY_POOL_KERNELS(T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("MaxPool").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
+ MaxPoolingOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("MaxPoolV2").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
+ MaxPoolingV2Op<CPUDevice, T>);
TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_ONLY_POOL_KERNELS);
#undef REGISTER_CPU_ONLY_POOL_KERNELS
@@ -1015,9 +1314,22 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_MAX_POOL_KERNELS);
.TypeConstraint<T>("T") \
.Label("eigen_tensor"), \
MaxPoolingOp<GPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("MaxPoolV2") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("ksize") \
+ .HostMemory("strides") \
+ .TypeConstraint<T>("T") \
+ .Label("eigen_tensor"), \
+ MaxPoolingV2Op<GPUDevice, T>); \
REGISTER_KERNEL_BUILDER( \
Name("MaxPool").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
MaxPoolingNoMaskOp<GPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("MaxPoolV2") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("ksize") \
+ .HostMemory("strides") \
+ .TypeConstraint<T>("T"), \
+ MaxPoolingNoMaskV2Op<GPUDevice, T>); \
REGISTER_KERNEL_BUILDER(Name("MaxPoolWithArgmax") \
.Device(DEVICE_GPU) \
.TypeConstraint<int64>("Targmax") \
diff --git a/tensorflow/core/kernels/pooling_ops_common.h b/tensorflow/core/kernels/pooling_ops_common.h
index 2c097c0ce2..1b59c18df7 100644
--- a/tensorflow/core/kernels/pooling_ops_common.h
+++ b/tensorflow/core/kernels/pooling_ops_common.h
@@ -69,6 +69,8 @@ struct PoolParameters {
};
// An implementation of MaxPooling (forward).
+// TODO (yongtang): Remove MaxPoolingOp and use MaxPoolingV2Op,
+// QuantizedMaxPoolingOp depends on MaxPoolingOp so keep intact for now
template <typename Device, typename T>
class MaxPoolingOp : public OpKernel {
public:
@@ -255,6 +257,219 @@ class MaxPoolingOp : public OpKernel {
};
template <typename Device, typename T>
+class MaxPoolingV2Op : public OpKernel {
+ public:
+ explicit MaxPoolingV2Op(OpKernelConstruction* context) : OpKernel(context) {
+ string data_format;
+ auto status = context->GetAttr("data_format", &data_format);
+ if (status.ok()) {
+ OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
+ errors::InvalidArgument("Invalid data format"));
+ OP_REQUIRES(
+ context, data_format_ == FORMAT_NHWC,
+ errors::InvalidArgument("Default MaxPoolingOp only supports NHWC."));
+ } else {
+ data_format_ = FORMAT_NHWC;
+ }
+ if (context->num_inputs() == 1) {
+ OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
+ OP_REQUIRES(context, ksize_.size() == 4,
+ errors::InvalidArgument("Sliding window ksize field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
+ OP_REQUIRES(context, stride_.size() == 4,
+ errors::InvalidArgument("Sliding window stride field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
+ errors::Unimplemented(
+ "Pooling is not yet supported on the batch dimension."));
+ }
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& tensor_in = context->input(0);
+
+ std::vector<int32> ksize = ksize_;
+ std::vector<int32> stride = stride_;
+
+ if (context->num_inputs() != 1) {
+ const Tensor& tensor_ksize = context->input(1);
+ auto value_ksize = tensor_ksize.flat<int32>();
+ ksize.resize(tensor_ksize.shape().num_elements());
+ std::copy_n(&value_ksize(0), ksize.size(), ksize.begin());
+
+ const Tensor& tensor_stride = context->input(2);
+ auto value_stride = tensor_stride.flat<int32>();
+ stride.resize(tensor_stride.shape().num_elements());
+ std::copy_n(&value_stride(0), stride.size(), stride.begin());
+ }
+
+ OP_REQUIRES(context, ksize.size() == 4,
+ errors::InvalidArgument("Sliding window ksize field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES(context, stride.size() == 4,
+ errors::InvalidArgument("Sliding window stride field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES(context, ksize[0] == 1 && stride[0] == 1,
+ errors::Unimplemented(
+ "Pooling is not yet supported on the batch dimension."));
+
+ PoolParameters params{context, ksize, stride,
+ padding_, FORMAT_NHWC, tensor_in.shape()};
+ if (!context->status().ok()) {
+ return;
+ }
+
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(
+ 0, params.forward_output_shape(), &output));
+
+ if (params.depth_window > 1) {
+ // Validate spec against the current implementation. A
+ // relaxation of these requirements would be ideal.
+ OP_REQUIRES(context, params.depth % params.depth_window == 0,
+ errors::Unimplemented(
+ "Depthwise max pooling requires "
+ "the depth window to evenly divide the input depth."));
+ OP_REQUIRES(
+ context, params.depth_window == params.depth_stride,
+ errors::Unimplemented("Depthwise max pooling requires "
+ "the depth window to equal the depth stride."));
+
+ DepthwiseMaxPool(context, output, tensor_in, params);
+ } else {
+ SpatialMaxPool(context, output, tensor_in, params, padding_);
+ }
+ }
+
+ private:
+ // Single-threaded implementation of DepthwiseMaxPool which
+ // does not handle all of the same options as SpatialMaxPool
+ // (strict assumptions on no padding, stride).
+ //
+ // TODO(vrv): implement a more general depthwise-max pool that works
+ // on GPU as well.
+ void DepthwiseMaxPool(OpKernelContext* context, Tensor* output,
+ const Tensor& tensor_in, const PoolParameters& params) {
+ Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
+ in_by_pool(tensor_in.flat<T>().data(), params.depth_window,
+ tensor_in.NumElements() / params.depth_window);
+ Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>> out_by_pool(
+ output->flat<T>().data(), 1, output->NumElements());
+ out_by_pool = in_by_pool.colwise().maxCoeff();
+ }
+
+ void SpatialMaxPool(OpKernelContext* context, Tensor* output,
+ const Tensor& tensor_in, const PoolParameters& params,
+ const Padding& padding) {
+ // On GPU, use Eigen's Spatial Max Pooling. On CPU, use an
+ // EigenMatrix version that is currently faster than Eigen's
+ // Spatial MaxPooling implementation.
+ //
+ // TODO(vrv): Remove this once we no longer need it.
+ if (std::is_same<Device, GPUDevice>::value) {
+ Eigen::PaddingType pt = BrainPadding2EigenPadding(padding);
+ functor::SpatialMaxPooling<Device, T>()(
+ context->eigen_device<Device>(), output->tensor<T, 4>(),
+ tensor_in.tensor<T, 4>(), params.window_rows, params.window_cols,
+ params.row_stride, params.col_stride, pt);
+ } else {
+ typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
+ ConstEigenMatrixMap;
+ typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
+ EigenMatrixMap;
+
+ ConstEigenMatrixMap in_mat(tensor_in.flat<T>().data(), params.depth,
+ params.tensor_in_cols * params.tensor_in_rows *
+ params.tensor_in_batch);
+ EigenMatrixMap out_mat(
+ output->flat<T>().data(), params.depth,
+ params.out_width * params.out_height * params.tensor_in_batch);
+
+ const DeviceBase::CpuWorkerThreads& worker_threads =
+ *(context->device()->tensorflow_cpu_worker_threads());
+
+ // The following code basically does the following:
+ // 1. Flattens the input and output tensors into two dimensional arrays.
+ // tensor_in_as_matrix:
+ // depth by (tensor_in_cols * tensor_in_rows * tensor_in_batch)
+ // output_as_matrix:
+ // depth by (out_width * out_height * tensor_in_batch)
+ //
+ // 2. Walks through the set of columns in the flattened
+ // tensor_in_as_matrix,
+ // and updates the corresponding column(s) in output_as_matrix with the
+ // max value.
+ auto shard = [&params, &in_mat, &out_mat](int64 start, int64 limit) {
+
+ const int32 in_rows = params.tensor_in_rows;
+ const int32 in_cols = params.tensor_in_cols;
+ const int32 pad_rows = params.pad_rows;
+ const int32 pad_cols = params.pad_cols;
+ const int32 window_rows = params.window_rows;
+ const int32 window_cols = params.window_cols;
+ const int32 row_stride = params.row_stride;
+ const int32 col_stride = params.col_stride;
+ const int32 out_height = params.out_height;
+ const int32 out_width = params.out_width;
+
+ {
+ // Initializes the output tensor with MIN<T>.
+ const int32 output_image_size = out_height * out_width * params.depth;
+ EigenMatrixMap out_shard(out_mat.data() + start * output_image_size,
+ 1, (limit - start) * output_image_size);
+ out_shard.setConstant(Eigen::NumTraits<T>::lowest());
+ }
+
+ for (int32 b = start; b < limit; ++b) {
+ const int32 out_offset_batch = b * out_height;
+ for (int32 h = 0; h < in_rows; ++h) {
+ for (int32 w = 0; w < in_cols; ++w) {
+ // (h_start, h_end) * (w_start, w_end) is the range that the input
+ // vector projects to.
+ const int32 hpad = h + pad_rows;
+ const int32 wpad = w + pad_cols;
+ const int32 h_start = (hpad < window_rows)
+ ? 0
+ : (hpad - window_rows) / row_stride + 1;
+ const int32 h_end = std::min(hpad / row_stride + 1, out_height);
+ const int32 w_start = (wpad < window_cols)
+ ? 0
+ : (wpad - window_cols) / col_stride + 1;
+ const int32 w_end = std::min(wpad / col_stride + 1, out_width);
+ // compute elementwise max
+ const int32 in_offset = (b * in_rows + h) * in_cols + w;
+ for (int32 ph = h_start; ph < h_end; ++ph) {
+ const int32 out_offset_base =
+ (out_offset_batch + ph) * out_width;
+ for (int32 pw = w_start; pw < w_end; ++pw) {
+ const int32 out_offset = out_offset_base + pw;
+ out_mat.col(out_offset) =
+ out_mat.col(out_offset).cwiseMax(in_mat.col(in_offset));
+ }
+ }
+ }
+ }
+ }
+ };
+
+ // TODO(andydavis) Consider sharding across batch x rows x cols.
+ // TODO(andydavis) Consider a higher resolution shard cost model.
+ const int64 shard_cost =
+ params.tensor_in_rows * params.tensor_in_cols * params.depth;
+ Shard(worker_threads.num_threads, worker_threads.workers,
+ params.tensor_in_batch, shard_cost, shard);
+ }
+ }
+
+ std::vector<int32> ksize_;
+ std::vector<int32> stride_;
+ Padding padding_;
+ TensorFormat data_format_;
+};
+
+template <typename Device, typename T>
void SpatialAvgPool(OpKernelContext* context, Tensor* output,
const Tensor& input, const PoolParameters& params,
const Padding& padding) {
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index 1018742521..0a96258dd1 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -1368,6 +1368,34 @@ input: 4-D input to pool over.
output: The max pooled output tensor.
)doc");
+REGISTER_OP("MaxPoolV2")
+ .Attr("T: realnumbertype = DT_FLOAT")
+ .Attr(GetPaddingAttrString())
+ .Attr(GetConvnetDataFormatAttrString())
+ .Input("input: T")
+ .Input("ksize: int32")
+ .Input("strides: int32")
+ .Output("output: T")
+ .SetShapeFn([](InferenceContext* c) {
+ TF_RETURN_IF_ERROR(shape_inference::MaxPoolV2Shape(c, 3));
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Performs max pooling on the input.
+
+ksize: The size of the window for each dimension of the input tensor.
+strides: The stride of the sliding window for each dimension of the
+ input tensor.
+padding: The type of padding algorithm to use.
+data_format: Specify the data format of the input and output data. With the
+ default format "NHWC", the data is stored in the order of:
+ [batch, in_height, in_width, in_channels].
+ Alternatively, the format could be "NCHW", the data storage order of:
+ [batch, in_channels, in_height, in_width].
+input: 4-D input to pool over.
+output: The max pooled output tensor.
+)doc");
+
REGISTER_OP("MaxPoolGrad")
.Attr("ksize: list(int) >= 4")
.Attr("strides: list(int) >= 4")
@@ -1399,6 +1427,37 @@ grad: 4-D. Gradients w.r.t. the output of `max_pool`.
output: Gradients w.r.t. the input to `max_pool`.
)doc");
+REGISTER_OP("MaxPoolGradV2")
+ .Attr(GetPaddingAttrString())
+ .Attr(GetConvnetDataFormatAttrString())
+ .Input("orig_input: T")
+ .Input("orig_output: T")
+ .Input("grad: T")
+ .Input("ksize: int32")
+ .Input("strides: int32")
+ .Output("output: T")
+ .Attr("T: realnumbertype = DT_FLOAT")
+ .SetShapeFn([](InferenceContext* c) {
+ return UnchangedShapeWithRank(c, 4);
+ })
+ .Doc(R"doc(
+Computes gradients of the maxpooling function.
+
+ksize: The size of the window for each dimension of the input tensor.
+strides: The stride of the sliding window for each dimension of the
+ input tensor.
+padding: The type of padding algorithm to use.
+data_format: Specify the data format of the input and output data. With the
+ default format "NHWC", the data is stored in the order of:
+ [batch, in_height, in_width, in_channels].
+ Alternatively, the format could be "NCHW", the data storage order of:
+ [batch, in_channels, in_height, in_width].
+orig_input: The original input tensor.
+orig_output: The original output tensor.
+grad: 4-D. Gradients w.r.t. the output of `max_pool`.
+output: Gradients w.r.t. the input to `max_pool`.
+)doc");
+
REGISTER_OP("MaxPoolGradGrad")
.Attr("ksize: list(int) >= 4")
.Attr("strides: list(int) >= 4")
@@ -1436,6 +1495,43 @@ grad: 4-D. Gradients of gradients w.r.t. the input of `max_pool`.
output: Gradients of gradients w.r.t. the input to `max_pool`.
)doc");
+REGISTER_OP("MaxPoolGradGradV2")
+ .Attr(GetPaddingAttrString())
+ .Attr(GetConvnetDataFormatAttrString())
+ .Input("orig_input: T")
+ .Input("orig_output: T")
+ .Input("grad: T")
+ .Input("ksize: int32")
+ .Input("strides: int32")
+ .Output("output: T")
+ .Attr("T: realnumbertype")
+ .SetShapeFn([](InferenceContext* c) {
+ TF_RETURN_IF_ERROR(shape_inference::MaxPoolV2Shape(c, 5));
+ ShapeHandle unused;
+ // Validate 'orig_input' is the same shape as 'grad'
+ TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(2), &unused));
+ // Validate 'orig_output' is same shape as 'output'
+ TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->output(0), &unused));
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Computes second-order gradients of the maxpooling function.
+
+ksize: The size of the window for each dimension of the input tensor.
+strides: The stride of the sliding window for each dimension of the
+ input tensor.
+padding: The type of padding algorithm to use.
+data_format: Specify the data format of the input and output data. With the
+ default format "NHWC", the data is stored in the order of:
+ [batch, in_height, in_width, in_channels].
+ Alternatively, the format could be "NCHW", the data storage order of:
+ [batch, in_channels, in_height, in_width].
+orig_input: The original input tensor.
+orig_output: The original output tensor.
+grad: 4-D. Gradients of gradients w.r.t. the input of `max_pool`.
+output: Gradients of gradients w.r.t. the input to `max_pool`.
+)doc");
+
REGISTER_OP("MaxPoolWithArgmax")
.Attr("ksize: list(int) >= 4")
.Attr("strides: list(int) >= 4")
diff --git a/tensorflow/python/kernel_tests/pooling_ops_test.py b/tensorflow/python/kernel_tests/pooling_ops_test.py
index f5fb7e4e03..da14871c87 100644
--- a/tensorflow/python/kernel_tests/pooling_ops_test.py
+++ b/tensorflow/python/kernel_tests/pooling_ops_test.py
@@ -29,6 +29,7 @@ from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import nn_ops
+from tensorflow.python.framework import ops
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
from tensorflow.python.platform import test
@@ -76,7 +77,7 @@ def GetShrunkInceptionMaxPoolShapes(shrink=30):
class PoolingTest(test.TestCase):
def _VerifyOneType(self, pool_func, input_sizes, ksize, strides, padding,
- data_format, data_type, expected, use_gpu):
+ data_format, data_type, expected, use_gpu, v2):
"""Verifies the output values of the pooling function.
Args:
@@ -103,20 +104,35 @@ class PoolingTest(test.TestCase):
t = test_util.NHWCToNCHW(t)
ksize = test_util.NHWCToNCHW(ksize)
strides = test_util.NHWCToNCHW(strides)
- t = pool_func(
- t,
- ksize=ksize,
- strides=strides,
- padding=padding,
- data_format=data_format)
+ v2 = v2 and data_format != "NCHW"
+ ksize_placeholder = array_ops.placeholder(dtypes.int32, shape=[4])
+ strides_placeholder = array_ops.placeholder(dtypes.int32, shape=[4])
+ if v2:
+ t = pool_func(
+ t,
+ ksize=ksize_placeholder,
+ strides=strides_placeholder,
+ padding=padding,
+ data_format=data_format)
+ else:
+ t = pool_func(
+ t,
+ ksize=ksize,
+ strides=strides,
+ padding=padding,
+ data_format=data_format)
if data_format == "NCHW":
t = test_util.NCHWToNHWC(t)
- actual = t.eval()
+ if v2:
+ actual = t.eval(feed_dict={ksize_placeholder: ksize,
+ strides_placeholder: strides})
+ else:
+ actual = t.eval()
+ self.assertShapeEqual(actual, t)
self.assertAllCloseAccordingToType(expected, actual.flatten())
- self.assertShapeEqual(actual, t)
def _VerifyOneTest(self, pool_func, input_sizes, ksize, strides, padding,
- data_format, expected, use_gpu):
+ data_format, expected, use_gpu, v2):
"""Verifies the output values of the pooling function.
Args:
@@ -131,14 +147,14 @@ class PoolingTest(test.TestCase):
use_gpu: Whether we are running on GPU.
"""
self._VerifyOneType(pool_func, input_sizes, ksize, strides, padding,
- data_format, dtypes.float32, expected, use_gpu)
+ data_format, dtypes.float32, expected, use_gpu, v2)
if not use_gpu or test_util.CudaSupportsHalfMatMulAndConv():
self._VerifyOneType(pool_func, input_sizes, ksize, strides, padding,
- data_format, dtypes.float16, expected, use_gpu)
+ data_format, dtypes.float16, expected, use_gpu, v2)
def _VerifyValues(self, pool_func, input_sizes, ksize, strides, padding,
- expected, use_gpu):
+ expected, use_gpu, v2=False):
"""Verifies the output values of the pooling function.
Args:
@@ -154,7 +170,7 @@ class PoolingTest(test.TestCase):
for (data_format, use_gpu_2) in GetTestConfigs():
if use_gpu_2 == use_gpu:
self._VerifyOneTest(pool_func, input_sizes, ksize, strides, padding,
- data_format, expected, use_gpu)
+ data_format, expected, use_gpu, v2)
def _testAvgPoolValidPadding(self, use_gpu):
expected_output = [7.0, 8.0, 9.0]
@@ -325,6 +341,17 @@ class PoolingTest(test.TestCase):
expected=expected_output,
use_gpu=use_gpu)
+ for v2 in [True, False]:
+ self._VerifyValues(
+ gen_nn_ops._max_pool_v2,
+ input_sizes=[1, 3, 3, 3],
+ ksize=[1, 2, 2, 1],
+ strides=[1, 2, 2, 1],
+ padding="VALID",
+ expected=expected_output,
+ use_gpu=use_gpu,
+ v2=v2)
+
def _testMaxPoolSamePadding(self, use_gpu):
expected_output = [13.0, 14.0, 15.0, 16.0, 17.0, 18.0]
self._VerifyValues(
@@ -336,6 +363,17 @@ class PoolingTest(test.TestCase):
expected=expected_output,
use_gpu=use_gpu)
+ for v2 in [True, False]:
+ self._VerifyValues(
+ gen_nn_ops._max_pool_v2,
+ input_sizes=[1, 2, 3, 3],
+ ksize=[1, 2, 2, 1],
+ strides=[1, 2, 2, 1],
+ padding="SAME",
+ expected=expected_output,
+ use_gpu=use_gpu,
+ v2=v2)
+
def _testMaxPoolSamePaddingNonSquareWindow(self, use_gpu):
# input is:
# [1.0, 2.0
@@ -354,6 +392,17 @@ class PoolingTest(test.TestCase):
expected=[2.0, 2.0, 4.0, 4.0],
use_gpu=use_gpu)
+ for v2 in [True, False]:
+ self._VerifyValues(
+ gen_nn_ops._max_pool_v2,
+ input_sizes=[1, 2, 2, 1],
+ ksize=[1, 1, 2, 1],
+ strides=[1, 1, 1, 1],
+ padding="SAME",
+ expected=[2.0, 2.0, 4.0, 4.0],
+ use_gpu=use_gpu,
+ v2=v2)
+
def _testMaxPoolValidPaddingUnevenStride(self, use_gpu):
self._VerifyValues(
nn_ops.max_pool,
@@ -372,6 +421,26 @@ class PoolingTest(test.TestCase):
expected=[6.0, 7.0, 8.0, 14.0, 15.0, 16.0],
use_gpu=use_gpu)
+ for v2 in [True, False]:
+ self._VerifyValues(
+ gen_nn_ops._max_pool_v2,
+ input_sizes=[1, 4, 4, 1],
+ ksize=[1, 2, 2, 1],
+ strides=[1, 1, 2, 1],
+ padding="VALID",
+ expected=[6.0, 8.0, 10.0, 12.0, 14.0, 16.0],
+ use_gpu=use_gpu,
+ v2=v2)
+ self._VerifyValues(
+ gen_nn_ops._max_pool_v2,
+ input_sizes=[1, 4, 4, 1],
+ ksize=[1, 2, 2, 1],
+ strides=[1, 2, 1, 1],
+ padding="VALID",
+ expected=[6.0, 7.0, 8.0, 14.0, 15.0, 16.0],
+ use_gpu=use_gpu,
+ v2=v2)
+
def _testMaxPoolSamePaddingPacket4(self, use_gpu):
expected_output = [
21.0, 22.0, 23.0, 24.0, 29.0, 30.0, 31.0, 32.0, 53.0, 54.0, 55.0, 56.0,
@@ -386,6 +455,17 @@ class PoolingTest(test.TestCase):
expected=expected_output,
use_gpu=use_gpu)
+ for v2 in [True, False]:
+ self._VerifyValues(
+ gen_nn_ops._max_pool_v2,
+ input_sizes=[1, 4, 4, 4],
+ ksize=[1, 2, 2, 1],
+ strides=[1, 2, 2, 1],
+ padding="SAME",
+ expected=expected_output,
+ use_gpu=use_gpu,
+ v2=v2)
+
def _testMaxPoolSamePaddingPacket8(self, use_gpu):
expected_output = [
145.0, 146.0, 147.0, 148.0, 149.0, 150.0, 151.0, 152.0, 161.0, 162.0,
@@ -411,6 +491,17 @@ class PoolingTest(test.TestCase):
expected=expected_output,
use_gpu=use_gpu)
+ for v2 in [True, False]:
+ self._VerifyValues(
+ gen_nn_ops._max_pool_v2,
+ input_sizes=[1, 8, 8, 8],
+ ksize=[1, 3, 3, 1],
+ strides=[1, 2, 2, 1],
+ padding="SAME",
+ expected=expected_output,
+ use_gpu=use_gpu,
+ v2=v2)
+
def testMaxPooling(self):
for use_gpu in True, False:
self._testMaxPoolValidPadding(use_gpu)
@@ -435,6 +526,17 @@ class PoolingTest(test.TestCase):
expected=[2.0, 4.0, 6.0, 8.0, 10.0],
use_gpu=False)
+ for v2 in [True, False]:
+ self._VerifyValues(
+ gen_nn_ops._max_pool_v2,
+ input_sizes=[1, 1, 1, 10],
+ ksize=[1, 1, 1, 2],
+ strides=[1, 1, 1, 2],
+ padding="SAME",
+ expected=[2.0, 4.0, 6.0, 8.0, 10.0],
+ use_gpu=False,
+ v2=v2)
+
def testDepthwiseMaxPool2x2DepthWindow3(self):
# input is:
#
@@ -450,6 +552,17 @@ class PoolingTest(test.TestCase):
expected=[3.0, 6.0, 9.0, 12.0, 15.0, 18.0, 21.0, 24.0],
use_gpu=False)
+ for v2 in [True, False]:
+ self._VerifyValues(
+ gen_nn_ops._max_pool_v2,
+ input_sizes=[1, 2, 2, 6],
+ ksize=[1, 1, 1, 3],
+ strides=[1, 1, 1, 3],
+ padding="SAME",
+ expected=[3.0, 6.0, 9.0, 12.0, 15.0, 18.0, 21.0, 24.0],
+ use_gpu=False,
+ v2=v2)
+
def testKernelSmallerThanStrideValid(self):
for use_gpu in [True, False]:
self._VerifyValues(
@@ -461,6 +574,17 @@ class PoolingTest(test.TestCase):
expected=[9, 12, 30, 33],
use_gpu=use_gpu)
+ for v2 in [True, False]:
+ self._VerifyValues(
+ gen_nn_ops._max_pool_v2,
+ input_sizes=[1, 7, 7, 1],
+ ksize=[1, 2, 2, 1],
+ strides=[1, 3, 3, 1],
+ padding="VALID",
+ expected=[9, 12, 30, 33],
+ use_gpu=use_gpu,
+ v2=v2)
+
self._VerifyValues(
nn_ops.avg_pool,
input_sizes=[1, 7, 7, 1],
@@ -491,6 +615,27 @@ class PoolingTest(test.TestCase):
expected=[1, 3, 9, 11],
use_gpu=use_gpu)
+ for v2 in [True, False]:
+ self._VerifyValues(
+ gen_nn_ops._max_pool_v2,
+ input_sizes=[1, 3, 3, 1],
+ ksize=[1, 1, 1, 1],
+ strides=[1, 2, 2, 1],
+ padding="SAME",
+ expected=[1, 3, 7, 9],
+ use_gpu=use_gpu,
+ v2=v2)
+
+ self._VerifyValues(
+ gen_nn_ops._max_pool_v2,
+ input_sizes=[1, 4, 4, 1],
+ ksize=[1, 1, 1, 1],
+ strides=[1, 2, 2, 1],
+ padding="SAME",
+ expected=[1, 3, 9, 11],
+ use_gpu=use_gpu,
+ v2=v2)
+
def _testDepthwiseMaxPoolInvalidConfig(self,
in_size,
ksize,
@@ -812,99 +957,107 @@ class PoolingTest(test.TestCase):
self.assertLess(err, err_tolerance)
def _testMaxPoolGradValidPadding1_1(self, data_format, use_gpu):
- self._ConstructAndTestGradient(
- nn_ops.max_pool,
- input_sizes=[1, 3, 3, 1],
- output_sizes=[1, 3, 3, 1],
- window_rows=1,
- window_cols=1,
- row_stride=1,
- col_stride=1,
- padding="VALID",
- data_format=data_format,
- use_gpu=use_gpu)
+ for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]:
+ self._ConstructAndTestGradient(
+ pool_func,
+ input_sizes=[1, 3, 3, 1],
+ output_sizes=[1, 3, 3, 1],
+ window_rows=1,
+ window_cols=1,
+ row_stride=1,
+ col_stride=1,
+ padding="VALID",
+ data_format=data_format,
+ use_gpu=use_gpu)
def _testMaxPoolGradValidPadding2_1_6(self, data_format, use_gpu):
- self._ConstructAndTestGradient(
- nn_ops.max_pool,
- input_sizes=[2, 6, 6, 3],
- output_sizes=[2, 5, 5, 3],
- window_rows=2,
- window_cols=2,
- row_stride=1,
- col_stride=1,
- padding="VALID",
- data_format=data_format,
- use_gpu=use_gpu)
+ for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]:
+ self._ConstructAndTestGradient(
+ pool_func,
+ input_sizes=[2, 6, 6, 3],
+ output_sizes=[2, 5, 5, 3],
+ window_rows=2,
+ window_cols=2,
+ row_stride=1,
+ col_stride=1,
+ padding="VALID",
+ data_format=data_format,
+ use_gpu=use_gpu)
def _testMaxPoolGradValidPadding2_1_7(self, data_format, use_gpu):
- self._ConstructAndTestGradient(
- nn_ops.max_pool,
- input_sizes=[2, 7, 7, 3],
- output_sizes=[2, 6, 6, 3],
- window_rows=2,
- window_cols=2,
- row_stride=1,
- col_stride=1,
- padding="VALID",
- data_format=data_format,
- use_gpu=use_gpu)
+ for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]:
+ self._ConstructAndTestGradient(
+ pool_func,
+ input_sizes=[2, 7, 7, 3],
+ output_sizes=[2, 6, 6, 3],
+ window_rows=2,
+ window_cols=2,
+ row_stride=1,
+ col_stride=1,
+ padding="VALID",
+ data_format=data_format,
+ use_gpu=use_gpu)
def _testMaxPoolGradValidPadding2_2(self, data_format, use_gpu):
- self._ConstructAndTestGradient(
- nn_ops.max_pool,
- input_sizes=[2, 2, 2, 3],
- output_sizes=[2, 1, 1, 3],
- window_rows=2,
- window_cols=2,
- row_stride=2,
- col_stride=2,
- padding="VALID",
- data_format=data_format,
- use_gpu=use_gpu)
+ for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]:
+ self._ConstructAndTestGradient(
+ pool_func,
+ input_sizes=[2, 2, 2, 3],
+ output_sizes=[2, 1, 1, 3],
+ window_rows=2,
+ window_cols=2,
+ row_stride=2,
+ col_stride=2,
+ padding="VALID",
+ data_format=data_format,
+ use_gpu=use_gpu)
def _testMaxPoolGradSamePadding1_1(self, data_format, use_gpu):
- self._ConstructAndTestGradient(
- nn_ops.max_pool,
- input_sizes=[2, 2, 4, 3],
- output_sizes=[2, 2, 4, 3],
- window_rows=1,
- window_cols=1,
- row_stride=1,
- col_stride=1,
- padding="SAME",
- data_format=data_format,
- use_gpu=use_gpu)
+ for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]:
+ self._ConstructAndTestGradient(
+ pool_func,
+ input_sizes=[2, 2, 4, 3],
+ output_sizes=[2, 2, 4, 3],
+ window_rows=1,
+ window_cols=1,
+ row_stride=1,
+ col_stride=1,
+ padding="SAME",
+ data_format=data_format,
+ use_gpu=use_gpu)
def _testMaxPoolGradSamePadding2_1(self, data_format, use_gpu):
- self._ConstructAndTestGradient(
- nn_ops.max_pool,
- input_sizes=[2, 2, 4, 3],
- output_sizes=[2, 2, 4, 3],
- window_rows=2,
- window_cols=2,
- row_stride=1,
- col_stride=1,
- padding="SAME",
- data_format=data_format,
- use_gpu=use_gpu)
+ for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]:
+ self._ConstructAndTestGradient(
+ pool_func,
+ input_sizes=[2, 2, 4, 3],
+ output_sizes=[2, 2, 4, 3],
+ window_rows=2,
+ window_cols=2,
+ row_stride=1,
+ col_stride=1,
+ padding="SAME",
+ data_format=data_format,
+ use_gpu=use_gpu)
def _testMaxPoolGradSamePadding2_2(self, data_format, use_gpu):
- self._ConstructAndTestGradient(
- nn_ops.max_pool,
- input_sizes=[2, 2, 4, 3],
- output_sizes=[2, 1, 2, 3],
- window_rows=2,
- window_cols=2,
- row_stride=2,
- col_stride=2,
- padding="SAME",
- data_format=data_format,
- use_gpu=use_gpu)
+ for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]:
+ self._ConstructAndTestGradient(
+ pool_func,
+ input_sizes=[2, 2, 4, 3],
+ output_sizes=[2, 1, 2, 3],
+ window_rows=2,
+ window_cols=2,
+ row_stride=2,
+ col_stride=2,
+ padding="SAME",
+ data_format=data_format,
+ use_gpu=use_gpu)
def _testMaxPoolGradSamePadding3_1(self, data_format, use_gpu):
- self._ConstructAndTestGradient(
- nn_ops.max_pool,
+ for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]:
+ self._ConstructAndTestGradient(
+ pool_func,
input_sizes=[1, 7, 7, 1],
output_sizes=[1, 7, 7, 1],
window_rows=3,
@@ -927,7 +1080,7 @@ class PoolingTest(test.TestCase):
self._testMaxPoolGradSamePadding3_1(data_format, use_gpu)
def _MaxPoolGrad(self, orig_input, orig_output, grad, window_rows,
- window_cols, row_stride, col_stride, padding):
+ window_cols, row_stride, col_stride, padding, v2):
"""Max Pooling Gradient.
Args:
@@ -944,26 +1097,29 @@ class PoolingTest(test.TestCase):
Returns:
A Tensor.
"""
- return gen_nn_ops._max_pool_grad(orig_input, orig_output, grad,
- [1, window_rows, window_cols, 1],
- [1, row_stride, col_stride, 1], padding)
+ pool_func = gen_nn_ops.max_pool_grad_v2 if v2 else gen_nn_ops._max_pool_grad
+ return pool_func(orig_input, orig_output, grad,
+ [1, window_rows, window_cols, 1],
+ [1, row_stride, col_stride, 1], padding)
def _testMaxPoolGradDirect(self, input_data, output_backprop,
expected_input_backprop, input_sizes, output_sizes,
window_rows, window_cols, row_stride, col_stride,
- padding, use_gpu):
+ padding, use_gpu, v2):
+ pool_func = gen_nn_ops._max_pool_v2 if v2 else nn_ops.max_pool
with self.test_session(use_gpu=use_gpu):
input_tensor = constant_op.constant(input_data, shape=input_sizes)
- output_tensor = nn_ops.max_pool(input_tensor,
- [1, window_rows, window_cols, 1],
- [1, row_stride, col_stride, 1], padding)
+ output_tensor = pool_func(input_tensor,
+ [1, window_rows, window_cols, 1],
+ [1, row_stride, col_stride, 1], padding)
output_backprop_tensor = constant_op.constant(
output_backprop, shape=output_sizes)
input_backprop_tensor = self._MaxPoolGrad(input_tensor, output_tensor,
output_backprop_tensor,
window_rows, window_cols,
- row_stride, col_stride, padding)
+ row_stride, col_stride,
+ padding, v2)
actual_input_backprop = input_backprop_tensor.eval()
self.assertShapeEqual(actual_input_backprop, input_backprop_tensor)
@@ -988,18 +1144,20 @@ class PoolingTest(test.TestCase):
]
for use_gpu in True, False:
- self._testMaxPoolGradDirect(
- input_data,
- output_backprop,
- expected_input_backprop,
- input_sizes=[1, 4, 4, 1],
- output_sizes=[1, 3, 3, 1],
- window_rows=2,
- window_cols=2,
- row_stride=1,
- col_stride=1,
- padding="VALID",
- use_gpu=use_gpu)
+ for v2 in [True, False]:
+ self._testMaxPoolGradDirect(
+ input_data,
+ output_backprop,
+ expected_input_backprop,
+ input_sizes=[1, 4, 4, 1],
+ output_sizes=[1, 3, 3, 1],
+ window_rows=2,
+ window_cols=2,
+ row_stride=1,
+ col_stride=1,
+ padding="VALID",
+ use_gpu=use_gpu,
+ v2=v2)
def _testMaxPoolGradDirect1_2(self):
input_data = [
@@ -1013,18 +1171,20 @@ class PoolingTest(test.TestCase):
]
for use_gpu in True, False:
- self._testMaxPoolGradDirect(
- input_data,
- output_backprop,
- expected_input_backprop,
- input_sizes=[1, 4, 4, 1],
- output_sizes=[1, 3, 3, 1],
- window_rows=2,
- window_cols=2,
- row_stride=1,
- col_stride=1,
- padding="VALID",
- use_gpu=use_gpu)
+ for v2 in [True, False]:
+ self._testMaxPoolGradDirect(
+ input_data,
+ output_backprop,
+ expected_input_backprop,
+ input_sizes=[1, 4, 4, 1],
+ output_sizes=[1, 3, 3, 1],
+ window_rows=2,
+ window_cols=2,
+ row_stride=1,
+ col_stride=1,
+ padding="VALID",
+ use_gpu=use_gpu,
+ v2=v2)
def _testMaxPoolGradDirect1_3(self):
input_data = [
@@ -1069,18 +1229,20 @@ class PoolingTest(test.TestCase):
]
for use_gpu in True, False:
- self._testMaxPoolGradDirect(
- input_data,
- output_backprop,
- expected_input_backprop,
- input_sizes=[1, 4, 4, 1],
- output_sizes=[1, 4, 4, 1],
- window_rows=3,
- window_cols=3,
- row_stride=1,
- col_stride=1,
- padding="SAME",
- use_gpu=use_gpu)
+ for v2 in [True, False]:
+ self._testMaxPoolGradDirect(
+ input_data,
+ output_backprop,
+ expected_input_backprop,
+ input_sizes=[1, 4, 4, 1],
+ output_sizes=[1, 4, 4, 1],
+ window_rows=3,
+ window_cols=3,
+ row_stride=1,
+ col_stride=1,
+ padding="SAME",
+ use_gpu=use_gpu,
+ v2=v2)
def _testMaxPoolGradDirectWithNans2_1(self):
input_data = [float("nan")] * 16
@@ -1090,18 +1252,20 @@ class PoolingTest(test.TestCase):
11.0, 12.0, 13.0, 0.0, 15.0, 16.0, 17.0, 0.0, 19.0, 20.0, 21.0, 0.0,
0.0, 0.0, 0.0, 0.0
]
- self._testMaxPoolGradDirect(
- input_data,
- output_backprop,
- expected_input_backprop_tf_cpu,
- input_sizes=[1, 4, 4, 1],
- output_sizes=[1, 3, 3, 1],
- window_rows=2,
- window_cols=2,
- row_stride=1,
- col_stride=1,
- padding="VALID",
- use_gpu=False)
+ for v2 in [True, False]:
+ self._testMaxPoolGradDirect(
+ input_data,
+ output_backprop,
+ expected_input_backprop_tf_cpu,
+ input_sizes=[1, 4, 4, 1],
+ output_sizes=[1, 3, 3, 1],
+ window_rows=2,
+ window_cols=2,
+ row_stride=1,
+ col_stride=1,
+ padding="VALID",
+ use_gpu=False,
+ v2=v2)
if not test.is_gpu_available():
return
@@ -1112,18 +1276,20 @@ class PoolingTest(test.TestCase):
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0
]
- self._testMaxPoolGradDirect(
- input_data,
- output_backprop,
- expected_input_backprop_cudnn,
- input_sizes=[1, 4, 4, 1],
- output_sizes=[1, 3, 3, 1],
- window_rows=2,
- window_cols=2,
- row_stride=1,
- col_stride=1,
- padding="VALID",
- use_gpu=True)
+ for v2 in [True, False]:
+ self._testMaxPoolGradDirect(
+ input_data,
+ output_backprop,
+ expected_input_backprop_cudnn,
+ input_sizes=[1, 4, 4, 1],
+ output_sizes=[1, 3, 3, 1],
+ window_rows=2,
+ window_cols=2,
+ row_stride=1,
+ col_stride=1,
+ padding="VALID",
+ use_gpu=True,
+ v2=v2)
def _testMaxPoolGradDirectWithNans2_2(self):
input_data = [float("nan")] * 16
@@ -1136,18 +1302,20 @@ class PoolingTest(test.TestCase):
float("nan"), 12.0, 13.0, 0.0, 15.0, float("nan"), 17.0, 0.0, 19.0,
20.0, float("nan"), 0.0, 0.0, 0.0, 0.0, 0.0
]
- self._testMaxPoolGradDirect(
- input_data,
- output_backprop,
- expected_input_backprop_tf_cpu,
- input_sizes=[1, 4, 4, 1],
- output_sizes=[1, 3, 3, 1],
- window_rows=2,
- window_cols=2,
- row_stride=1,
- col_stride=1,
- padding="VALID",
- use_gpu=False)
+ for v2 in [True, False]:
+ self._testMaxPoolGradDirect(
+ input_data,
+ output_backprop,
+ expected_input_backprop_tf_cpu,
+ input_sizes=[1, 4, 4, 1],
+ output_sizes=[1, 3, 3, 1],
+ window_rows=2,
+ window_cols=2,
+ row_stride=1,
+ col_stride=1,
+ padding="VALID",
+ use_gpu=False,
+ v2=v2)
if not test.is_gpu_available():
return
@@ -1158,18 +1326,20 @@ class PoolingTest(test.TestCase):
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0
]
- self._testMaxPoolGradDirect(
- input_data,
- output_backprop,
- expected_input_backprop_cudnn,
- input_sizes=[1, 4, 4, 1],
- output_sizes=[1, 3, 3, 1],
- window_rows=2,
- window_cols=2,
- row_stride=1,
- col_stride=1,
- padding="VALID",
- use_gpu=True)
+ for v2 in [True, False]:
+ self._testMaxPoolGradDirect(
+ input_data,
+ output_backprop,
+ expected_input_backprop_cudnn,
+ input_sizes=[1, 4, 4, 1],
+ output_sizes=[1, 3, 3, 1],
+ window_rows=2,
+ window_cols=2,
+ row_stride=1,
+ col_stride=1,
+ padding="VALID",
+ use_gpu=True,
+ v2=v2)
def testMaxPoolGradDirect(self):
self._testMaxPoolGradDirect1_1()
@@ -1179,108 +1349,116 @@ class PoolingTest(test.TestCase):
self._testMaxPoolGradDirectWithNans2_2()
def _testMaxPoolGradGradValidPadding1_1(self, data_format, use_gpu):
- self._ConstructAndTestSecondGradient(
- nn_ops.max_pool,
- input_sizes=[1, 3, 3, 1],
- output_sizes=[1, 3, 3, 1],
- window_rows=1,
- window_cols=1,
- row_stride=1,
- col_stride=1,
- padding="VALID",
- data_format=data_format,
- use_gpu=use_gpu)
+ for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]:
+ self._ConstructAndTestSecondGradient(
+ pool_func,
+ input_sizes=[1, 3, 3, 1],
+ output_sizes=[1, 3, 3, 1],
+ window_rows=1,
+ window_cols=1,
+ row_stride=1,
+ col_stride=1,
+ padding="VALID",
+ data_format=data_format,
+ use_gpu=use_gpu)
def _testMaxPoolGradGradValidPadding2_1_6(self, data_format, use_gpu):
- self._ConstructAndTestSecondGradient(
- nn_ops.max_pool,
- input_sizes=[2, 6, 6, 3],
- output_sizes=[2, 5, 5, 3],
- window_rows=2,
- window_cols=2,
- row_stride=1,
- col_stride=1,
- padding="VALID",
- data_format=data_format,
- use_gpu=use_gpu)
+ for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]:
+ self._ConstructAndTestSecondGradient(
+ pool_func,
+ input_sizes=[2, 6, 6, 3],
+ output_sizes=[2, 5, 5, 3],
+ window_rows=2,
+ window_cols=2,
+ row_stride=1,
+ col_stride=1,
+ padding="VALID",
+ data_format=data_format,
+ use_gpu=use_gpu)
def _testMaxPoolGradGradValidPadding2_1_7(self, data_format, use_gpu):
- self._ConstructAndTestSecondGradient(
- nn_ops.max_pool,
- input_sizes=[2, 7, 7, 3],
- output_sizes=[2, 6, 6, 3],
- window_rows=2,
- window_cols=2,
- row_stride=1,
- col_stride=1,
- padding="VALID",
- data_format=data_format,
- use_gpu=use_gpu)
+ for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]:
+ self._ConstructAndTestSecondGradient(
+ pool_func,
+ input_sizes=[2, 7, 7, 3],
+ output_sizes=[2, 6, 6, 3],
+ window_rows=2,
+ window_cols=2,
+ row_stride=1,
+ col_stride=1,
+ padding="VALID",
+ data_format=data_format,
+ use_gpu=use_gpu)
def _testMaxPoolGradGradValidPadding2_2(self, data_format, use_gpu):
- self._ConstructAndTestSecondGradient(
- nn_ops.max_pool,
- input_sizes=[2, 2, 2, 3],
- output_sizes=[2, 1, 1, 3],
- window_rows=2,
- window_cols=2,
- row_stride=2,
- col_stride=2,
- padding="VALID",
- data_format=data_format,
- use_gpu=use_gpu)
+ for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]:
+ self._ConstructAndTestSecondGradient(
+ pool_func,
+ input_sizes=[2, 2, 2, 3],
+ output_sizes=[2, 1, 1, 3],
+ window_rows=2,
+ window_cols=2,
+ row_stride=2,
+ col_stride=2,
+ padding="VALID",
+ data_format=data_format,
+ use_gpu=use_gpu)
def _testMaxPoolGradGradSamePadding1_1(self, data_format, use_gpu):
- self._ConstructAndTestSecondGradient(
- nn_ops.max_pool,
- input_sizes=[2, 2, 4, 3],
- output_sizes=[2, 2, 4, 3],
- window_rows=1,
- window_cols=1,
- row_stride=1,
- col_stride=1,
- padding="SAME",
- data_format=data_format,
- use_gpu=use_gpu)
+ for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]:
+ self._ConstructAndTestSecondGradient(
+ pool_func,
+ input_sizes=[2, 2, 4, 3],
+ output_sizes=[2, 2, 4, 3],
+ window_rows=1,
+ window_cols=1,
+ row_stride=1,
+ col_stride=1,
+ padding="SAME",
+ data_format=data_format,
+ use_gpu=use_gpu)
def _testMaxPoolGradGradSamePadding2_1(self, data_format, use_gpu):
- self._ConstructAndTestSecondGradient(
- nn_ops.max_pool,
- input_sizes=[2, 2, 4, 3],
- output_sizes=[2, 2, 4, 3],
- window_rows=2,
- window_cols=2,
- row_stride=1,
- col_stride=1,
- padding="SAME",
- data_format=data_format,
- use_gpu=use_gpu)
+ for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]:
+ self._ConstructAndTestSecondGradient(
+ pool_func,
+ input_sizes=[2, 2, 4, 3],
+ output_sizes=[2, 2, 4, 3],
+ window_rows=2,
+ window_cols=2,
+ row_stride=1,
+ col_stride=1,
+ padding="SAME",
+ data_format=data_format,
+ use_gpu=use_gpu)
def _testMaxPoolGradGradSamePadding2_2(self, data_format, use_gpu):
- self._ConstructAndTestSecondGradient(
- nn_ops.max_pool,
- input_sizes=[2, 2, 4, 3],
- output_sizes=[2, 1, 2, 3],
- window_rows=2,
- window_cols=2,
- row_stride=2,
- col_stride=2,
- padding="SAME",
- data_format=data_format,
- use_gpu=use_gpu)
+ for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]:
+ self._ConstructAndTestSecondGradient(
+ pool_func,
+ input_sizes=[2, 2, 4, 3],
+ output_sizes=[2, 1, 2, 3],
+ window_rows=2,
+ window_cols=2,
+ row_stride=2,
+ col_stride=2,
+ padding="SAME",
+ data_format=data_format,
+ use_gpu=use_gpu)
def _testMaxPoolGradGradSamePadding3_1(self, data_format, use_gpu):
- self._ConstructAndTestSecondGradient(
- nn_ops.max_pool,
- input_sizes=[1, 7, 7, 1],
- output_sizes=[1, 7, 7, 1],
- window_rows=3,
- window_cols=3,
- row_stride=1,
- col_stride=1,
- padding="SAME",
- data_format=data_format,
- use_gpu=use_gpu)
+ for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]:
+ self._ConstructAndTestSecondGradient(
+ pool_func,
+ input_sizes=[1, 7, 7, 1],
+ output_sizes=[1, 7, 7, 1],
+ window_rows=3,
+ window_cols=3,
+ row_stride=1,
+ col_stride=1,
+ padding="SAME",
+ data_format=data_format,
+ use_gpu=use_gpu)
def testMaxPoolGradGrad(self):
for (data_format, use_gpu) in GetTestConfigs():
diff --git a/tensorflow/python/ops/hidden_ops.txt b/tensorflow/python/ops/hidden_ops.txt
index ffa5dc4d62..eeaf418c8b 100644
--- a/tensorflow/python/ops/hidden_ops.txt
+++ b/tensorflow/python/ops/hidden_ops.txt
@@ -302,6 +302,7 @@ BiasAddV1
Relu6
AvgPool
MaxPool
+MaxPoolV2
Softmax
LogSoftmax
FractionalAvgPoolGrad
diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py
index 094757dac9..de302a2271 100644
--- a/tensorflow/python/ops/nn_grad.py
+++ b/tensorflow/python/ops/nn_grad.py
@@ -541,6 +541,19 @@ def _MaxPoolGrad(op, grad):
data_format=op.get_attr("data_format"))
+@ops.RegisterGradient("MaxPoolV2")
+def _MaxPoolGradV2(op, grad):
+ ksize = op.inputs[1]
+ strides = op.inputs[2]
+ return gen_nn_ops.max_pool_grad_v2(op.inputs[0],
+ op.outputs[0],
+ grad,
+ ksize,
+ strides,
+ padding=op.get_attr("padding"),
+ data_format=op.get_attr("data_format")), None, None
+
+
@ops.RegisterGradient("MaxPoolWithArgmax")
def _MaxPoolGradWithArgmax(op, grad, unused_argmax_grad):
return gen_nn_ops._max_pool_grad_with_argmax(op.inputs[0],
@@ -567,6 +580,24 @@ def _MaxPoolGradGrad(op, grad):
data_format=op.get_attr("data_format")))
+@ops.RegisterGradient("MaxPoolGradV2")
+def _MaxPoolGradGradV2(op, grad):
+ ksize = op.inputs[3]
+ strides = op.inputs[4]
+ return (array_ops.zeros(
+ shape=array_ops.shape(op.inputs[0]),
+ dtype=op.inputs[0].dtype), array_ops.zeros(
+ shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype),
+ gen_nn_ops.max_pool_grad_grad_v2(
+ op.inputs[0],
+ op.inputs[1],
+ grad,
+ ksize,
+ strides,
+ padding=op.get_attr("padding"),
+ data_format=op.get_attr("data_format")), None, None)
+
+
@ops.RegisterGradient("MaxPoolGradGrad")
def _MaxPoolGradGradGrad(op, grad):
return (array_ops.zeros(