aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-23 13:45:34 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-23 13:50:19 -0700
commitc0d88a5b6c81fc4651cdb9678c6bc9608139e256 (patch)
tree424840825fcfaa99b98759cfed8a09ac3d79f802
parent9a774e4d2d31443ea694938bec41237b4d6bcf02 (diff)
[TF:XLA] Use newly-added pooling library AvgPoolGrad in tf2xla
PiperOrigin-RevId: 209992284
-rw-r--r--tensorflow/compiler/tf2xla/kernels/pooling_ops.cc144
1 files changed, 22 insertions, 122 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
index d4d180aff8..f6f158a73b 100644
--- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
@@ -199,59 +199,6 @@ class MaxPool3DOp : public MaxPoolOp {
};
REGISTER_XLA_OP(Name("MaxPool3D"), MaxPool3DOp);
-// Divide each element of an image by the count of elements that contributed to
-// that element during pooling.
-static xla::XlaOp AvgPoolDivideByCount(
- XlaOpKernelContext* ctx, const xla::XlaOp& output, DataType dtype,
- const TensorShape& input_shape, xla::Padding padding,
- const std::vector<int64>& ksize, const std::vector<int64>& stride,
- int num_spatial_dims, TensorFormat data_format) {
- if (padding == xla::Padding::kValid) {
- // In VALID padding, all windows have the same number of elements
- // contributing to each average. Divide by the window size everywhere to
- // get the average.
- int64 window_size = std::accumulate(ksize.begin(), ksize.end(), 1,
- [](int64 a, int64 b) { return a * b; });
-
- auto divisor =
- XlaHelpers::IntegerLiteral(ctx->builder(), dtype, window_size);
- return xla::Div(output, divisor);
- } else {
- // For SAME padding, the padding shouldn't be included in the
- // counts. We use another ReduceWindow to find the right counts.
-
- // TODO(phawkins): use a less brute-force way to compute this. Only
- // the boundary regions will have interesting values here.
-
- std::vector<int64> input_dim_sizes(num_spatial_dims);
- std::vector<int64> window_dims(num_spatial_dims);
- std::vector<int64> window_ksize(num_spatial_dims);
- std::vector<int64> window_stride(num_spatial_dims);
- for (int i = 0; i < num_spatial_dims; ++i) {
- int dim = GetTensorSpatialDimIndex(num_spatial_dims + 2, data_format, i);
- input_dim_sizes[i] = input_shape.dim_size(dim);
- window_dims[i] = dim;
- window_ksize[i] = ksize[dim];
- window_stride[i] = stride[dim];
- }
-
- // Build a matrix of all 1s, with the same width/height as the input.
- const DataType accumulation_type = XlaHelpers::SumAccumulationType(dtype);
- auto ones = xla::Broadcast(
- XlaHelpers::One(ctx->builder(), accumulation_type), input_dim_sizes);
-
- // Perform a ReduceWindow with the same window size, strides, and padding
- // to count the number of contributions to each result element.
- auto reduce = xla::ReduceWindow(
- ones, XlaHelpers::Zero(ctx->builder(), accumulation_type),
- *ctx->GetOrCreateAdd(accumulation_type), window_ksize, window_stride,
- xla::Padding::kSame);
- auto counts = XlaHelpers::ConvertElementType(ctx->builder(), reduce, dtype);
-
- return xla::Div(output, counts, window_dims);
- }
-}
-
class AvgPoolOp : public PoolingOp {
public:
AvgPoolOp(OpKernelConstruction* ctx, int num_spatial_dims)
@@ -463,78 +410,31 @@ class AvgPoolGradOp : public XlaOpKernel {
errors::InvalidArgument("out_backprop must be ", num_dims(),
"-dimensional"));
- int depth_dim = GetTensorFeatureDimIndex(num_dims(), data_format_);
- int64 depth = out_backprop_shape.dim_size(depth_dim);
-
- // We can think of average-pooling as:
- // * a convolution with a kernel consisting entirely of 1s, where the
- // input feature and output feature are equal, and 0s everywhere else.
- // * followed by dividing by the counts.
- //
- // This then gives us an algorithm to build the gradient:
- // * divide out_backprop by the counts, followed by
- // * Conv2DBackpropInput specialized for that kernel, which simplifies to
- // a Pad and a ReduceWindow.
- //
- // For an explanation of backpropagation for convolution, see the comments
- // in third_party/tensorflow/core/kernels/conv_grad_ops.h
-
- // TF filter shape is [ H, W, ..., inC, outC ]
- std::vector<int64> filter_dims(num_dims());
- for (int i = 0; i < num_spatial_dims_; ++i) {
- int dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
- filter_dims[i] = ksize_[dim];
- }
- filter_dims[num_dims() - 2] = depth;
- filter_dims[num_dims() - 1] = depth;
- TensorShape filter_shape(filter_dims);
-
- // Reuse the logic from Conv2DBackpropInput to compute padding.
- ConvBackpropDimensions dims;
- OP_REQUIRES_OK(
- ctx, ConvBackpropComputeDimensions(
- type_string(), /*num_spatial_dims=*/num_spatial_dims_,
- gradients_shape, filter_shape, out_backprop_shape, stride_,
- padding_, data_format_, &dims));
-
- // The input gradients are computed by a convolution of the output gradients
- // and the filter, with some appropriate padding. See the comment at the top
- // of conv_grad_ops.h for details.
- xla::XlaBuilder* const b = ctx->builder();
auto out_backprop = ctx->Input(1);
- auto dtype = input_type(1);
+ std::vector<int64> stride_int64s(stride_.begin(), stride_.end());
xla::Padding xla_padding =
(padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
-
- // Divide the out_backprop values by the counts for each spatial position.
- std::vector<int64> stride_int64s(stride_.begin(), stride_.end());
- auto out_backprop_div = AvgPoolDivideByCount(
- ctx, out_backprop, dtype, gradients_shape, xla_padding, ksize_,
- stride_int64s, num_spatial_dims_, data_format_);
-
- // Pad the gradients in the spatial dimensions. We use the same padding
- // as Conv2DBackpropInput.
- xla::PaddingConfig padding_config = xla::MakeNoPaddingConfig(num_dims());
- for (int i = 0; i < num_spatial_dims_; ++i) {
- int dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
- auto* padding = padding_config.mutable_dimensions(dim);
- padding->set_edge_padding_low(dims.spatial_dims[i].pad_before);
- padding->set_edge_padding_high(dims.spatial_dims[i].pad_after);
- padding->set_interior_padding(dims.spatial_dims[i].stride - 1);
- }
-
- auto zero = XlaHelpers::Zero(b, dtype);
- auto padded_gradients = xla::Pad(out_backprop_div, zero, padding_config);
-
- // in_backprop = padded_gradients <conv> ones
- std::vector<int64> ones(num_dims(), 1LL);
- auto accumulation_type = XlaHelpers::SumAccumulationType(dtype);
- auto in_backprop = xla::ReduceWindow(
- XlaHelpers::ConvertElementType(b, padded_gradients, accumulation_type),
- XlaHelpers::Zero(b, accumulation_type),
- *ctx->GetOrCreateAdd(accumulation_type), ksize_,
- /* window_strides=*/ones, xla::Padding::kValid);
- ctx->SetOutput(0, XlaHelpers::ConvertElementType(b, in_backprop, dtype));
+ xla::PrimitiveType xla_reduction_type;
+ auto reduction_type = XlaHelpers::SumAccumulationType(ctx->input_type(1));
+ OP_REQUIRES_OK(
+ ctx, DataTypeToPrimitiveType(reduction_type, &xla_reduction_type));
+ auto converted_out_backprop =
+ xla::ConvertElementType(out_backprop, xla_reduction_type);
+ auto xla_data_format =
+ XlaTensorFormat(data_format_, gradients_shape.dims() - 2);
+ auto padding_values =
+ MakeSpatialPadding(gradients_shape.dim_sizes(), ksize_, stride_int64s,
+ xla_padding, xla_data_format);
+ auto in_backprop =
+ xla::AvgPoolGrad(converted_out_backprop, gradients_shape.dim_sizes(),
+ ksize_, stride_int64s, padding_values, xla_data_format,
+ /*counts_include_padding=*/padding_ == VALID);
+ // Convert the pooling result back to the input type before returning it.
+ xla::PrimitiveType xla_out_backprop_type;
+ OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(ctx->input_type(1),
+ &xla_out_backprop_type));
+ ctx->SetOutput(0,
+ xla::ConvertElementType(in_backprop, xla_out_backprop_type));
}
protected: