diff options
author | 2018-08-23 13:45:34 -0700 | |
---|---|---|
committer | 2018-08-23 13:50:19 -0700 | |
commit | c0d88a5b6c81fc4651cdb9678c6bc9608139e256 (patch) | |
tree | 424840825fcfaa99b98759cfed8a09ac3d79f802 | |
parent | 9a774e4d2d31443ea694938bec41237b4d6bcf02 (diff) |
[TF:XLA] Use newly-added pooling library AvgPoolGrad in tf2xla
PiperOrigin-RevId: 209992284
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/pooling_ops.cc | 144 |
1 files changed, 22 insertions, 122 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index d4d180aff8..f6f158a73b 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -199,59 +199,6 @@ class MaxPool3DOp : public MaxPoolOp { }; REGISTER_XLA_OP(Name("MaxPool3D"), MaxPool3DOp); -// Divide each element of an image by the count of elements that contributed to -// that element during pooling. -static xla::XlaOp AvgPoolDivideByCount( - XlaOpKernelContext* ctx, const xla::XlaOp& output, DataType dtype, - const TensorShape& input_shape, xla::Padding padding, - const std::vector<int64>& ksize, const std::vector<int64>& stride, - int num_spatial_dims, TensorFormat data_format) { - if (padding == xla::Padding::kValid) { - // In VALID padding, all windows have the same number of elements - // contributing to each average. Divide by the window size everywhere to - // get the average. - int64 window_size = std::accumulate(ksize.begin(), ksize.end(), 1, - [](int64 a, int64 b) { return a * b; }); - - auto divisor = - XlaHelpers::IntegerLiteral(ctx->builder(), dtype, window_size); - return xla::Div(output, divisor); - } else { - // For SAME padding, the padding shouldn't be included in the - // counts. We use another ReduceWindow to find the right counts. - - // TODO(phawkins): use a less brute-force way to compute this. Only - // the boundary regions will have interesting values here. - - std::vector<int64> input_dim_sizes(num_spatial_dims); - std::vector<int64> window_dims(num_spatial_dims); - std::vector<int64> window_ksize(num_spatial_dims); - std::vector<int64> window_stride(num_spatial_dims); - for (int i = 0; i < num_spatial_dims; ++i) { - int dim = GetTensorSpatialDimIndex(num_spatial_dims + 2, data_format, i); - input_dim_sizes[i] = input_shape.dim_size(dim); - window_dims[i] = dim; - window_ksize[i] = ksize[dim]; - window_stride[i] = stride[dim]; - } - - // Build a matrix of all 1s, with the same width/height as the input. - const DataType accumulation_type = XlaHelpers::SumAccumulationType(dtype); - auto ones = xla::Broadcast( - XlaHelpers::One(ctx->builder(), accumulation_type), input_dim_sizes); - - // Perform a ReduceWindow with the same window size, strides, and padding - // to count the number of contributions to each result element. - auto reduce = xla::ReduceWindow( - ones, XlaHelpers::Zero(ctx->builder(), accumulation_type), - *ctx->GetOrCreateAdd(accumulation_type), window_ksize, window_stride, - xla::Padding::kSame); - auto counts = XlaHelpers::ConvertElementType(ctx->builder(), reduce, dtype); - - return xla::Div(output, counts, window_dims); - } -} - class AvgPoolOp : public PoolingOp { public: AvgPoolOp(OpKernelConstruction* ctx, int num_spatial_dims) @@ -463,78 +410,31 @@ class AvgPoolGradOp : public XlaOpKernel { errors::InvalidArgument("out_backprop must be ", num_dims(), "-dimensional")); - int depth_dim = GetTensorFeatureDimIndex(num_dims(), data_format_); - int64 depth = out_backprop_shape.dim_size(depth_dim); - - // We can think of average-pooling as: - // * a convolution with a kernel consisting entirely of 1s, where the - // input feature and output feature are equal, and 0s everywhere else. - // * followed by dividing by the counts. - // - // This then gives us an algorithm to build the gradient: - // * divide out_backprop by the counts, followed by - // * Conv2DBackpropInput specialized for that kernel, which simplifies to - // a Pad and a ReduceWindow. - // - // For an explanation of backpropagation for convolution, see the comments - // in third_party/tensorflow/core/kernels/conv_grad_ops.h - - // TF filter shape is [ H, W, ..., inC, outC ] - std::vector<int64> filter_dims(num_dims()); - for (int i = 0; i < num_spatial_dims_; ++i) { - int dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); - filter_dims[i] = ksize_[dim]; - } - filter_dims[num_dims() - 2] = depth; - filter_dims[num_dims() - 1] = depth; - TensorShape filter_shape(filter_dims); - - // Reuse the logic from Conv2DBackpropInput to compute padding. - ConvBackpropDimensions dims; - OP_REQUIRES_OK( - ctx, ConvBackpropComputeDimensions( - type_string(), /*num_spatial_dims=*/num_spatial_dims_, - gradients_shape, filter_shape, out_backprop_shape, stride_, - padding_, data_format_, &dims)); - - // The input gradients are computed by a convolution of the output gradients - // and the filter, with some appropriate padding. See the comment at the top - // of conv_grad_ops.h for details. - xla::XlaBuilder* const b = ctx->builder(); auto out_backprop = ctx->Input(1); - auto dtype = input_type(1); + std::vector<int64> stride_int64s(stride_.begin(), stride_.end()); xla::Padding xla_padding = (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame; - - // Divide the out_backprop values by the counts for each spatial position. - std::vector<int64> stride_int64s(stride_.begin(), stride_.end()); - auto out_backprop_div = AvgPoolDivideByCount( - ctx, out_backprop, dtype, gradients_shape, xla_padding, ksize_, - stride_int64s, num_spatial_dims_, data_format_); - - // Pad the gradients in the spatial dimensions. We use the same padding - // as Conv2DBackpropInput. - xla::PaddingConfig padding_config = xla::MakeNoPaddingConfig(num_dims()); - for (int i = 0; i < num_spatial_dims_; ++i) { - int dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); - auto* padding = padding_config.mutable_dimensions(dim); - padding->set_edge_padding_low(dims.spatial_dims[i].pad_before); - padding->set_edge_padding_high(dims.spatial_dims[i].pad_after); - padding->set_interior_padding(dims.spatial_dims[i].stride - 1); - } - - auto zero = XlaHelpers::Zero(b, dtype); - auto padded_gradients = xla::Pad(out_backprop_div, zero, padding_config); - - // in_backprop = padded_gradients <conv> ones - std::vector<int64> ones(num_dims(), 1LL); - auto accumulation_type = XlaHelpers::SumAccumulationType(dtype); - auto in_backprop = xla::ReduceWindow( - XlaHelpers::ConvertElementType(b, padded_gradients, accumulation_type), - XlaHelpers::Zero(b, accumulation_type), - *ctx->GetOrCreateAdd(accumulation_type), ksize_, - /* window_strides=*/ones, xla::Padding::kValid); - ctx->SetOutput(0, XlaHelpers::ConvertElementType(b, in_backprop, dtype)); + xla::PrimitiveType xla_reduction_type; + auto reduction_type = XlaHelpers::SumAccumulationType(ctx->input_type(1)); + OP_REQUIRES_OK( + ctx, DataTypeToPrimitiveType(reduction_type, &xla_reduction_type)); + auto converted_out_backprop = + xla::ConvertElementType(out_backprop, xla_reduction_type); + auto xla_data_format = + XlaTensorFormat(data_format_, gradients_shape.dims() - 2); + auto padding_values = + MakeSpatialPadding(gradients_shape.dim_sizes(), ksize_, stride_int64s, + xla_padding, xla_data_format); + auto in_backprop = + xla::AvgPoolGrad(converted_out_backprop, gradients_shape.dim_sizes(), + ksize_, stride_int64s, padding_values, xla_data_format, + /*counts_include_padding=*/padding_ == VALID); + // Convert the pooling result back to the input type before returning it. + xla::PrimitiveType xla_out_backprop_type; + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(ctx->input_type(1), + &xla_out_backprop_type)); + ctx->SetOutput(0, + xla::ConvertElementType(in_backprop, xla_out_backprop_type)); } protected: |