aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla
diff options
context:
space:
mode:
authorGravatar Justin Lebar <jlebar@google.com>2018-09-17 23:09:48 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-17 23:13:50 -0700
commit9cc7bbe5b476bec556d7dce235996a03775d7492 (patch)
tree7943f0d1eb95737bd2b3792facf1a39ec3e7d370 /tensorflow/compiler/tf2xla
parent7c826588b058c14fd8c152bedb4e256c57ae1248 (diff)
[XLA] Refactor conv_ops emitters to make them reusable.
PiperOrigin-RevId: 213398930
Diffstat (limited to 'tensorflow/compiler/tf2xla')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/BUILD22
-rw-r--r--tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc509
-rw-r--r--tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h69
-rw-r--r--tensorflow/compiler/tf2xla/kernels/conv_ops.cc551
-rw-r--r--tensorflow/compiler/tf2xla/shape_util.cc14
-rw-r--r--tensorflow/compiler/tf2xla/shape_util.h5
6 files changed, 661 insertions, 509 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index 46794f7b50..3e823254d3 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -113,6 +113,7 @@ tf_kernel_library(
"shape_util.h",
],
deps = [
+ ":conv_op_helpers",
":if_op",
":while_op",
"//tensorflow/compiler/tf2xla:common",
@@ -172,6 +173,27 @@ tf_kernel_library(
],
)
+cc_library(
+ name = "conv_op_helpers",
+ srcs = ["conv_op_helpers.cc"],
+ hdrs = ["conv_op_helpers.h"],
+ deps = [
+ "//tensorflow/compiler/tf2xla:common",
+ "//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla/client:xla_builder",
+ "//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/client/lib:constants",
+ "//tensorflow/compiler/xla/client/lib:numeric",
+ "//tensorflow/core:framework",
+ "//tensorflow/core/kernels:bounds_check",
+ "//tensorflow/core/kernels:conv_ops",
+ "//tensorflow/core/kernels:ops_util",
+ "@com_google_absl//absl/types:span",
+ ],
+)
+
tf_kernel_library(
name = "while_op",
srcs = ["while_op.cc"],
diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc
new file mode 100644
index 0000000000..c9a1be4940
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc
@@ -0,0 +1,509 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// XLA-specific Ops for 2D convolution.
+
+#include "tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h"
+#include "absl/types/span.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/lib/numeric.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_slice.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/conv_grad_ops.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/util/padding.h"
+#include "tensorflow/core/util/tensor_format.h"
+
+namespace tensorflow {
+namespace {
+
+// Returns the expanded size of a filter used for depthwise convolution.
+// If `shape` is [H, W, ..., M, N] returns [H, W, ..., M, M*N].
+xla::Shape ExpandedFilterShapeForDepthwiseConvolution(const xla::Shape& shape) {
+ int num_dims = shape.dimensions_size();
+ CHECK_GE(num_dims, 2); // Crash OK
+ xla::Shape expanded_shape = shape;
+ expanded_shape.set_dimensions(
+ num_dims - 1,
+ shape.dimensions(num_dims - 2) * shape.dimensions(num_dims - 1));
+ return expanded_shape;
+}
+
+// Create a mask for depthwise convolution that will make a normal convolution
+// produce the same results as a depthwise convolution. For a [2, 2, 3, 2]
+// depthwise filter this returns a [2, 2, 3, 6] tensor
+// 1 1 0 0 0 0 1 1 0 0 0 0
+// 0 0 1 1 0 0 0 0 1 1 0 0
+// 0 0 0 0 1 1 0 0 0 0 1 1
+//
+// 1 1 0 0 0 0 1 1 0 0 0 0
+// 0 0 1 1 0 0 0 0 1 1 0 0
+// 0 0 0 0 1 1 0 0 0 0 1 1
+//
+// The first step is to create a one tensor, A, that is [3]
+// 0 1 2
+//
+// and another tensor, B, that is [3 * 2]
+// 0 1 2 3 4 5
+//
+// and divide B it by 2 to get
+// 0 0 1 1 2 2
+//
+// then we broadcast the B to [2, 2, 3, 3 * 2]
+// 0 0 1 1 2 2 0 0 1 1 2 2
+// 0 0 1 1 2 2 0 0 1 1 2 2
+// 0 0 1 1 2 2 0 0 1 1 2 2
+//
+// 0 0 1 1 2 2 0 0 1 1 2 2
+// 0 0 1 1 2 2 0 0 1 1 2 2
+// 0 0 1 1 2 2 0 0 1 1 2 2
+//
+// Finally compare A and broadcasted B in dimension 2 amd return the result at
+// the beginning of the comment.
+xla::XlaOp CreateExpandedFilterMask(const xla::Shape& filter_shape,
+ xla::XlaBuilder* builder) {
+ xla::Shape expanded_filter_shape =
+ ExpandedFilterShapeForDepthwiseConvolution(filter_shape);
+ int64 depthwise_multiplier =
+ filter_shape.dimensions(filter_shape.dimensions_size() - 1);
+ int64 input_feature =
+ filter_shape.dimensions(filter_shape.dimensions_size() - 2);
+
+ // Create a M sized linspace and an M*N sized linspace that will be
+ // broadcasted into perpendicular dimensions and compared.
+ xla::XlaOp input_feature_iota = xla::Iota(builder, xla::S32, input_feature);
+ xla::XlaOp expanded_feature_iota =
+ xla::Iota(builder, xla::S32, input_feature * depthwise_multiplier);
+
+ // Divide the M*N sized linspace by the depthwise_multiplier to create
+ // [0 0 1 1 2 2] in the example in the function comment.
+ expanded_feature_iota =
+ xla::Div(expanded_feature_iota,
+ XlaHelpers::IntegerLiteral(builder, DataType::DT_INT32,
+ depthwise_multiplier));
+
+ // Broadcast the N*M linspace to [H, W, ..., M, M*N].
+ std::vector<int64> expanded_feature_broadcast_dims(
+ expanded_filter_shape.dimensions().begin(),
+ expanded_filter_shape.dimensions().end());
+ expanded_feature_broadcast_dims.pop_back();
+ auto broadcasted_expanded_feature_iota =
+ xla::Broadcast(expanded_feature_iota, expanded_feature_broadcast_dims);
+
+ // Compare the broadcasted linspace to the input feature linspace in the
+ // input feature dimension to create a diagonal predicate.
+ return xla::Eq(broadcasted_expanded_feature_iota, input_feature_iota,
+ {expanded_filter_shape.dimensions_size() - 2});
+}
+
+// Reshapes a filter of shape [H, W, ..., M, N] to [H, W, ..., 1, M*N]. Used to
+// build a depthwise convolution.
+xla::XlaOp ReshapeFilterForDepthwiseConvolution(const xla::Shape& filter_shape,
+ const xla::XlaOp& filter) {
+ int64 input_feature_dim = filter_shape.dimensions_size() - 2;
+ int64 output_feature_dim = filter_shape.dimensions_size() - 1;
+ int64 depthwise_multiplier = filter_shape.dimensions(output_feature_dim);
+ int64 input_feature = filter_shape.dimensions(input_feature_dim);
+
+ // Create a [H, W, ..., 1, N*M] reshape of the filter.
+ xla::Shape implicit_broadcast_filter_shape = filter_shape;
+ implicit_broadcast_filter_shape.set_dimensions(input_feature_dim, 1);
+ implicit_broadcast_filter_shape.set_dimensions(
+ output_feature_dim, depthwise_multiplier * input_feature);
+ return xla::Reshape(
+ filter, xla::AsInt64Slice(implicit_broadcast_filter_shape.dimensions()));
+}
+
+// Reduces the results of the convolution with an expanded filter to the
+// non-expanded filter.
+xla::XlaOp ContractFilterForDepthwiseBackprop(const xla::Shape& filter_shape,
+ const xla::XlaOp& filter_backprop,
+ xla::XlaBuilder* builder) {
+ auto masked_expanded_filter =
+ xla::Select(CreateExpandedFilterMask(filter_shape, builder),
+ filter_backprop, xla::ZerosLike(filter_backprop));
+
+ auto elem_type = filter_shape.element_type();
+ return xla::Reshape(
+ // This reduce does not need inputs to be converted with
+ // XlaHelpers::SumAccumulationType() since the select above guarantees
+ // that only one element is non zero, so there cannot be accumulated
+ // precision error.
+ xla::Reduce(masked_expanded_filter, xla::Zero(builder, elem_type),
+ CreateScalarAddComputation(elem_type, builder),
+ {filter_shape.dimensions_size() - 2}),
+ xla::AsInt64Slice(filter_shape.dimensions()));
+}
+
+// Performs some basic checks on ConvOpAttrs that are true for all kinds of XLA
+// convolutions (as currently implemented).
+Status CheckConvAttrs(const ConvOpAttrs& attrs) {
+ const int num_dims = attrs.num_spatial_dims + 2;
+ if (attrs.strides.size() != num_dims) {
+ return errors::InvalidArgument("Sliding window strides field must specify ",
+ num_dims, " dimensions");
+ }
+ int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format);
+ int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format);
+ if (attrs.strides[batch_dim] != 1 || attrs.strides[feature_dim] != 1) {
+ return errors::Unimplemented(
+ "Current implementation does not yet support strides in the batch and "
+ "depth dimensions.");
+ }
+ if (attrs.dilations.size() != num_dims) {
+ return errors::InvalidArgument("Dilations field must specify ", num_dims,
+ " dimensions");
+ }
+ if (attrs.dilations[batch_dim] != 1 || attrs.dilations[feature_dim] != 1) {
+ return errors::Unimplemented(
+ "Current implementation does not support dilations in the batch and "
+ "depth dimensions.");
+ }
+ for (int i = 0; i < attrs.num_spatial_dims; ++i) {
+ int input_dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
+ if (attrs.dilations[input_dim] < 1) {
+ return errors::Unimplemented("Dilation values must be positive; ", i,
+ "th spatial dimension had dilation ",
+ attrs.dilations[input_dim]);
+ }
+ }
+ return Status::OK();
+}
+
+// Wrapper around ConvBackpropComputeDimensions that converts from XLA shapes
+// to TensorShapes.
+Status ConvBackpropComputeDimensionsV2XlaShapes(
+ StringPiece label, int num_spatial_dims, const xla::Shape& input_shape,
+ const xla::Shape& filter_shape, const xla::Shape& out_backprop_shape,
+ absl::Span<const int32> dilations, const std::vector<int32>& strides,
+ Padding padding, TensorFormat data_format, ConvBackpropDimensions* dims) {
+ TensorShape input_tensor_shape, filter_tensor_shape,
+ out_backprop_tensor_shape;
+ TF_RETURN_IF_ERROR(XLAShapeToTensorShape(input_shape, &input_tensor_shape));
+ TF_RETURN_IF_ERROR(XLAShapeToTensorShape(filter_shape, &filter_tensor_shape));
+ TF_RETURN_IF_ERROR(
+ XLAShapeToTensorShape(out_backprop_shape, &out_backprop_tensor_shape));
+ return ConvBackpropComputeDimensionsV2(
+ label, num_spatial_dims, input_tensor_shape, filter_tensor_shape,
+ out_backprop_tensor_shape, dilations, strides, padding, data_format,
+ dims);
+}
+
+} // anonymous namespace
+
+xla::StatusOr<ConvOpAttrs> ConvOpAttrs::Create(int num_spatial_dims,
+ bool depthwise,
+ OpKernelConstruction* ctx) {
+ ConvOpAttrs attrs;
+ attrs.num_spatial_dims = num_spatial_dims;
+ attrs.depthwise = depthwise;
+ TF_RETURN_IF_ERROR(ctx->GetAttr("dilations", &attrs.dilations));
+ TF_RETURN_IF_ERROR(ctx->GetAttr("strides", &attrs.strides));
+ TF_RETURN_IF_ERROR(ctx->GetAttr("padding", &attrs.padding));
+
+ string data_format;
+ TF_RETURN_IF_ERROR(ctx->GetAttr("data_format", &data_format));
+ if (!FormatFromString(data_format, &attrs.data_format)) {
+ return errors::InvalidArgument("Invalid data format: ", data_format);
+ }
+
+ return attrs;
+}
+
+xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(StringPiece /*type_string*/,
+ xla::XlaOp conv_input,
+ xla::XlaOp filter,
+ const ConvOpAttrs& attrs) {
+ TF_RETURN_IF_ERROR(CheckConvAttrs(attrs));
+
+ auto* builder = conv_input.builder();
+ TF_ASSIGN_OR_RETURN(xla::Shape input_shape, builder->GetShape(conv_input));
+ // Filter has the form [filter_rows, filter_cols, ..., in_depth, out_depth]
+ TF_ASSIGN_OR_RETURN(xla::Shape filter_shape, builder->GetShape(filter));
+
+ // For 2D convolution, there should be 4 dimensions.
+ int num_dims = attrs.num_spatial_dims + 2;
+ if (input_shape.dimensions_size() != num_dims) {
+ return errors::InvalidArgument("input must be ", num_dims, "-dimensional",
+ input_shape.DebugString());
+ }
+ if (filter_shape.dimensions_size() != num_dims) {
+ return errors::InvalidArgument(
+ "filter must be ", num_dims,
+ "-dimensional: ", filter_shape.DebugString());
+ }
+
+ // The last two dimensions of the filter are the input and output shapes.
+ int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format);
+ int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format);
+
+ int64 in_depth = filter_shape.dimensions(attrs.num_spatial_dims);
+ // The 'C' dimension for input is in_depth. It must be the same as
+ // the filter's in_depth.
+ if (in_depth != input_shape.dimensions(feature_dim)) {
+ return errors::InvalidArgument(
+ "input and filter must have the same depth: ", in_depth, " vs ",
+ input_shape.dimensions(feature_dim));
+ }
+
+ if (attrs.depthwise) {
+ filter = ReshapeFilterForDepthwiseConvolution(filter_shape, filter);
+ }
+
+ xla::ConvolutionDimensionNumbers dims;
+ std::vector<int64> window_strides(attrs.num_spatial_dims);
+ std::vector<int64> lhs_dilation(attrs.num_spatial_dims, 1);
+ std::vector<int64> rhs_dilation(attrs.num_spatial_dims);
+ std::vector<std::pair<int64, int64>> padding(attrs.num_spatial_dims);
+
+ dims.set_input_batch_dimension(batch_dim);
+ dims.set_output_batch_dimension(batch_dim);
+ dims.set_input_feature_dimension(feature_dim);
+ dims.set_output_feature_dimension(feature_dim);
+ dims.set_kernel_input_feature_dimension(attrs.num_spatial_dims);
+ dims.set_kernel_output_feature_dimension(attrs.num_spatial_dims + 1);
+
+ for (int i = 0; i < attrs.num_spatial_dims; ++i) {
+ const int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
+ dims.add_input_spatial_dimensions(dim);
+ dims.add_kernel_spatial_dimensions(i);
+ dims.add_output_spatial_dimensions(dim);
+ window_strides[i] = attrs.strides.at(dim);
+ rhs_dilation[i] = attrs.dilations.at(dim);
+
+ int64 unused_output_size;
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerboseV2(
+ input_shape.dimensions(dim), filter_shape.dimensions(i),
+ rhs_dilation[i], window_strides[i], attrs.padding, &unused_output_size,
+ &padding[i].first, &padding[i].second));
+ }
+
+ return xla::ConvGeneralDilated(
+ conv_input, filter, window_strides, padding, lhs_dilation, rhs_dilation,
+ dims, /*feature_group_count=*/attrs.depthwise ? in_depth : 1);
+}
+
+xla::StatusOr<xla::XlaOp> MakeXlaBackpropInputConvOp(
+ StringPiece type_string, const xla::Shape& input_shape, xla::XlaOp filter,
+ xla::XlaOp out_backprop, const ConvOpAttrs& attrs) {
+ TF_RETURN_IF_ERROR(CheckConvAttrs(attrs));
+
+ int num_dims = attrs.num_spatial_dims + 2;
+ int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format);
+ int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format);
+
+ auto* builder = filter.builder();
+ TF_ASSIGN_OR_RETURN(xla::Shape filter_shape, builder->GetShape(filter));
+ TF_ASSIGN_OR_RETURN(xla::Shape out_backprop_shape,
+ builder->GetShape(out_backprop));
+
+ xla::Shape expanded_filter_shape =
+ attrs.depthwise ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape)
+ : filter_shape;
+ // Reuse dimension computation logic from conv_grad_ops.cc.
+ ConvBackpropDimensions dims;
+ TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes(
+ type_string, attrs.num_spatial_dims, input_shape, expanded_filter_shape,
+ out_backprop_shape, attrs.dilations, attrs.strides, attrs.padding,
+ attrs.data_format, &dims));
+
+ // The input gradients are computed by a convolution of the output
+ // gradients and the filter, with some appropriate padding. See the
+ // comment at the top of conv_grad_ops.h for details.
+
+ xla::ConvolutionDimensionNumbers dnums;
+ dnums.set_input_batch_dimension(batch_dim);
+ dnums.set_output_batch_dimension(batch_dim);
+ dnums.set_input_feature_dimension(feature_dim);
+ dnums.set_output_feature_dimension(feature_dim);
+
+ // TF filter shape is [ H, W, ..., inC, outC ]
+ // Transpose the input and output features for computing the gradient.
+ dnums.set_kernel_input_feature_dimension(attrs.num_spatial_dims + 1);
+ dnums.set_kernel_output_feature_dimension(attrs.num_spatial_dims);
+
+ std::vector<int64> kernel_spatial_dims(attrs.num_spatial_dims);
+ std::vector<std::pair<int64, int64>> padding(attrs.num_spatial_dims);
+ std::vector<int64> lhs_dilation(attrs.num_spatial_dims);
+ std::vector<int64> rhs_dilation(attrs.num_spatial_dims);
+ std::vector<int64> ones(attrs.num_spatial_dims, 1);
+ for (int i = 0; i < attrs.num_spatial_dims; ++i) {
+ int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
+ dnums.add_input_spatial_dimensions(dim);
+ dnums.add_kernel_spatial_dimensions(i);
+ dnums.add_output_spatial_dimensions(dim);
+
+ kernel_spatial_dims[i] = i;
+ padding[i] = {dims.spatial_dims[i].pad_before,
+ dims.spatial_dims[i].pad_after};
+ lhs_dilation[i] = dims.spatial_dims[i].stride;
+ rhs_dilation[i] = attrs.dilations[dim];
+ }
+
+ // Mirror the filter in the spatial dimensions.
+ xla::XlaOp mirrored_weights = xla::Rev(filter, kernel_spatial_dims);
+
+ // activation gradients
+ // = gradients (with padding and dilation) <conv> mirrored_weights
+ return xla::ConvGeneralDilated(
+ out_backprop, mirrored_weights, /*window_strides=*/ones, padding,
+ lhs_dilation, rhs_dilation, dnums,
+ /*feature_group_count=*/
+ attrs.depthwise ? out_backprop_shape.dimensions(feature_dim) /
+ filter_shape.dimensions(attrs.num_spatial_dims + 1)
+ : 1);
+}
+
+xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp(
+ StringPiece type_string, xla::XlaOp activations,
+ const xla::Shape& filter_shape, xla::XlaOp gradients,
+ const ConvOpAttrs& attrs) {
+ TF_RETURN_IF_ERROR(CheckConvAttrs(attrs));
+
+ auto* builder = activations.builder();
+ TF_ASSIGN_OR_RETURN(xla::Shape activations_shape,
+ builder->GetShape(activations));
+ TF_ASSIGN_OR_RETURN(xla::Shape out_backprop_shape,
+ builder->GetShape(gradients));
+ const xla::Shape expanded_filter_shape =
+ attrs.depthwise ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape)
+ : filter_shape;
+
+ // Reuse dimension computation logic from conv_grad_ops.cc.
+ ConvBackpropDimensions dims;
+ TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes(
+ type_string, attrs.num_spatial_dims, activations_shape,
+ expanded_filter_shape, out_backprop_shape, attrs.dilations, attrs.strides,
+ attrs.padding, attrs.data_format, &dims));
+
+ // The filter gradients are computed by a convolution of the input
+ // activations and the output gradients, with some appropriate padding.
+ // See the comment at the top of conv_grad_ops.h for details.
+
+ xla::ConvolutionDimensionNumbers dnums;
+
+ // The activations (inputs) form the LHS of the convolution.
+ // Activations have shape: [batch, in_rows, in_cols, ..., in_depth]
+ // For the gradient computation, we flip the roles of the batch and
+ // feature dimensions.
+ // Each spatial entry has size in_depth * batch
+
+ // The last two dimensions of the filter are the input and output shapes.
+ int num_dims = attrs.num_spatial_dims + 2;
+ int n_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format);
+ int c_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format);
+
+ // Swap n_dim and c_dim in the activations.
+ dnums.set_input_batch_dimension(c_dim);
+ dnums.set_input_feature_dimension(n_dim);
+
+ // The gradients become the RHS of the convolution.
+ // The gradients have shape [batch, out_rows, out_cols, ..., out_depth]
+ // where the batch becomes the input feature for the convolution.
+ dnums.set_kernel_input_feature_dimension(n_dim);
+ dnums.set_kernel_output_feature_dimension(c_dim);
+
+ std::vector<std::pair<int64, int64>> padding(attrs.num_spatial_dims);
+ std::vector<int64> rhs_dilation(attrs.num_spatial_dims);
+ std::vector<int64> window_strides(attrs.num_spatial_dims);
+ std::vector<int64> ones(attrs.num_spatial_dims, 1);
+
+ // Tensorflow filter shape is [ H, W, ..., inC, outC ].
+ for (int i = 0; i < attrs.num_spatial_dims; ++i) {
+ dnums.add_output_spatial_dimensions(i);
+ }
+ dnums.set_output_batch_dimension(attrs.num_spatial_dims);
+ dnums.set_output_feature_dimension(attrs.num_spatial_dims + 1);
+
+ for (int i = 0; i < attrs.num_spatial_dims; ++i) {
+ int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
+ dnums.add_input_spatial_dimensions(dim);
+ dnums.add_kernel_spatial_dimensions(dim);
+
+ // We will also need to pad the input with zeros such that after the
+ // convolution, we get the right size for the filter.
+ // The padded_in_rows should be such that when we convolve this with the
+ // expanded_out_rows as a filter, we should get filter_rows back.
+ //
+ const int64 padded_in_size =
+ dims.spatial_dims[i].expanded_output_size +
+ (dims.spatial_dims[i].filter_size - 1) * attrs.dilations[dim];
+
+ // However it can be smaller than input_rows: in this
+ // case it means some of the inputs are not used.
+ //
+ // An example is to have input_cols = 3, filter_cols = 2 and stride = 2:
+ //
+ // INPUT = [ A B C ]
+ //
+ // FILTER = [ x y ]
+ //
+ // and the output will only have one column: a = A * x + B * y
+ //
+ // and input "C" is not used at all.
+ //
+ // We apply negative padding in this case.
+ const int64 pad_total = padded_in_size - dims.spatial_dims[i].input_size;
+
+ // + For the VALID padding, we don't pad anything on the top/left side
+ // and pad the bottom/right side with the remaining space.
+ // + For the SAME padding, we pad top/left side the same as bottom/right
+ // side.
+ //
+ // In addition, if the padded input size is smaller than the input size,
+ // we need to ignore some training elements of the input. We do this by
+ // applying negative padding on the right/bottom.
+ const int64 pad_before =
+ attrs.padding == Padding::SAME ? std::max<int64>(pad_total / 2, 0) : 0;
+
+ padding[i] = {pad_before, pad_total - pad_before};
+ rhs_dilation[i] = dims.spatial_dims[i].stride;
+ window_strides[i] = attrs.dilations[dim];
+ }
+
+ // Besides padding the input, we will also expand output_rows to
+ // expanded_out_rows = (output_rows - 1) * stride + 1
+ // with zeros in between:
+ //
+ // a . . . b . . . c . . . d . . . e
+ //
+ // This is done by specifying the window dilation factors in the
+ // convolution HLO below.
+ auto filter_backprop =
+ xla::ConvGeneralDilated(activations, gradients, window_strides, padding,
+ /*lhs_dilation=*/ones, rhs_dilation, dnums);
+
+ if (attrs.depthwise) {
+ filter_backprop = ContractFilterForDepthwiseBackprop(
+ filter_shape, filter_backprop, activations.builder());
+ }
+
+ return filter_backprop;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h
new file mode 100644
index 0000000000..6e1b70a478
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h
@@ -0,0 +1,69 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_CONV_OP_HELPERS_H_
+#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_CONV_OP_HELPERS_H_
+
+#include <vector>
+
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/util/padding.h"
+#include "tensorflow/core/util/tensor_format.h"
+
+// This header exposes utilities for translating TensorFlow convolution ops into
+// XLA ops.
+//
+// conv_ops.cc contains lowerings for many of these TF convolution ops (e.g.
+// Conv2D, Conv3DBackpropFilterV2), but you might want to use the utilities in
+// this header to implement a new and exciting convolution op, for example a
+// fused TensorFlow op that contains a convolution and other things.
+
+namespace tensorflow {
+
+// ConvOpAttrs contains all of the metadata necessary to specify a TF or XLA
+// convolution.
+struct ConvOpAttrs {
+ // Constructs a ConvOpAttrs, reading most of the attributes from `ctx`.
+ static xla::StatusOr<ConvOpAttrs> Create(int num_spatial_dims, bool depthwise,
+ OpKernelConstruction* ctx);
+
+ bool depthwise;
+ int num_spatial_dims;
+ std::vector<int32> dilations;
+ std::vector<int32> strides;
+ Padding padding;
+ TensorFormat data_format;
+};
+
+// Creates a new XLA forward or backward convolution with the given inputs and
+// attributes.
+xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(StringPiece type_string,
+ xla::XlaOp conv_input,
+ xla::XlaOp filter,
+ const ConvOpAttrs& attrs);
+xla::StatusOr<xla::XlaOp> MakeXlaBackpropInputConvOp(
+ StringPiece type_string, const xla::Shape& input_shape, xla::XlaOp filter,
+ xla::XlaOp out_backprop, const ConvOpAttrs& attrs);
+xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp(
+ StringPiece type_string, xla::XlaOp activations,
+ const xla::Shape& filter_shape, xla::XlaOp gradients,
+ const ConvOpAttrs& attrs);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_CONV_OP_HELPERS_H_
diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
index 674720e22f..cd7c820be0 100644
--- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
@@ -15,12 +15,17 @@ limitations under the License.
// XLA-specific Ops for 2D convolution.
+#include "tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/numeric.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
@@ -33,250 +38,28 @@ limitations under the License.
#include "tensorflow/core/util/tensor_format.h"
namespace tensorflow {
-
namespace {
-// Returns the expanded size of a filter used for depthwise convolution.
-// If `shape` is [H, W, ..., M, N] returns [H, W, ..., M, M*N].
-TensorShape ExpandedFilterShapeForDepthwiseConvolution(
- const TensorShape& shape) {
- int num_dims = shape.dims();
- CHECK_GE(num_dims, 2);
- TensorShape expanded_shape = shape;
- expanded_shape.set_dim(num_dims - 1, shape.dim_size(num_dims - 2) *
- shape.dim_size(num_dims - 1));
- return expanded_shape;
-}
-
-// Broadcast zeros to ExpandedFilterShapeForDepthwiseConvolution.
-xla::XlaOp CreateExpandedZero(const TensorShape& filter_shape, DataType dtype,
- xla::XlaBuilder* builder) {
- TensorShape expanded_filter_shape =
- ExpandedFilterShapeForDepthwiseConvolution(filter_shape);
- return xla::Broadcast(XlaHelpers::Zero(builder, dtype),
- expanded_filter_shape.dim_sizes());
-}
-
-// Create a mask for depthwise convolution that will make a normal convolution
-// produce the same results as a depthwise convolution. For a [2, 2, 3, 2]
-// depthwise filter this returns a [2, 2, 3, 6] tensor
-// 1 1 0 0 0 0 1 1 0 0 0 0
-// 0 0 1 1 0 0 0 0 1 1 0 0
-// 0 0 0 0 1 1 0 0 0 0 1 1
-//
-// 1 1 0 0 0 0 1 1 0 0 0 0
-// 0 0 1 1 0 0 0 0 1 1 0 0
-// 0 0 0 0 1 1 0 0 0 0 1 1
-//
-// The first step is to create a one tensor, A, that is [3]
-// 0 1 2
-//
-// and another tensor, B, that is [3 * 2]
-// 0 1 2 3 4 5
-//
-// and divide B it by 2 to get
-// 0 0 1 1 2 2
-//
-// then we broadcast the B to [2, 2, 3, 3 * 2]
-// 0 0 1 1 2 2 0 0 1 1 2 2
-// 0 0 1 1 2 2 0 0 1 1 2 2
-// 0 0 1 1 2 2 0 0 1 1 2 2
-//
-// 0 0 1 1 2 2 0 0 1 1 2 2
-// 0 0 1 1 2 2 0 0 1 1 2 2
-// 0 0 1 1 2 2 0 0 1 1 2 2
-//
-// Finally compare A and broadcasted B in dimension 2 amd return the result at
-// the beginning of the comment.
-xla::XlaOp CreateExpandedFilterMask(const TensorShape& filter_shape,
- xla::XlaBuilder* builder) {
- TensorShape expanded_filter_shape =
- ExpandedFilterShapeForDepthwiseConvolution(filter_shape);
- int64 depthwise_multiplier = filter_shape.dim_size(filter_shape.dims() - 1);
- int64 input_feature = filter_shape.dim_size(filter_shape.dims() - 2);
-
- // Create a M sized linspace and an M*N sized linspace that will be
- // broadcasted into perpendicular dimensions and compared.
- xla::XlaOp input_feature_iota = xla::Iota(builder, xla::S32, input_feature);
- xla::XlaOp expanded_feature_iota =
- xla::Iota(builder, xla::S32, input_feature * depthwise_multiplier);
-
- // Divide the M*N sized linspace by the depthwise_multiplier to create
- // [0 0 1 1 2 2] in the example in the function comment.
- expanded_feature_iota =
- xla::Div(expanded_feature_iota,
- XlaHelpers::IntegerLiteral(builder, DataType::DT_INT32,
- depthwise_multiplier));
-
- // Broadcast the N*M linspace to [H, W, ..., M, M*N].
- auto expanded_feature_broadcast_dims = expanded_filter_shape.dim_sizes();
- expanded_feature_broadcast_dims.pop_back();
- auto broadcasted_expanded_feature_iota =
- xla::Broadcast(expanded_feature_iota, expanded_feature_broadcast_dims);
-
- // Compare the broadcasted linspace to the input feature linspace in the
- // input feature dimension to create a diagonal predicate.
- return xla::Eq(broadcasted_expanded_feature_iota, input_feature_iota,
- {expanded_filter_shape.dims() - 2});
-}
-
-// Reshapes a filter of shape [H, W, ..., M, N] to [H, W, ..., 1, M*N]. Used to
-// build a depthwise convolution.
-xla::XlaOp ReshapeFilterForDepthwiseConvolution(const TensorShape& filter_shape,
- const xla::XlaOp& filter) {
- int64 input_feature_dim = filter_shape.dims() - 2;
- int64 output_feature_dim = filter_shape.dims() - 1;
- int64 depthwise_multiplier = filter_shape.dim_size(output_feature_dim);
- int64 input_feature = filter_shape.dim_size(input_feature_dim);
-
- // Create a [H, W, ..., 1, N*M] reshape of the filter.
- TensorShape implicit_broadcast_filter_shape = filter_shape;
- implicit_broadcast_filter_shape.set_dim(input_feature_dim, 1);
- implicit_broadcast_filter_shape.set_dim(output_feature_dim,
- depthwise_multiplier * input_feature);
- return xla::Reshape(filter, implicit_broadcast_filter_shape.dim_sizes());
-}
-
-// Reduces the results of the convolution with an expanded filter to the
-// non-expanded filter.
-xla::XlaOp ContractFilterForDepthwiseBackprop(XlaOpKernelContext* ctx,
- const TensorShape& filter_shape,
- DataType dtype,
- const xla::XlaOp& filter_backprop,
- xla::XlaBuilder* builder) {
- auto masked_expanded_filter = xla::Select(
- CreateExpandedFilterMask(filter_shape, builder), filter_backprop,
- CreateExpandedZero(filter_shape, dtype, builder));
- return xla::Reshape(
- // This reduce does not need inputs to be converted with
- // XlaHelpers::SumAccumulationType() since the ExpandedFilterMask with
- // ExpandedZero guarantees that only one element is non zero, so there
- // cannot be accumulated precision error.
- xla::Reduce(masked_expanded_filter, XlaHelpers::Zero(builder, dtype),
- *ctx->GetOrCreateAdd(dtype), {filter_shape.dims() - 2}),
- filter_shape.dim_sizes());
-}
-
class ConvOp : public XlaOpKernel {
public:
explicit ConvOp(OpKernelConstruction* ctx, int num_spatial_dims,
bool depthwise)
- : XlaOpKernel(ctx),
- num_spatial_dims_(num_spatial_dims),
- depthwise_(depthwise) {
- OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
-
- string data_format;
- OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
- OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
- errors::InvalidArgument("Invalid data format"));
+ : XlaOpKernel(ctx) {
+ xla::StatusOr<ConvOpAttrs> attrs =
+ ConvOpAttrs::Create(num_spatial_dims, depthwise, ctx);
+ OP_REQUIRES_OK(ctx, attrs.status());
+ attrs_ = attrs.ValueOrDie();
}
- int num_dims() const { return num_spatial_dims_ + 2; }
-
void Compile(XlaOpKernelContext* ctx) override {
- OP_REQUIRES(ctx, strides_.size() == num_dims(),
- errors::InvalidArgument("Sliding window strides field must "
- "specify ",
- num_dims(), " dimensions"));
- int batch_dim = GetTensorBatchDimIndex(num_dims(), data_format_);
- int feature_dim = GetTensorFeatureDimIndex(num_dims(), data_format_);
- OP_REQUIRES(
- ctx, strides_[batch_dim] == 1 && strides_[feature_dim] == 1,
- errors::Unimplemented("Current implementation does not yet support "
- "strides in the batch and depth dimensions."));
-
- OP_REQUIRES(ctx, dilations_.size() == num_dims(),
- errors::InvalidArgument("Dilations field must "
- "specify ",
- num_dims(), " dimensions"));
- OP_REQUIRES(
- ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1,
- errors::Unimplemented("Current implementation does not support "
- "dilations in the batch and depth dimensions."));
- for (int i = 0; i < num_spatial_dims_; ++i) {
- int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
- OP_REQUIRES(ctx, dilations_[input_dim] >= 1,
- errors::Unimplemented("Dilation values must be positive; ", i,
- "th spatial dimension had dilation ",
- dilations_[input_dim]));
- }
-
- const TensorShape input_shape = ctx->InputShape(0);
- // Input filter is of the following dimensions:
- // [ filter_rows, filter_cols, ..., in_depth, out_depth]
- const TensorShape filter_shape = ctx->InputShape(1);
-
- // For 2D convolution, there should be 4 dimensions.
- OP_REQUIRES(
- ctx, input_shape.dims() == num_dims(),
- errors::InvalidArgument("input must be ", num_dims(), "-dimensional",
- input_shape.DebugString()));
- OP_REQUIRES(
- ctx, filter_shape.dims() == num_dims(),
- errors::InvalidArgument("filter must be ", num_dims(),
- "-dimensional: ", filter_shape.DebugString()));
-
- // The last two dimension of the filter are the input and output shapes.
- const int64 in_depth = filter_shape.dim_size(num_spatial_dims_);
-
- // The 'C' dimension for input is in_depth. It must be the same as
- // the filter's in_depth.
- OP_REQUIRES(ctx, in_depth == input_shape.dim_size(feature_dim),
- errors::InvalidArgument(
- "input and filter must have the same depth: ", in_depth,
- " vs ", input_shape.dim_size(feature_dim)));
-
- xla::XlaOp filter = ctx->Input(1);
- if (depthwise_) {
- filter = ReshapeFilterForDepthwiseConvolution(filter_shape, filter);
- }
-
- xla::ConvolutionDimensionNumbers dims;
- std::vector<int64> window_strides(num_spatial_dims_);
- std::vector<int64> lhs_dilation(num_spatial_dims_, 1);
- std::vector<int64> rhs_dilation(num_spatial_dims_);
- std::vector<std::pair<int64, int64>> padding(num_spatial_dims_);
-
- dims.set_input_batch_dimension(batch_dim);
- dims.set_output_batch_dimension(batch_dim);
- dims.set_input_feature_dimension(feature_dim);
- dims.set_output_feature_dimension(feature_dim);
- dims.set_kernel_input_feature_dimension(num_spatial_dims_);
- dims.set_kernel_output_feature_dimension(num_spatial_dims_ + 1);
-
- for (int i = 0; i < num_spatial_dims_; ++i) {
- const int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
- dims.add_input_spatial_dimensions(dim);
- dims.add_kernel_spatial_dimensions(i);
- dims.add_output_spatial_dimensions(dim);
- window_strides[i] = strides_.at(dim);
- rhs_dilation[i] = dilations_.at(dim);
-
- int64 unused_output_size;
- OP_REQUIRES_OK(
- ctx, GetWindowedOutputSizeVerboseV2(
- input_shape.dim_size(dim), filter_shape.dim_size(i),
- rhs_dilation[i], window_strides[i], padding_,
- &unused_output_size, &padding[i].first, &padding[i].second));
- }
-
- xla::XlaOp conv = xla::ConvGeneralDilated(
- ctx->Input(0), filter, window_strides, padding, lhs_dilation,
- rhs_dilation, dims,
- /*feature_group_count=*/depthwise_ ? in_depth : 1);
- ctx->SetOutput(0, conv);
+ xla::StatusOr<xla::XlaOp> conv = MakeXlaForwardConvOp(
+ ctx->op_kernel().type_string(), ctx->Input(0), ctx->Input(1), attrs_);
+ OP_REQUIRES_OK(ctx, conv.status());
+ ctx->SetOutput(0, conv.ValueOrDie());
}
protected:
- const int num_spatial_dims_;
- const bool depthwise_;
- std::vector<int32> dilations_;
- std::vector<int32> strides_;
- Padding padding_;
- TensorFormat data_format_ = FORMAT_NHWC;
+ ConvOpAttrs attrs_;
private:
TF_DISALLOW_COPY_AND_ASSIGN(ConvOp);
@@ -308,124 +91,28 @@ class ConvBackpropInputOp : public XlaOpKernel {
public:
explicit ConvBackpropInputOp(OpKernelConstruction* ctx, int num_spatial_dims,
bool depthwise)
- : XlaOpKernel(ctx),
- num_spatial_dims_(num_spatial_dims),
- depthwise_(depthwise) {
- OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
- string data_format;
- OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
- OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
- errors::InvalidArgument("Invalid data format"));
+ : XlaOpKernel(ctx) {
+ xla::StatusOr<ConvOpAttrs> attrs =
+ ConvOpAttrs::Create(num_spatial_dims, depthwise, ctx);
+ OP_REQUIRES_OK(ctx, attrs.status());
+ attrs_ = attrs.ValueOrDie();
}
- int num_dims() const { return num_spatial_dims_ + 2; }
-
void Compile(XlaOpKernelContext* ctx) override {
- OP_REQUIRES(ctx, strides_.size() == num_dims(),
- errors::InvalidArgument("Sliding window strides field must "
- "specify ",
- num_dims(), " dimensions"));
- int batch_dim = GetTensorBatchDimIndex(num_dims(), data_format_);
- int feature_dim = GetTensorFeatureDimIndex(num_dims(), data_format_);
- OP_REQUIRES(
- ctx, strides_[batch_dim] == 1 && strides_[feature_dim] == 1,
- errors::Unimplemented("Current implementation does not yet support "
- "strides in the batch and depth dimensions."));
-
- OP_REQUIRES(ctx, dilations_.size() == num_dims(),
- errors::InvalidArgument("Dilations field must "
- "specify ",
- num_dims(), " dimensions"));
- OP_REQUIRES(
- ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1,
- errors::Unimplemented("Current implementation does not support "
- "dilations in the batch and depth dimensions."));
- for (int i = 0; i < num_spatial_dims_; ++i) {
- int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
- OP_REQUIRES(ctx, dilations_[input_dim] >= 1,
- errors::Unimplemented("Dilation values must be positive; ", i,
- "th spatial dimension had dilation ",
- dilations_[input_dim]));
- }
-
- TensorShape input_shape;
- OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_shape));
-
- const TensorShape filter_shape = ctx->InputShape(1);
- const TensorShape out_backprop_shape = ctx->InputShape(2);
-
- const TensorShape expanded_filter_shape =
- depthwise_ ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape)
- : filter_shape;
- // Reuse dimension computation logic from conv_grad_ops.cc.
- ConvBackpropDimensions dims;
- OP_REQUIRES_OK(ctx,
- ConvBackpropComputeDimensionsV2(
- type_string(), num_spatial_dims_, input_shape,
- expanded_filter_shape, out_backprop_shape, dilations_,
- strides_, padding_, data_format_, &dims));
-
- auto filter = ctx->Input(1);
- auto out_backprop = ctx->Input(2);
-
- // The input gradients are computed by a convolution of the output
- // gradients and the filter, with some appropriate padding. See the
- // comment at the top of conv_grad_ops.h for details.
-
- xla::ConvolutionDimensionNumbers dnums;
- dnums.set_input_batch_dimension(batch_dim);
- dnums.set_output_batch_dimension(batch_dim);
- dnums.set_input_feature_dimension(feature_dim);
- dnums.set_output_feature_dimension(feature_dim);
-
- // TF filter shape is [ H, W, ..., inC, outC ]
- // Transpose the input and output features for computing the gradient.
- dnums.set_kernel_input_feature_dimension(num_spatial_dims_ + 1);
- dnums.set_kernel_output_feature_dimension(num_spatial_dims_);
-
- std::vector<int64> kernel_spatial_dims(num_spatial_dims_);
- std::vector<std::pair<int64, int64>> padding(num_spatial_dims_);
- std::vector<int64> lhs_dilation(num_spatial_dims_);
- std::vector<int64> rhs_dilation(num_spatial_dims_);
- std::vector<int64> ones(num_spatial_dims_, 1);
- for (int i = 0; i < num_spatial_dims_; ++i) {
- int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
- dnums.add_input_spatial_dimensions(dim);
- dnums.add_kernel_spatial_dimensions(i);
- dnums.add_output_spatial_dimensions(dim);
-
- kernel_spatial_dims[i] = i;
- padding[i] = {dims.spatial_dims[i].pad_before,
- dims.spatial_dims[i].pad_after};
- lhs_dilation[i] = dims.spatial_dims[i].stride;
- rhs_dilation[i] = dilations_[dim];
- }
-
- // Mirror the filter in the spatial dimensions.
- xla::XlaOp mirrored_weights = xla::Rev(filter, kernel_spatial_dims);
-
- // activation gradients
- // = gradients (with padding and dilation) <conv> mirrored_weights
- xla::XlaOp in_backprop = xla::ConvGeneralDilated(
- out_backprop, mirrored_weights, /*window_strides=*/ones, padding,
- lhs_dilation, rhs_dilation, dnums,
- /*feature_group_count=*/
- depthwise_ ? out_backprop_shape.dim_size(feature_dim) /
- filter_shape.dim_size(num_spatial_dims_ + 1)
- : 1);
-
- ctx->SetOutput(0, in_backprop);
+ TensorShape input_tensor_shape;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_tensor_shape));
+ xla::Shape input_shape =
+ TensorShapeToXLAShape(ctx->input_xla_type(1), input_tensor_shape);
+
+ xla::StatusOr<xla::XlaOp> in_backprop =
+ MakeXlaBackpropInputConvOp(ctx->op_kernel().type_string(), input_shape,
+ ctx->Input(1), ctx->Input(2), attrs_);
+ OP_REQUIRES_OK(ctx, in_backprop.status());
+ ctx->SetOutput(0, in_backprop.ValueOrDie());
}
protected:
- const int num_spatial_dims_;
- const bool depthwise_;
- std::vector<int32> dilations_;
- std::vector<int32> strides_;
- Padding padding_;
- TensorFormat data_format_ = FORMAT_NHWC;
+ ConvOpAttrs attrs_;
private:
TF_DISALLOW_COPY_AND_ASSIGN(ConvBackpropInputOp);
@@ -462,172 +149,28 @@ class ConvBackpropFilterOp : public XlaOpKernel {
public:
explicit ConvBackpropFilterOp(OpKernelConstruction* ctx, int num_spatial_dims,
bool depthwise)
- : XlaOpKernel(ctx),
- num_spatial_dims_(num_spatial_dims),
- depthwise_(depthwise) {
- OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
- string data_format;
- OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
- OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
- errors::InvalidArgument("Invalid data format"));
+ : XlaOpKernel(ctx) {
+ xla::StatusOr<ConvOpAttrs> attrs =
+ ConvOpAttrs::Create(num_spatial_dims, depthwise, ctx);
+ OP_REQUIRES_OK(ctx, attrs.status());
+ attrs_ = attrs.ValueOrDie();
}
- int num_dims() const { return num_spatial_dims_ + 2; }
-
void Compile(XlaOpKernelContext* ctx) override {
- const int n_dim = GetTensorBatchDimIndex(num_dims(), data_format_);
- const int c_dim = GetTensorFeatureDimIndex(num_dims(), data_format_);
-
- OP_REQUIRES(
- ctx, (strides_[n_dim] == 1 && strides_[c_dim] == 1),
- errors::InvalidArgument("Current implementation does not yet support "
- "strides in the batch and depth dimensions."));
-
- OP_REQUIRES(ctx, dilations_.size() == num_dims(),
- errors::InvalidArgument("Dilations field must "
- "specify ",
- num_dims(), " dimensions"));
- OP_REQUIRES(
- ctx, dilations_[n_dim] == 1 && dilations_[c_dim] == 1,
- errors::Unimplemented("Current implementation does not support "
- "dilations in the batch and depth dimensions."));
- for (int i = 0; i < num_spatial_dims_; ++i) {
- int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
- OP_REQUIRES(ctx, dilations_[input_dim] >= 1,
- errors::Unimplemented("Dilation values must be positive; ", i,
- "th spatial dimension had dilation ",
- dilations_[input_dim]));
- }
-
- const TensorShape activations_shape = ctx->InputShape(0);
- TensorShape filter_shape;
- OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(1, &filter_shape));
- const TensorShape out_backprop_shape = ctx->InputShape(2);
-
- const TensorShape expanded_filter_shape =
- depthwise_ ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape)
- : filter_shape;
-
- // Reuse dimension computation logic from conv_grad_ops.cc.
- ConvBackpropDimensions dims;
- OP_REQUIRES_OK(ctx,
- ConvBackpropComputeDimensionsV2(
- type_string(), num_spatial_dims_, activations_shape,
- expanded_filter_shape, out_backprop_shape, dilations_,
- strides_, padding_, data_format_, &dims));
-
- xla::XlaBuilder* b = ctx->builder();
- xla::XlaOp activations = ctx->Input(0);
- xla::XlaOp gradients = ctx->Input(2);
-
- // The filter gradients are computed by a convolution of the input
- // activations and the output gradients, with some appropriate padding.
- // See the comment at the top of conv_grad_ops.h for details.
-
- xla::ConvolutionDimensionNumbers dnums;
-
- // The activations (inputs) form the LHS of the convolution.
- // Activations have shape: [batch, in_rows, in_cols, ..., in_depth]
- // For the gradient computation, we flip the roles of the batch and
- // feature dimensions.
- // Each spatial entry has size in_depth * batch
-
- // Swap n_dim and c_dim in the activations.
- dnums.set_input_batch_dimension(c_dim);
- dnums.set_input_feature_dimension(n_dim);
-
- // The gradients become the RHS of the convolution.
- // The gradients have shape [batch, out_rows, out_cols, ..., out_depth]
- // where the batch becomes the input feature for the convolution.
- dnums.set_kernel_input_feature_dimension(n_dim);
- dnums.set_kernel_output_feature_dimension(c_dim);
-
- std::vector<std::pair<int64, int64>> padding(num_spatial_dims_);
- std::vector<int64> rhs_dilation(num_spatial_dims_);
- std::vector<int64> window_strides(num_spatial_dims_);
- std::vector<int64> ones(num_spatial_dims_, 1);
-
- // Tensorflow filter shape is [ H, W, ..., inC, outC ].
- for (int i = 0; i < num_spatial_dims_; ++i) {
- dnums.add_output_spatial_dimensions(i);
- }
- dnums.set_output_batch_dimension(num_spatial_dims_);
- dnums.set_output_feature_dimension(num_spatial_dims_ + 1);
-
- for (int i = 0; i < num_spatial_dims_; ++i) {
- int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
- dnums.add_input_spatial_dimensions(dim);
- dnums.add_kernel_spatial_dimensions(dim);
-
- // We will also need to pad the input with zeros such that after the
- // convolution, we get the right size for the filter.
- // The padded_in_rows should be such that when we convolve this with the
- // expanded_out_rows as a filter, we should get filter_rows back.
- //
- const int64 padded_in_size =
- dims.spatial_dims[i].expanded_output_size +
- (dims.spatial_dims[i].filter_size - 1) * dilations_[dim];
-
- // However it can be smaller than input_rows: in this
- // case it means some of the inputs are not used.
- //
- // An example is to have input_cols = 3, filter_cols = 2 and stride = 2:
- //
- // INPUT = [ A B C ]
- //
- // FILTER = [ x y ]
- //
- // and the output will only have one column: a = A * x + B * y
- //
- // and input "C" is not used at all.
- //
- // We apply negative padding in this case.
- const int64 pad_total = padded_in_size - dims.spatial_dims[i].input_size;
-
- // + For the VALID padding, we don't pad anything on the top/left side
- // and pad the bottom/right side with the remaining space.
- // + For the SAME padding, we pad top/left side the same as bottom/right
- // side.
- //
- // In addition, if the padded input size is smaller than the input size,
- // we need to ignore some training elements of the input. We do this by
- // applying negative padding on the right/bottom.
- const int64 pad_before =
- padding_ == Padding::SAME ? std::max<int64>(pad_total / 2, 0) : 0;
-
- padding[i] = {pad_before, pad_total - pad_before};
- rhs_dilation[i] = dims.spatial_dims[i].stride;
- window_strides[i] = dilations_[dim];
- }
-
- // Besides padding the input, we will also expand output_rows to
- // expanded_out_rows = (output_rows - 1) * stride + 1
- // with zeros in between:
- //
- // a . . . b . . . c . . . d . . . e
- //
- // This is done by specifying the window dilation factors in the
- // convolution HLO below.
- auto filter_backprop =
- xla::ConvGeneralDilated(activations, gradients, window_strides, padding,
- /*lhs_dilation=*/ones, rhs_dilation, dnums);
-
- if (depthwise_) {
- filter_backprop = ContractFilterForDepthwiseBackprop(
- ctx, filter_shape, ctx->input_type(0), filter_backprop, b);
- }
- ctx->SetOutput(0, filter_backprop);
+ TensorShape filter_tensor_shape;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(1, &filter_tensor_shape));
+ xla::Shape filter_shape =
+ TensorShapeToXLAShape(ctx->input_xla_type(0), filter_tensor_shape);
+
+ xla::StatusOr<xla::XlaOp> filter_backprop = MakeXlaBackpropFilterConvOp(
+ ctx->op_kernel().type_string(), ctx->Input(0), filter_shape,
+ ctx->Input(2), attrs_);
+ OP_REQUIRES_OK(ctx, filter_backprop.status());
+ ctx->SetOutput(0, filter_backprop.ValueOrDie());
}
protected:
- const int num_spatial_dims_;
- const bool depthwise_;
- std::vector<int32> dilations_;
- std::vector<int32> strides_;
- Padding padding_;
- TensorFormat data_format_ = FORMAT_NHWC;
+ ConvOpAttrs attrs_;
private:
TF_DISALLOW_COPY_AND_ASSIGN(ConvBackpropFilterOp);
diff --git a/tensorflow/compiler/tf2xla/shape_util.cc b/tensorflow/compiler/tf2xla/shape_util.cc
index 9d1992205b..b589512dcd 100644
--- a/tensorflow/compiler/tf2xla/shape_util.cc
+++ b/tensorflow/compiler/tf2xla/shape_util.cc
@@ -41,6 +41,14 @@ Status XLAShapeToTensorShape(const xla::Shape& shape,
// Convert a TensorShape into the equivalent XLA Shape proto.
Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape,
xla::Shape* shape) {
+ xla::PrimitiveType type;
+ TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(dtype, &type));
+ *shape = TensorShapeToXLAShape(type, tensor_shape);
+ return Status::OK();
+}
+
+xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type,
+ const TensorShape& tensor_shape) {
int rank = tensor_shape.dims();
std::vector<int64> dimensions(rank);
std::vector<int64> layout(rank);
@@ -50,11 +58,7 @@ Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape,
// XLA uses minor-to-major; Tensorflow uses major-to-minor.
std::iota(layout.rbegin(), layout.rend(), 0);
- xla::PrimitiveType type;
- TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(dtype, &type));
-
- *shape = xla::ShapeUtil::MakeShapeWithLayout(type, dimensions, layout);
- return Status::OK();
+ return xla::ShapeUtil::MakeShapeWithLayout(type, dimensions, layout);
}
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/shape_util.h b/tensorflow/compiler/tf2xla/shape_util.h
index 58240b9c96..f7e34a5b40 100644
--- a/tensorflow/compiler/tf2xla/shape_util.h
+++ b/tensorflow/compiler/tf2xla/shape_util.h
@@ -35,6 +35,11 @@ Status XLAShapeToTensorShape(const xla::Shape& shape,
Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape,
xla::Shape* shape);
+// Converts a TensorShape into the equivalent XLA Shape proto, taking an
+// xla::PrimitiveType to specify the element type. This never fails.
+xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type,
+ const TensorShape& tensor_shape);
+
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_SHAPE_UTIL_H_