aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc509
1 files changed, 509 insertions, 0 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc
new file mode 100644
index 0000000000..c9a1be4940
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc
@@ -0,0 +1,509 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// XLA-specific Ops for 2D convolution.
+
+#include "tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h"
+#include "absl/types/span.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/lib/numeric.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_slice.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/conv_grad_ops.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/util/padding.h"
+#include "tensorflow/core/util/tensor_format.h"
+
+namespace tensorflow {
+namespace {
+
+// Returns the expanded size of a filter used for depthwise convolution.
+// If `shape` is [H, W, ..., M, N] returns [H, W, ..., M, M*N].
+xla::Shape ExpandedFilterShapeForDepthwiseConvolution(const xla::Shape& shape) {
+ int num_dims = shape.dimensions_size();
+ CHECK_GE(num_dims, 2); // Crash OK
+ xla::Shape expanded_shape = shape;
+ expanded_shape.set_dimensions(
+ num_dims - 1,
+ shape.dimensions(num_dims - 2) * shape.dimensions(num_dims - 1));
+ return expanded_shape;
+}
+
+// Create a mask for depthwise convolution that will make a normal convolution
+// produce the same results as a depthwise convolution. For a [2, 2, 3, 2]
+// depthwise filter this returns a [2, 2, 3, 6] tensor
+// 1 1 0 0 0 0 1 1 0 0 0 0
+// 0 0 1 1 0 0 0 0 1 1 0 0
+// 0 0 0 0 1 1 0 0 0 0 1 1
+//
+// 1 1 0 0 0 0 1 1 0 0 0 0
+// 0 0 1 1 0 0 0 0 1 1 0 0
+// 0 0 0 0 1 1 0 0 0 0 1 1
+//
+// The first step is to create a one tensor, A, that is [3]
+// 0 1 2
+//
+// and another tensor, B, that is [3 * 2]
+// 0 1 2 3 4 5
+//
+// and divide B it by 2 to get
+// 0 0 1 1 2 2
+//
+// then we broadcast the B to [2, 2, 3, 3 * 2]
+// 0 0 1 1 2 2 0 0 1 1 2 2
+// 0 0 1 1 2 2 0 0 1 1 2 2
+// 0 0 1 1 2 2 0 0 1 1 2 2
+//
+// 0 0 1 1 2 2 0 0 1 1 2 2
+// 0 0 1 1 2 2 0 0 1 1 2 2
+// 0 0 1 1 2 2 0 0 1 1 2 2
+//
+// Finally compare A and broadcasted B in dimension 2 amd return the result at
+// the beginning of the comment.
+xla::XlaOp CreateExpandedFilterMask(const xla::Shape& filter_shape,
+ xla::XlaBuilder* builder) {
+ xla::Shape expanded_filter_shape =
+ ExpandedFilterShapeForDepthwiseConvolution(filter_shape);
+ int64 depthwise_multiplier =
+ filter_shape.dimensions(filter_shape.dimensions_size() - 1);
+ int64 input_feature =
+ filter_shape.dimensions(filter_shape.dimensions_size() - 2);
+
+ // Create a M sized linspace and an M*N sized linspace that will be
+ // broadcasted into perpendicular dimensions and compared.
+ xla::XlaOp input_feature_iota = xla::Iota(builder, xla::S32, input_feature);
+ xla::XlaOp expanded_feature_iota =
+ xla::Iota(builder, xla::S32, input_feature * depthwise_multiplier);
+
+ // Divide the M*N sized linspace by the depthwise_multiplier to create
+ // [0 0 1 1 2 2] in the example in the function comment.
+ expanded_feature_iota =
+ xla::Div(expanded_feature_iota,
+ XlaHelpers::IntegerLiteral(builder, DataType::DT_INT32,
+ depthwise_multiplier));
+
+ // Broadcast the N*M linspace to [H, W, ..., M, M*N].
+ std::vector<int64> expanded_feature_broadcast_dims(
+ expanded_filter_shape.dimensions().begin(),
+ expanded_filter_shape.dimensions().end());
+ expanded_feature_broadcast_dims.pop_back();
+ auto broadcasted_expanded_feature_iota =
+ xla::Broadcast(expanded_feature_iota, expanded_feature_broadcast_dims);
+
+ // Compare the broadcasted linspace to the input feature linspace in the
+ // input feature dimension to create a diagonal predicate.
+ return xla::Eq(broadcasted_expanded_feature_iota, input_feature_iota,
+ {expanded_filter_shape.dimensions_size() - 2});
+}
+
+// Reshapes a filter of shape [H, W, ..., M, N] to [H, W, ..., 1, M*N]. Used to
+// build a depthwise convolution.
+xla::XlaOp ReshapeFilterForDepthwiseConvolution(const xla::Shape& filter_shape,
+ const xla::XlaOp& filter) {
+ int64 input_feature_dim = filter_shape.dimensions_size() - 2;
+ int64 output_feature_dim = filter_shape.dimensions_size() - 1;
+ int64 depthwise_multiplier = filter_shape.dimensions(output_feature_dim);
+ int64 input_feature = filter_shape.dimensions(input_feature_dim);
+
+ // Create a [H, W, ..., 1, N*M] reshape of the filter.
+ xla::Shape implicit_broadcast_filter_shape = filter_shape;
+ implicit_broadcast_filter_shape.set_dimensions(input_feature_dim, 1);
+ implicit_broadcast_filter_shape.set_dimensions(
+ output_feature_dim, depthwise_multiplier * input_feature);
+ return xla::Reshape(
+ filter, xla::AsInt64Slice(implicit_broadcast_filter_shape.dimensions()));
+}
+
+// Reduces the results of the convolution with an expanded filter to the
+// non-expanded filter.
+xla::XlaOp ContractFilterForDepthwiseBackprop(const xla::Shape& filter_shape,
+ const xla::XlaOp& filter_backprop,
+ xla::XlaBuilder* builder) {
+ auto masked_expanded_filter =
+ xla::Select(CreateExpandedFilterMask(filter_shape, builder),
+ filter_backprop, xla::ZerosLike(filter_backprop));
+
+ auto elem_type = filter_shape.element_type();
+ return xla::Reshape(
+ // This reduce does not need inputs to be converted with
+ // XlaHelpers::SumAccumulationType() since the select above guarantees
+ // that only one element is non zero, so there cannot be accumulated
+ // precision error.
+ xla::Reduce(masked_expanded_filter, xla::Zero(builder, elem_type),
+ CreateScalarAddComputation(elem_type, builder),
+ {filter_shape.dimensions_size() - 2}),
+ xla::AsInt64Slice(filter_shape.dimensions()));
+}
+
+// Performs some basic checks on ConvOpAttrs that are true for all kinds of XLA
+// convolutions (as currently implemented).
+Status CheckConvAttrs(const ConvOpAttrs& attrs) {
+ const int num_dims = attrs.num_spatial_dims + 2;
+ if (attrs.strides.size() != num_dims) {
+ return errors::InvalidArgument("Sliding window strides field must specify ",
+ num_dims, " dimensions");
+ }
+ int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format);
+ int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format);
+ if (attrs.strides[batch_dim] != 1 || attrs.strides[feature_dim] != 1) {
+ return errors::Unimplemented(
+ "Current implementation does not yet support strides in the batch and "
+ "depth dimensions.");
+ }
+ if (attrs.dilations.size() != num_dims) {
+ return errors::InvalidArgument("Dilations field must specify ", num_dims,
+ " dimensions");
+ }
+ if (attrs.dilations[batch_dim] != 1 || attrs.dilations[feature_dim] != 1) {
+ return errors::Unimplemented(
+ "Current implementation does not support dilations in the batch and "
+ "depth dimensions.");
+ }
+ for (int i = 0; i < attrs.num_spatial_dims; ++i) {
+ int input_dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
+ if (attrs.dilations[input_dim] < 1) {
+ return errors::Unimplemented("Dilation values must be positive; ", i,
+ "th spatial dimension had dilation ",
+ attrs.dilations[input_dim]);
+ }
+ }
+ return Status::OK();
+}
+
+// Wrapper around ConvBackpropComputeDimensions that converts from XLA shapes
+// to TensorShapes.
+Status ConvBackpropComputeDimensionsV2XlaShapes(
+ StringPiece label, int num_spatial_dims, const xla::Shape& input_shape,
+ const xla::Shape& filter_shape, const xla::Shape& out_backprop_shape,
+ absl::Span<const int32> dilations, const std::vector<int32>& strides,
+ Padding padding, TensorFormat data_format, ConvBackpropDimensions* dims) {
+ TensorShape input_tensor_shape, filter_tensor_shape,
+ out_backprop_tensor_shape;
+ TF_RETURN_IF_ERROR(XLAShapeToTensorShape(input_shape, &input_tensor_shape));
+ TF_RETURN_IF_ERROR(XLAShapeToTensorShape(filter_shape, &filter_tensor_shape));
+ TF_RETURN_IF_ERROR(
+ XLAShapeToTensorShape(out_backprop_shape, &out_backprop_tensor_shape));
+ return ConvBackpropComputeDimensionsV2(
+ label, num_spatial_dims, input_tensor_shape, filter_tensor_shape,
+ out_backprop_tensor_shape, dilations, strides, padding, data_format,
+ dims);
+}
+
+} // anonymous namespace
+
+xla::StatusOr<ConvOpAttrs> ConvOpAttrs::Create(int num_spatial_dims,
+ bool depthwise,
+ OpKernelConstruction* ctx) {
+ ConvOpAttrs attrs;
+ attrs.num_spatial_dims = num_spatial_dims;
+ attrs.depthwise = depthwise;
+ TF_RETURN_IF_ERROR(ctx->GetAttr("dilations", &attrs.dilations));
+ TF_RETURN_IF_ERROR(ctx->GetAttr("strides", &attrs.strides));
+ TF_RETURN_IF_ERROR(ctx->GetAttr("padding", &attrs.padding));
+
+ string data_format;
+ TF_RETURN_IF_ERROR(ctx->GetAttr("data_format", &data_format));
+ if (!FormatFromString(data_format, &attrs.data_format)) {
+ return errors::InvalidArgument("Invalid data format: ", data_format);
+ }
+
+ return attrs;
+}
+
+xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(StringPiece /*type_string*/,
+ xla::XlaOp conv_input,
+ xla::XlaOp filter,
+ const ConvOpAttrs& attrs) {
+ TF_RETURN_IF_ERROR(CheckConvAttrs(attrs));
+
+ auto* builder = conv_input.builder();
+ TF_ASSIGN_OR_RETURN(xla::Shape input_shape, builder->GetShape(conv_input));
+ // Filter has the form [filter_rows, filter_cols, ..., in_depth, out_depth]
+ TF_ASSIGN_OR_RETURN(xla::Shape filter_shape, builder->GetShape(filter));
+
+ // For 2D convolution, there should be 4 dimensions.
+ int num_dims = attrs.num_spatial_dims + 2;
+ if (input_shape.dimensions_size() != num_dims) {
+ return errors::InvalidArgument("input must be ", num_dims, "-dimensional",
+ input_shape.DebugString());
+ }
+ if (filter_shape.dimensions_size() != num_dims) {
+ return errors::InvalidArgument(
+ "filter must be ", num_dims,
+ "-dimensional: ", filter_shape.DebugString());
+ }
+
+ // The last two dimensions of the filter are the input and output shapes.
+ int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format);
+ int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format);
+
+ int64 in_depth = filter_shape.dimensions(attrs.num_spatial_dims);
+ // The 'C' dimension for input is in_depth. It must be the same as
+ // the filter's in_depth.
+ if (in_depth != input_shape.dimensions(feature_dim)) {
+ return errors::InvalidArgument(
+ "input and filter must have the same depth: ", in_depth, " vs ",
+ input_shape.dimensions(feature_dim));
+ }
+
+ if (attrs.depthwise) {
+ filter = ReshapeFilterForDepthwiseConvolution(filter_shape, filter);
+ }
+
+ xla::ConvolutionDimensionNumbers dims;
+ std::vector<int64> window_strides(attrs.num_spatial_dims);
+ std::vector<int64> lhs_dilation(attrs.num_spatial_dims, 1);
+ std::vector<int64> rhs_dilation(attrs.num_spatial_dims);
+ std::vector<std::pair<int64, int64>> padding(attrs.num_spatial_dims);
+
+ dims.set_input_batch_dimension(batch_dim);
+ dims.set_output_batch_dimension(batch_dim);
+ dims.set_input_feature_dimension(feature_dim);
+ dims.set_output_feature_dimension(feature_dim);
+ dims.set_kernel_input_feature_dimension(attrs.num_spatial_dims);
+ dims.set_kernel_output_feature_dimension(attrs.num_spatial_dims + 1);
+
+ for (int i = 0; i < attrs.num_spatial_dims; ++i) {
+ const int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
+ dims.add_input_spatial_dimensions(dim);
+ dims.add_kernel_spatial_dimensions(i);
+ dims.add_output_spatial_dimensions(dim);
+ window_strides[i] = attrs.strides.at(dim);
+ rhs_dilation[i] = attrs.dilations.at(dim);
+
+ int64 unused_output_size;
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerboseV2(
+ input_shape.dimensions(dim), filter_shape.dimensions(i),
+ rhs_dilation[i], window_strides[i], attrs.padding, &unused_output_size,
+ &padding[i].first, &padding[i].second));
+ }
+
+ return xla::ConvGeneralDilated(
+ conv_input, filter, window_strides, padding, lhs_dilation, rhs_dilation,
+ dims, /*feature_group_count=*/attrs.depthwise ? in_depth : 1);
+}
+
+xla::StatusOr<xla::XlaOp> MakeXlaBackpropInputConvOp(
+ StringPiece type_string, const xla::Shape& input_shape, xla::XlaOp filter,
+ xla::XlaOp out_backprop, const ConvOpAttrs& attrs) {
+ TF_RETURN_IF_ERROR(CheckConvAttrs(attrs));
+
+ int num_dims = attrs.num_spatial_dims + 2;
+ int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format);
+ int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format);
+
+ auto* builder = filter.builder();
+ TF_ASSIGN_OR_RETURN(xla::Shape filter_shape, builder->GetShape(filter));
+ TF_ASSIGN_OR_RETURN(xla::Shape out_backprop_shape,
+ builder->GetShape(out_backprop));
+
+ xla::Shape expanded_filter_shape =
+ attrs.depthwise ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape)
+ : filter_shape;
+ // Reuse dimension computation logic from conv_grad_ops.cc.
+ ConvBackpropDimensions dims;
+ TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes(
+ type_string, attrs.num_spatial_dims, input_shape, expanded_filter_shape,
+ out_backprop_shape, attrs.dilations, attrs.strides, attrs.padding,
+ attrs.data_format, &dims));
+
+ // The input gradients are computed by a convolution of the output
+ // gradients and the filter, with some appropriate padding. See the
+ // comment at the top of conv_grad_ops.h for details.
+
+ xla::ConvolutionDimensionNumbers dnums;
+ dnums.set_input_batch_dimension(batch_dim);
+ dnums.set_output_batch_dimension(batch_dim);
+ dnums.set_input_feature_dimension(feature_dim);
+ dnums.set_output_feature_dimension(feature_dim);
+
+ // TF filter shape is [ H, W, ..., inC, outC ]
+ // Transpose the input and output features for computing the gradient.
+ dnums.set_kernel_input_feature_dimension(attrs.num_spatial_dims + 1);
+ dnums.set_kernel_output_feature_dimension(attrs.num_spatial_dims);
+
+ std::vector<int64> kernel_spatial_dims(attrs.num_spatial_dims);
+ std::vector<std::pair<int64, int64>> padding(attrs.num_spatial_dims);
+ std::vector<int64> lhs_dilation(attrs.num_spatial_dims);
+ std::vector<int64> rhs_dilation(attrs.num_spatial_dims);
+ std::vector<int64> ones(attrs.num_spatial_dims, 1);
+ for (int i = 0; i < attrs.num_spatial_dims; ++i) {
+ int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
+ dnums.add_input_spatial_dimensions(dim);
+ dnums.add_kernel_spatial_dimensions(i);
+ dnums.add_output_spatial_dimensions(dim);
+
+ kernel_spatial_dims[i] = i;
+ padding[i] = {dims.spatial_dims[i].pad_before,
+ dims.spatial_dims[i].pad_after};
+ lhs_dilation[i] = dims.spatial_dims[i].stride;
+ rhs_dilation[i] = attrs.dilations[dim];
+ }
+
+ // Mirror the filter in the spatial dimensions.
+ xla::XlaOp mirrored_weights = xla::Rev(filter, kernel_spatial_dims);
+
+ // activation gradients
+ // = gradients (with padding and dilation) <conv> mirrored_weights
+ return xla::ConvGeneralDilated(
+ out_backprop, mirrored_weights, /*window_strides=*/ones, padding,
+ lhs_dilation, rhs_dilation, dnums,
+ /*feature_group_count=*/
+ attrs.depthwise ? out_backprop_shape.dimensions(feature_dim) /
+ filter_shape.dimensions(attrs.num_spatial_dims + 1)
+ : 1);
+}
+
+xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp(
+ StringPiece type_string, xla::XlaOp activations,
+ const xla::Shape& filter_shape, xla::XlaOp gradients,
+ const ConvOpAttrs& attrs) {
+ TF_RETURN_IF_ERROR(CheckConvAttrs(attrs));
+
+ auto* builder = activations.builder();
+ TF_ASSIGN_OR_RETURN(xla::Shape activations_shape,
+ builder->GetShape(activations));
+ TF_ASSIGN_OR_RETURN(xla::Shape out_backprop_shape,
+ builder->GetShape(gradients));
+ const xla::Shape expanded_filter_shape =
+ attrs.depthwise ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape)
+ : filter_shape;
+
+ // Reuse dimension computation logic from conv_grad_ops.cc.
+ ConvBackpropDimensions dims;
+ TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes(
+ type_string, attrs.num_spatial_dims, activations_shape,
+ expanded_filter_shape, out_backprop_shape, attrs.dilations, attrs.strides,
+ attrs.padding, attrs.data_format, &dims));
+
+ // The filter gradients are computed by a convolution of the input
+ // activations and the output gradients, with some appropriate padding.
+ // See the comment at the top of conv_grad_ops.h for details.
+
+ xla::ConvolutionDimensionNumbers dnums;
+
+ // The activations (inputs) form the LHS of the convolution.
+ // Activations have shape: [batch, in_rows, in_cols, ..., in_depth]
+ // For the gradient computation, we flip the roles of the batch and
+ // feature dimensions.
+ // Each spatial entry has size in_depth * batch
+
+ // The last two dimensions of the filter are the input and output shapes.
+ int num_dims = attrs.num_spatial_dims + 2;
+ int n_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format);
+ int c_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format);
+
+ // Swap n_dim and c_dim in the activations.
+ dnums.set_input_batch_dimension(c_dim);
+ dnums.set_input_feature_dimension(n_dim);
+
+ // The gradients become the RHS of the convolution.
+ // The gradients have shape [batch, out_rows, out_cols, ..., out_depth]
+ // where the batch becomes the input feature for the convolution.
+ dnums.set_kernel_input_feature_dimension(n_dim);
+ dnums.set_kernel_output_feature_dimension(c_dim);
+
+ std::vector<std::pair<int64, int64>> padding(attrs.num_spatial_dims);
+ std::vector<int64> rhs_dilation(attrs.num_spatial_dims);
+ std::vector<int64> window_strides(attrs.num_spatial_dims);
+ std::vector<int64> ones(attrs.num_spatial_dims, 1);
+
+ // Tensorflow filter shape is [ H, W, ..., inC, outC ].
+ for (int i = 0; i < attrs.num_spatial_dims; ++i) {
+ dnums.add_output_spatial_dimensions(i);
+ }
+ dnums.set_output_batch_dimension(attrs.num_spatial_dims);
+ dnums.set_output_feature_dimension(attrs.num_spatial_dims + 1);
+
+ for (int i = 0; i < attrs.num_spatial_dims; ++i) {
+ int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
+ dnums.add_input_spatial_dimensions(dim);
+ dnums.add_kernel_spatial_dimensions(dim);
+
+ // We will also need to pad the input with zeros such that after the
+ // convolution, we get the right size for the filter.
+ // The padded_in_rows should be such that when we convolve this with the
+ // expanded_out_rows as a filter, we should get filter_rows back.
+ //
+ const int64 padded_in_size =
+ dims.spatial_dims[i].expanded_output_size +
+ (dims.spatial_dims[i].filter_size - 1) * attrs.dilations[dim];
+
+ // However it can be smaller than input_rows: in this
+ // case it means some of the inputs are not used.
+ //
+ // An example is to have input_cols = 3, filter_cols = 2 and stride = 2:
+ //
+ // INPUT = [ A B C ]
+ //
+ // FILTER = [ x y ]
+ //
+ // and the output will only have one column: a = A * x + B * y
+ //
+ // and input "C" is not used at all.
+ //
+ // We apply negative padding in this case.
+ const int64 pad_total = padded_in_size - dims.spatial_dims[i].input_size;
+
+ // + For the VALID padding, we don't pad anything on the top/left side
+ // and pad the bottom/right side with the remaining space.
+ // + For the SAME padding, we pad top/left side the same as bottom/right
+ // side.
+ //
+ // In addition, if the padded input size is smaller than the input size,
+ // we need to ignore some training elements of the input. We do this by
+ // applying negative padding on the right/bottom.
+ const int64 pad_before =
+ attrs.padding == Padding::SAME ? std::max<int64>(pad_total / 2, 0) : 0;
+
+ padding[i] = {pad_before, pad_total - pad_before};
+ rhs_dilation[i] = dims.spatial_dims[i].stride;
+ window_strides[i] = attrs.dilations[dim];
+ }
+
+ // Besides padding the input, we will also expand output_rows to
+ // expanded_out_rows = (output_rows - 1) * stride + 1
+ // with zeros in between:
+ //
+ // a . . . b . . . c . . . d . . . e
+ //
+ // This is done by specifying the window dilation factors in the
+ // convolution HLO below.
+ auto filter_backprop =
+ xla::ConvGeneralDilated(activations, gradients, window_strides, padding,
+ /*lhs_dilation=*/ones, rhs_dilation, dnums);
+
+ if (attrs.depthwise) {
+ filter_backprop = ContractFilterForDepthwiseBackprop(
+ filter_shape, filter_backprop, activations.builder());
+ }
+
+ return filter_backprop;
+}
+
+} // namespace tensorflow