diff options
author | Justin Lebar <jlebar@google.com> | 2018-09-17 23:09:48 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-17 23:13:50 -0700 |
commit | 9cc7bbe5b476bec556d7dce235996a03775d7492 (patch) | |
tree | 7943f0d1eb95737bd2b3792facf1a39ec3e7d370 /tensorflow/compiler/tf2xla | |
parent | 7c826588b058c14fd8c152bedb4e256c57ae1248 (diff) |
[XLA] Refactor conv_ops emitters to make them reusable.
PiperOrigin-RevId: 213398930
Diffstat (limited to 'tensorflow/compiler/tf2xla')
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/BUILD | 22 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc | 509 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h | 69 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/conv_ops.cc | 551 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/shape_util.cc | 14 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/shape_util.h | 5 |
6 files changed, 661 insertions, 509 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 46794f7b50..3e823254d3 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -113,6 +113,7 @@ tf_kernel_library( "shape_util.h", ], deps = [ + ":conv_op_helpers", ":if_op", ":while_op", "//tensorflow/compiler/tf2xla:common", @@ -172,6 +173,27 @@ tf_kernel_library( ], ) +cc_library( + name = "conv_op_helpers", + srcs = ["conv_op_helpers.cc"], + hdrs = ["conv_op_helpers.h"], + deps = [ + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/client/lib:numeric", + "//tensorflow/core:framework", + "//tensorflow/core/kernels:bounds_check", + "//tensorflow/core/kernels:conv_ops", + "//tensorflow/core/kernels:ops_util", + "@com_google_absl//absl/types:span", + ], +) + tf_kernel_library( name = "while_op", srcs = ["while_op.cc"], diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc new file mode 100644 index 0000000000..c9a1be4940 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc @@ -0,0 +1,509 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA-specific Ops for 2D convolution. + +#include "tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_slice.h" +#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/kernels/conv_grad_ops.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/util/padding.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace tensorflow { +namespace { + +// Returns the expanded size of a filter used for depthwise convolution. +// If `shape` is [H, W, ..., M, N] returns [H, W, ..., M, M*N]. +xla::Shape ExpandedFilterShapeForDepthwiseConvolution(const xla::Shape& shape) { + int num_dims = shape.dimensions_size(); + CHECK_GE(num_dims, 2); // Crash OK + xla::Shape expanded_shape = shape; + expanded_shape.set_dimensions( + num_dims - 1, + shape.dimensions(num_dims - 2) * shape.dimensions(num_dims - 1)); + return expanded_shape; +} + +// Create a mask for depthwise convolution that will make a normal convolution +// produce the same results as a depthwise convolution. For a [2, 2, 3, 2] +// depthwise filter this returns a [2, 2, 3, 6] tensor +// 1 1 0 0 0 0 1 1 0 0 0 0 +// 0 0 1 1 0 0 0 0 1 1 0 0 +// 0 0 0 0 1 1 0 0 0 0 1 1 +// +// 1 1 0 0 0 0 1 1 0 0 0 0 +// 0 0 1 1 0 0 0 0 1 1 0 0 +// 0 0 0 0 1 1 0 0 0 0 1 1 +// +// The first step is to create a one tensor, A, that is [3] +// 0 1 2 +// +// and another tensor, B, that is [3 * 2] +// 0 1 2 3 4 5 +// +// and divide B it by 2 to get +// 0 0 1 1 2 2 +// +// then we broadcast the B to [2, 2, 3, 3 * 2] +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 +// +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 +// +// Finally compare A and broadcasted B in dimension 2 amd return the result at +// the beginning of the comment. +xla::XlaOp CreateExpandedFilterMask(const xla::Shape& filter_shape, + xla::XlaBuilder* builder) { + xla::Shape expanded_filter_shape = + ExpandedFilterShapeForDepthwiseConvolution(filter_shape); + int64 depthwise_multiplier = + filter_shape.dimensions(filter_shape.dimensions_size() - 1); + int64 input_feature = + filter_shape.dimensions(filter_shape.dimensions_size() - 2); + + // Create a M sized linspace and an M*N sized linspace that will be + // broadcasted into perpendicular dimensions and compared. + xla::XlaOp input_feature_iota = xla::Iota(builder, xla::S32, input_feature); + xla::XlaOp expanded_feature_iota = + xla::Iota(builder, xla::S32, input_feature * depthwise_multiplier); + + // Divide the M*N sized linspace by the depthwise_multiplier to create + // [0 0 1 1 2 2] in the example in the function comment. + expanded_feature_iota = + xla::Div(expanded_feature_iota, + XlaHelpers::IntegerLiteral(builder, DataType::DT_INT32, + depthwise_multiplier)); + + // Broadcast the N*M linspace to [H, W, ..., M, M*N]. + std::vector<int64> expanded_feature_broadcast_dims( + expanded_filter_shape.dimensions().begin(), + expanded_filter_shape.dimensions().end()); + expanded_feature_broadcast_dims.pop_back(); + auto broadcasted_expanded_feature_iota = + xla::Broadcast(expanded_feature_iota, expanded_feature_broadcast_dims); + + // Compare the broadcasted linspace to the input feature linspace in the + // input feature dimension to create a diagonal predicate. + return xla::Eq(broadcasted_expanded_feature_iota, input_feature_iota, + {expanded_filter_shape.dimensions_size() - 2}); +} + +// Reshapes a filter of shape [H, W, ..., M, N] to [H, W, ..., 1, M*N]. Used to +// build a depthwise convolution. +xla::XlaOp ReshapeFilterForDepthwiseConvolution(const xla::Shape& filter_shape, + const xla::XlaOp& filter) { + int64 input_feature_dim = filter_shape.dimensions_size() - 2; + int64 output_feature_dim = filter_shape.dimensions_size() - 1; + int64 depthwise_multiplier = filter_shape.dimensions(output_feature_dim); + int64 input_feature = filter_shape.dimensions(input_feature_dim); + + // Create a [H, W, ..., 1, N*M] reshape of the filter. + xla::Shape implicit_broadcast_filter_shape = filter_shape; + implicit_broadcast_filter_shape.set_dimensions(input_feature_dim, 1); + implicit_broadcast_filter_shape.set_dimensions( + output_feature_dim, depthwise_multiplier * input_feature); + return xla::Reshape( + filter, xla::AsInt64Slice(implicit_broadcast_filter_shape.dimensions())); +} + +// Reduces the results of the convolution with an expanded filter to the +// non-expanded filter. +xla::XlaOp ContractFilterForDepthwiseBackprop(const xla::Shape& filter_shape, + const xla::XlaOp& filter_backprop, + xla::XlaBuilder* builder) { + auto masked_expanded_filter = + xla::Select(CreateExpandedFilterMask(filter_shape, builder), + filter_backprop, xla::ZerosLike(filter_backprop)); + + auto elem_type = filter_shape.element_type(); + return xla::Reshape( + // This reduce does not need inputs to be converted with + // XlaHelpers::SumAccumulationType() since the select above guarantees + // that only one element is non zero, so there cannot be accumulated + // precision error. + xla::Reduce(masked_expanded_filter, xla::Zero(builder, elem_type), + CreateScalarAddComputation(elem_type, builder), + {filter_shape.dimensions_size() - 2}), + xla::AsInt64Slice(filter_shape.dimensions())); +} + +// Performs some basic checks on ConvOpAttrs that are true for all kinds of XLA +// convolutions (as currently implemented). +Status CheckConvAttrs(const ConvOpAttrs& attrs) { + const int num_dims = attrs.num_spatial_dims + 2; + if (attrs.strides.size() != num_dims) { + return errors::InvalidArgument("Sliding window strides field must specify ", + num_dims, " dimensions"); + } + int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format); + int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format); + if (attrs.strides[batch_dim] != 1 || attrs.strides[feature_dim] != 1) { + return errors::Unimplemented( + "Current implementation does not yet support strides in the batch and " + "depth dimensions."); + } + if (attrs.dilations.size() != num_dims) { + return errors::InvalidArgument("Dilations field must specify ", num_dims, + " dimensions"); + } + if (attrs.dilations[batch_dim] != 1 || attrs.dilations[feature_dim] != 1) { + return errors::Unimplemented( + "Current implementation does not support dilations in the batch and " + "depth dimensions."); + } + for (int i = 0; i < attrs.num_spatial_dims; ++i) { + int input_dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i); + if (attrs.dilations[input_dim] < 1) { + return errors::Unimplemented("Dilation values must be positive; ", i, + "th spatial dimension had dilation ", + attrs.dilations[input_dim]); + } + } + return Status::OK(); +} + +// Wrapper around ConvBackpropComputeDimensions that converts from XLA shapes +// to TensorShapes. +Status ConvBackpropComputeDimensionsV2XlaShapes( + StringPiece label, int num_spatial_dims, const xla::Shape& input_shape, + const xla::Shape& filter_shape, const xla::Shape& out_backprop_shape, + absl::Span<const int32> dilations, const std::vector<int32>& strides, + Padding padding, TensorFormat data_format, ConvBackpropDimensions* dims) { + TensorShape input_tensor_shape, filter_tensor_shape, + out_backprop_tensor_shape; + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(input_shape, &input_tensor_shape)); + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(filter_shape, &filter_tensor_shape)); + TF_RETURN_IF_ERROR( + XLAShapeToTensorShape(out_backprop_shape, &out_backprop_tensor_shape)); + return ConvBackpropComputeDimensionsV2( + label, num_spatial_dims, input_tensor_shape, filter_tensor_shape, + out_backprop_tensor_shape, dilations, strides, padding, data_format, + dims); +} + +} // anonymous namespace + +xla::StatusOr<ConvOpAttrs> ConvOpAttrs::Create(int num_spatial_dims, + bool depthwise, + OpKernelConstruction* ctx) { + ConvOpAttrs attrs; + attrs.num_spatial_dims = num_spatial_dims; + attrs.depthwise = depthwise; + TF_RETURN_IF_ERROR(ctx->GetAttr("dilations", &attrs.dilations)); + TF_RETURN_IF_ERROR(ctx->GetAttr("strides", &attrs.strides)); + TF_RETURN_IF_ERROR(ctx->GetAttr("padding", &attrs.padding)); + + string data_format; + TF_RETURN_IF_ERROR(ctx->GetAttr("data_format", &data_format)); + if (!FormatFromString(data_format, &attrs.data_format)) { + return errors::InvalidArgument("Invalid data format: ", data_format); + } + + return attrs; +} + +xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(StringPiece /*type_string*/, + xla::XlaOp conv_input, + xla::XlaOp filter, + const ConvOpAttrs& attrs) { + TF_RETURN_IF_ERROR(CheckConvAttrs(attrs)); + + auto* builder = conv_input.builder(); + TF_ASSIGN_OR_RETURN(xla::Shape input_shape, builder->GetShape(conv_input)); + // Filter has the form [filter_rows, filter_cols, ..., in_depth, out_depth] + TF_ASSIGN_OR_RETURN(xla::Shape filter_shape, builder->GetShape(filter)); + + // For 2D convolution, there should be 4 dimensions. + int num_dims = attrs.num_spatial_dims + 2; + if (input_shape.dimensions_size() != num_dims) { + return errors::InvalidArgument("input must be ", num_dims, "-dimensional", + input_shape.DebugString()); + } + if (filter_shape.dimensions_size() != num_dims) { + return errors::InvalidArgument( + "filter must be ", num_dims, + "-dimensional: ", filter_shape.DebugString()); + } + + // The last two dimensions of the filter are the input and output shapes. + int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format); + int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format); + + int64 in_depth = filter_shape.dimensions(attrs.num_spatial_dims); + // The 'C' dimension for input is in_depth. It must be the same as + // the filter's in_depth. + if (in_depth != input_shape.dimensions(feature_dim)) { + return errors::InvalidArgument( + "input and filter must have the same depth: ", in_depth, " vs ", + input_shape.dimensions(feature_dim)); + } + + if (attrs.depthwise) { + filter = ReshapeFilterForDepthwiseConvolution(filter_shape, filter); + } + + xla::ConvolutionDimensionNumbers dims; + std::vector<int64> window_strides(attrs.num_spatial_dims); + std::vector<int64> lhs_dilation(attrs.num_spatial_dims, 1); + std::vector<int64> rhs_dilation(attrs.num_spatial_dims); + std::vector<std::pair<int64, int64>> padding(attrs.num_spatial_dims); + + dims.set_input_batch_dimension(batch_dim); + dims.set_output_batch_dimension(batch_dim); + dims.set_input_feature_dimension(feature_dim); + dims.set_output_feature_dimension(feature_dim); + dims.set_kernel_input_feature_dimension(attrs.num_spatial_dims); + dims.set_kernel_output_feature_dimension(attrs.num_spatial_dims + 1); + + for (int i = 0; i < attrs.num_spatial_dims; ++i) { + const int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i); + dims.add_input_spatial_dimensions(dim); + dims.add_kernel_spatial_dimensions(i); + dims.add_output_spatial_dimensions(dim); + window_strides[i] = attrs.strides.at(dim); + rhs_dilation[i] = attrs.dilations.at(dim); + + int64 unused_output_size; + TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerboseV2( + input_shape.dimensions(dim), filter_shape.dimensions(i), + rhs_dilation[i], window_strides[i], attrs.padding, &unused_output_size, + &padding[i].first, &padding[i].second)); + } + + return xla::ConvGeneralDilated( + conv_input, filter, window_strides, padding, lhs_dilation, rhs_dilation, + dims, /*feature_group_count=*/attrs.depthwise ? in_depth : 1); +} + +xla::StatusOr<xla::XlaOp> MakeXlaBackpropInputConvOp( + StringPiece type_string, const xla::Shape& input_shape, xla::XlaOp filter, + xla::XlaOp out_backprop, const ConvOpAttrs& attrs) { + TF_RETURN_IF_ERROR(CheckConvAttrs(attrs)); + + int num_dims = attrs.num_spatial_dims + 2; + int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format); + int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format); + + auto* builder = filter.builder(); + TF_ASSIGN_OR_RETURN(xla::Shape filter_shape, builder->GetShape(filter)); + TF_ASSIGN_OR_RETURN(xla::Shape out_backprop_shape, + builder->GetShape(out_backprop)); + + xla::Shape expanded_filter_shape = + attrs.depthwise ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape) + : filter_shape; + // Reuse dimension computation logic from conv_grad_ops.cc. + ConvBackpropDimensions dims; + TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes( + type_string, attrs.num_spatial_dims, input_shape, expanded_filter_shape, + out_backprop_shape, attrs.dilations, attrs.strides, attrs.padding, + attrs.data_format, &dims)); + + // The input gradients are computed by a convolution of the output + // gradients and the filter, with some appropriate padding. See the + // comment at the top of conv_grad_ops.h for details. + + xla::ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(batch_dim); + dnums.set_output_batch_dimension(batch_dim); + dnums.set_input_feature_dimension(feature_dim); + dnums.set_output_feature_dimension(feature_dim); + + // TF filter shape is [ H, W, ..., inC, outC ] + // Transpose the input and output features for computing the gradient. + dnums.set_kernel_input_feature_dimension(attrs.num_spatial_dims + 1); + dnums.set_kernel_output_feature_dimension(attrs.num_spatial_dims); + + std::vector<int64> kernel_spatial_dims(attrs.num_spatial_dims); + std::vector<std::pair<int64, int64>> padding(attrs.num_spatial_dims); + std::vector<int64> lhs_dilation(attrs.num_spatial_dims); + std::vector<int64> rhs_dilation(attrs.num_spatial_dims); + std::vector<int64> ones(attrs.num_spatial_dims, 1); + for (int i = 0; i < attrs.num_spatial_dims; ++i) { + int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i); + dnums.add_input_spatial_dimensions(dim); + dnums.add_kernel_spatial_dimensions(i); + dnums.add_output_spatial_dimensions(dim); + + kernel_spatial_dims[i] = i; + padding[i] = {dims.spatial_dims[i].pad_before, + dims.spatial_dims[i].pad_after}; + lhs_dilation[i] = dims.spatial_dims[i].stride; + rhs_dilation[i] = attrs.dilations[dim]; + } + + // Mirror the filter in the spatial dimensions. + xla::XlaOp mirrored_weights = xla::Rev(filter, kernel_spatial_dims); + + // activation gradients + // = gradients (with padding and dilation) <conv> mirrored_weights + return xla::ConvGeneralDilated( + out_backprop, mirrored_weights, /*window_strides=*/ones, padding, + lhs_dilation, rhs_dilation, dnums, + /*feature_group_count=*/ + attrs.depthwise ? out_backprop_shape.dimensions(feature_dim) / + filter_shape.dimensions(attrs.num_spatial_dims + 1) + : 1); +} + +xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp( + StringPiece type_string, xla::XlaOp activations, + const xla::Shape& filter_shape, xla::XlaOp gradients, + const ConvOpAttrs& attrs) { + TF_RETURN_IF_ERROR(CheckConvAttrs(attrs)); + + auto* builder = activations.builder(); + TF_ASSIGN_OR_RETURN(xla::Shape activations_shape, + builder->GetShape(activations)); + TF_ASSIGN_OR_RETURN(xla::Shape out_backprop_shape, + builder->GetShape(gradients)); + const xla::Shape expanded_filter_shape = + attrs.depthwise ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape) + : filter_shape; + + // Reuse dimension computation logic from conv_grad_ops.cc. + ConvBackpropDimensions dims; + TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes( + type_string, attrs.num_spatial_dims, activations_shape, + expanded_filter_shape, out_backprop_shape, attrs.dilations, attrs.strides, + attrs.padding, attrs.data_format, &dims)); + + // The filter gradients are computed by a convolution of the input + // activations and the output gradients, with some appropriate padding. + // See the comment at the top of conv_grad_ops.h for details. + + xla::ConvolutionDimensionNumbers dnums; + + // The activations (inputs) form the LHS of the convolution. + // Activations have shape: [batch, in_rows, in_cols, ..., in_depth] + // For the gradient computation, we flip the roles of the batch and + // feature dimensions. + // Each spatial entry has size in_depth * batch + + // The last two dimensions of the filter are the input and output shapes. + int num_dims = attrs.num_spatial_dims + 2; + int n_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format); + int c_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format); + + // Swap n_dim and c_dim in the activations. + dnums.set_input_batch_dimension(c_dim); + dnums.set_input_feature_dimension(n_dim); + + // The gradients become the RHS of the convolution. + // The gradients have shape [batch, out_rows, out_cols, ..., out_depth] + // where the batch becomes the input feature for the convolution. + dnums.set_kernel_input_feature_dimension(n_dim); + dnums.set_kernel_output_feature_dimension(c_dim); + + std::vector<std::pair<int64, int64>> padding(attrs.num_spatial_dims); + std::vector<int64> rhs_dilation(attrs.num_spatial_dims); + std::vector<int64> window_strides(attrs.num_spatial_dims); + std::vector<int64> ones(attrs.num_spatial_dims, 1); + + // Tensorflow filter shape is [ H, W, ..., inC, outC ]. + for (int i = 0; i < attrs.num_spatial_dims; ++i) { + dnums.add_output_spatial_dimensions(i); + } + dnums.set_output_batch_dimension(attrs.num_spatial_dims); + dnums.set_output_feature_dimension(attrs.num_spatial_dims + 1); + + for (int i = 0; i < attrs.num_spatial_dims; ++i) { + int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i); + dnums.add_input_spatial_dimensions(dim); + dnums.add_kernel_spatial_dimensions(dim); + + // We will also need to pad the input with zeros such that after the + // convolution, we get the right size for the filter. + // The padded_in_rows should be such that when we convolve this with the + // expanded_out_rows as a filter, we should get filter_rows back. + // + const int64 padded_in_size = + dims.spatial_dims[i].expanded_output_size + + (dims.spatial_dims[i].filter_size - 1) * attrs.dilations[dim]; + + // However it can be smaller than input_rows: in this + // case it means some of the inputs are not used. + // + // An example is to have input_cols = 3, filter_cols = 2 and stride = 2: + // + // INPUT = [ A B C ] + // + // FILTER = [ x y ] + // + // and the output will only have one column: a = A * x + B * y + // + // and input "C" is not used at all. + // + // We apply negative padding in this case. + const int64 pad_total = padded_in_size - dims.spatial_dims[i].input_size; + + // + For the VALID padding, we don't pad anything on the top/left side + // and pad the bottom/right side with the remaining space. + // + For the SAME padding, we pad top/left side the same as bottom/right + // side. + // + // In addition, if the padded input size is smaller than the input size, + // we need to ignore some training elements of the input. We do this by + // applying negative padding on the right/bottom. + const int64 pad_before = + attrs.padding == Padding::SAME ? std::max<int64>(pad_total / 2, 0) : 0; + + padding[i] = {pad_before, pad_total - pad_before}; + rhs_dilation[i] = dims.spatial_dims[i].stride; + window_strides[i] = attrs.dilations[dim]; + } + + // Besides padding the input, we will also expand output_rows to + // expanded_out_rows = (output_rows - 1) * stride + 1 + // with zeros in between: + // + // a . . . b . . . c . . . d . . . e + // + // This is done by specifying the window dilation factors in the + // convolution HLO below. + auto filter_backprop = + xla::ConvGeneralDilated(activations, gradients, window_strides, padding, + /*lhs_dilation=*/ones, rhs_dilation, dnums); + + if (attrs.depthwise) { + filter_backprop = ContractFilterForDepthwiseBackprop( + filter_shape, filter_backprop, activations.builder()); + } + + return filter_backprop; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h new file mode 100644 index 0000000000..6e1b70a478 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h @@ -0,0 +1,69 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_CONV_OP_HELPERS_H_ +#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_CONV_OP_HELPERS_H_ + +#include <vector> + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/util/padding.h" +#include "tensorflow/core/util/tensor_format.h" + +// This header exposes utilities for translating TensorFlow convolution ops into +// XLA ops. +// +// conv_ops.cc contains lowerings for many of these TF convolution ops (e.g. +// Conv2D, Conv3DBackpropFilterV2), but you might want to use the utilities in +// this header to implement a new and exciting convolution op, for example a +// fused TensorFlow op that contains a convolution and other things. + +namespace tensorflow { + +// ConvOpAttrs contains all of the metadata necessary to specify a TF or XLA +// convolution. +struct ConvOpAttrs { + // Constructs a ConvOpAttrs, reading most of the attributes from `ctx`. + static xla::StatusOr<ConvOpAttrs> Create(int num_spatial_dims, bool depthwise, + OpKernelConstruction* ctx); + + bool depthwise; + int num_spatial_dims; + std::vector<int32> dilations; + std::vector<int32> strides; + Padding padding; + TensorFormat data_format; +}; + +// Creates a new XLA forward or backward convolution with the given inputs and +// attributes. +xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(StringPiece type_string, + xla::XlaOp conv_input, + xla::XlaOp filter, + const ConvOpAttrs& attrs); +xla::StatusOr<xla::XlaOp> MakeXlaBackpropInputConvOp( + StringPiece type_string, const xla::Shape& input_shape, xla::XlaOp filter, + xla::XlaOp out_backprop, const ConvOpAttrs& attrs); +xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp( + StringPiece type_string, xla::XlaOp activations, + const xla::Shape& filter_shape, xla::XlaOp gradients, + const ConvOpAttrs& attrs); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_CONV_OP_HELPERS_H_ diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc index 674720e22f..cd7c820be0 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc @@ -15,12 +15,17 @@ limitations under the License. // XLA-specific Ops for 2D convolution. +#include "tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" @@ -33,250 +38,28 @@ limitations under the License. #include "tensorflow/core/util/tensor_format.h" namespace tensorflow { - namespace { -// Returns the expanded size of a filter used for depthwise convolution. -// If `shape` is [H, W, ..., M, N] returns [H, W, ..., M, M*N]. -TensorShape ExpandedFilterShapeForDepthwiseConvolution( - const TensorShape& shape) { - int num_dims = shape.dims(); - CHECK_GE(num_dims, 2); - TensorShape expanded_shape = shape; - expanded_shape.set_dim(num_dims - 1, shape.dim_size(num_dims - 2) * - shape.dim_size(num_dims - 1)); - return expanded_shape; -} - -// Broadcast zeros to ExpandedFilterShapeForDepthwiseConvolution. -xla::XlaOp CreateExpandedZero(const TensorShape& filter_shape, DataType dtype, - xla::XlaBuilder* builder) { - TensorShape expanded_filter_shape = - ExpandedFilterShapeForDepthwiseConvolution(filter_shape); - return xla::Broadcast(XlaHelpers::Zero(builder, dtype), - expanded_filter_shape.dim_sizes()); -} - -// Create a mask for depthwise convolution that will make a normal convolution -// produce the same results as a depthwise convolution. For a [2, 2, 3, 2] -// depthwise filter this returns a [2, 2, 3, 6] tensor -// 1 1 0 0 0 0 1 1 0 0 0 0 -// 0 0 1 1 0 0 0 0 1 1 0 0 -// 0 0 0 0 1 1 0 0 0 0 1 1 -// -// 1 1 0 0 0 0 1 1 0 0 0 0 -// 0 0 1 1 0 0 0 0 1 1 0 0 -// 0 0 0 0 1 1 0 0 0 0 1 1 -// -// The first step is to create a one tensor, A, that is [3] -// 0 1 2 -// -// and another tensor, B, that is [3 * 2] -// 0 1 2 3 4 5 -// -// and divide B it by 2 to get -// 0 0 1 1 2 2 -// -// then we broadcast the B to [2, 2, 3, 3 * 2] -// 0 0 1 1 2 2 0 0 1 1 2 2 -// 0 0 1 1 2 2 0 0 1 1 2 2 -// 0 0 1 1 2 2 0 0 1 1 2 2 -// -// 0 0 1 1 2 2 0 0 1 1 2 2 -// 0 0 1 1 2 2 0 0 1 1 2 2 -// 0 0 1 1 2 2 0 0 1 1 2 2 -// -// Finally compare A and broadcasted B in dimension 2 amd return the result at -// the beginning of the comment. -xla::XlaOp CreateExpandedFilterMask(const TensorShape& filter_shape, - xla::XlaBuilder* builder) { - TensorShape expanded_filter_shape = - ExpandedFilterShapeForDepthwiseConvolution(filter_shape); - int64 depthwise_multiplier = filter_shape.dim_size(filter_shape.dims() - 1); - int64 input_feature = filter_shape.dim_size(filter_shape.dims() - 2); - - // Create a M sized linspace and an M*N sized linspace that will be - // broadcasted into perpendicular dimensions and compared. - xla::XlaOp input_feature_iota = xla::Iota(builder, xla::S32, input_feature); - xla::XlaOp expanded_feature_iota = - xla::Iota(builder, xla::S32, input_feature * depthwise_multiplier); - - // Divide the M*N sized linspace by the depthwise_multiplier to create - // [0 0 1 1 2 2] in the example in the function comment. - expanded_feature_iota = - xla::Div(expanded_feature_iota, - XlaHelpers::IntegerLiteral(builder, DataType::DT_INT32, - depthwise_multiplier)); - - // Broadcast the N*M linspace to [H, W, ..., M, M*N]. - auto expanded_feature_broadcast_dims = expanded_filter_shape.dim_sizes(); - expanded_feature_broadcast_dims.pop_back(); - auto broadcasted_expanded_feature_iota = - xla::Broadcast(expanded_feature_iota, expanded_feature_broadcast_dims); - - // Compare the broadcasted linspace to the input feature linspace in the - // input feature dimension to create a diagonal predicate. - return xla::Eq(broadcasted_expanded_feature_iota, input_feature_iota, - {expanded_filter_shape.dims() - 2}); -} - -// Reshapes a filter of shape [H, W, ..., M, N] to [H, W, ..., 1, M*N]. Used to -// build a depthwise convolution. -xla::XlaOp ReshapeFilterForDepthwiseConvolution(const TensorShape& filter_shape, - const xla::XlaOp& filter) { - int64 input_feature_dim = filter_shape.dims() - 2; - int64 output_feature_dim = filter_shape.dims() - 1; - int64 depthwise_multiplier = filter_shape.dim_size(output_feature_dim); - int64 input_feature = filter_shape.dim_size(input_feature_dim); - - // Create a [H, W, ..., 1, N*M] reshape of the filter. - TensorShape implicit_broadcast_filter_shape = filter_shape; - implicit_broadcast_filter_shape.set_dim(input_feature_dim, 1); - implicit_broadcast_filter_shape.set_dim(output_feature_dim, - depthwise_multiplier * input_feature); - return xla::Reshape(filter, implicit_broadcast_filter_shape.dim_sizes()); -} - -// Reduces the results of the convolution with an expanded filter to the -// non-expanded filter. -xla::XlaOp ContractFilterForDepthwiseBackprop(XlaOpKernelContext* ctx, - const TensorShape& filter_shape, - DataType dtype, - const xla::XlaOp& filter_backprop, - xla::XlaBuilder* builder) { - auto masked_expanded_filter = xla::Select( - CreateExpandedFilterMask(filter_shape, builder), filter_backprop, - CreateExpandedZero(filter_shape, dtype, builder)); - return xla::Reshape( - // This reduce does not need inputs to be converted with - // XlaHelpers::SumAccumulationType() since the ExpandedFilterMask with - // ExpandedZero guarantees that only one element is non zero, so there - // cannot be accumulated precision error. - xla::Reduce(masked_expanded_filter, XlaHelpers::Zero(builder, dtype), - *ctx->GetOrCreateAdd(dtype), {filter_shape.dims() - 2}), - filter_shape.dim_sizes()); -} - class ConvOp : public XlaOpKernel { public: explicit ConvOp(OpKernelConstruction* ctx, int num_spatial_dims, bool depthwise) - : XlaOpKernel(ctx), - num_spatial_dims_(num_spatial_dims), - depthwise_(depthwise) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); - - string data_format; - OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); - OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), - errors::InvalidArgument("Invalid data format")); + : XlaOpKernel(ctx) { + xla::StatusOr<ConvOpAttrs> attrs = + ConvOpAttrs::Create(num_spatial_dims, depthwise, ctx); + OP_REQUIRES_OK(ctx, attrs.status()); + attrs_ = attrs.ValueOrDie(); } - int num_dims() const { return num_spatial_dims_ + 2; } - void Compile(XlaOpKernelContext* ctx) override { - OP_REQUIRES(ctx, strides_.size() == num_dims(), - errors::InvalidArgument("Sliding window strides field must " - "specify ", - num_dims(), " dimensions")); - int batch_dim = GetTensorBatchDimIndex(num_dims(), data_format_); - int feature_dim = GetTensorFeatureDimIndex(num_dims(), data_format_); - OP_REQUIRES( - ctx, strides_[batch_dim] == 1 && strides_[feature_dim] == 1, - errors::Unimplemented("Current implementation does not yet support " - "strides in the batch and depth dimensions.")); - - OP_REQUIRES(ctx, dilations_.size() == num_dims(), - errors::InvalidArgument("Dilations field must " - "specify ", - num_dims(), " dimensions")); - OP_REQUIRES( - ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1, - errors::Unimplemented("Current implementation does not support " - "dilations in the batch and depth dimensions.")); - for (int i = 0; i < num_spatial_dims_; ++i) { - int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); - OP_REQUIRES(ctx, dilations_[input_dim] >= 1, - errors::Unimplemented("Dilation values must be positive; ", i, - "th spatial dimension had dilation ", - dilations_[input_dim])); - } - - const TensorShape input_shape = ctx->InputShape(0); - // Input filter is of the following dimensions: - // [ filter_rows, filter_cols, ..., in_depth, out_depth] - const TensorShape filter_shape = ctx->InputShape(1); - - // For 2D convolution, there should be 4 dimensions. - OP_REQUIRES( - ctx, input_shape.dims() == num_dims(), - errors::InvalidArgument("input must be ", num_dims(), "-dimensional", - input_shape.DebugString())); - OP_REQUIRES( - ctx, filter_shape.dims() == num_dims(), - errors::InvalidArgument("filter must be ", num_dims(), - "-dimensional: ", filter_shape.DebugString())); - - // The last two dimension of the filter are the input and output shapes. - const int64 in_depth = filter_shape.dim_size(num_spatial_dims_); - - // The 'C' dimension for input is in_depth. It must be the same as - // the filter's in_depth. - OP_REQUIRES(ctx, in_depth == input_shape.dim_size(feature_dim), - errors::InvalidArgument( - "input and filter must have the same depth: ", in_depth, - " vs ", input_shape.dim_size(feature_dim))); - - xla::XlaOp filter = ctx->Input(1); - if (depthwise_) { - filter = ReshapeFilterForDepthwiseConvolution(filter_shape, filter); - } - - xla::ConvolutionDimensionNumbers dims; - std::vector<int64> window_strides(num_spatial_dims_); - std::vector<int64> lhs_dilation(num_spatial_dims_, 1); - std::vector<int64> rhs_dilation(num_spatial_dims_); - std::vector<std::pair<int64, int64>> padding(num_spatial_dims_); - - dims.set_input_batch_dimension(batch_dim); - dims.set_output_batch_dimension(batch_dim); - dims.set_input_feature_dimension(feature_dim); - dims.set_output_feature_dimension(feature_dim); - dims.set_kernel_input_feature_dimension(num_spatial_dims_); - dims.set_kernel_output_feature_dimension(num_spatial_dims_ + 1); - - for (int i = 0; i < num_spatial_dims_; ++i) { - const int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); - dims.add_input_spatial_dimensions(dim); - dims.add_kernel_spatial_dimensions(i); - dims.add_output_spatial_dimensions(dim); - window_strides[i] = strides_.at(dim); - rhs_dilation[i] = dilations_.at(dim); - - int64 unused_output_size; - OP_REQUIRES_OK( - ctx, GetWindowedOutputSizeVerboseV2( - input_shape.dim_size(dim), filter_shape.dim_size(i), - rhs_dilation[i], window_strides[i], padding_, - &unused_output_size, &padding[i].first, &padding[i].second)); - } - - xla::XlaOp conv = xla::ConvGeneralDilated( - ctx->Input(0), filter, window_strides, padding, lhs_dilation, - rhs_dilation, dims, - /*feature_group_count=*/depthwise_ ? in_depth : 1); - ctx->SetOutput(0, conv); + xla::StatusOr<xla::XlaOp> conv = MakeXlaForwardConvOp( + ctx->op_kernel().type_string(), ctx->Input(0), ctx->Input(1), attrs_); + OP_REQUIRES_OK(ctx, conv.status()); + ctx->SetOutput(0, conv.ValueOrDie()); } protected: - const int num_spatial_dims_; - const bool depthwise_; - std::vector<int32> dilations_; - std::vector<int32> strides_; - Padding padding_; - TensorFormat data_format_ = FORMAT_NHWC; + ConvOpAttrs attrs_; private: TF_DISALLOW_COPY_AND_ASSIGN(ConvOp); @@ -308,124 +91,28 @@ class ConvBackpropInputOp : public XlaOpKernel { public: explicit ConvBackpropInputOp(OpKernelConstruction* ctx, int num_spatial_dims, bool depthwise) - : XlaOpKernel(ctx), - num_spatial_dims_(num_spatial_dims), - depthwise_(depthwise) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); - string data_format; - OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); - OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), - errors::InvalidArgument("Invalid data format")); + : XlaOpKernel(ctx) { + xla::StatusOr<ConvOpAttrs> attrs = + ConvOpAttrs::Create(num_spatial_dims, depthwise, ctx); + OP_REQUIRES_OK(ctx, attrs.status()); + attrs_ = attrs.ValueOrDie(); } - int num_dims() const { return num_spatial_dims_ + 2; } - void Compile(XlaOpKernelContext* ctx) override { - OP_REQUIRES(ctx, strides_.size() == num_dims(), - errors::InvalidArgument("Sliding window strides field must " - "specify ", - num_dims(), " dimensions")); - int batch_dim = GetTensorBatchDimIndex(num_dims(), data_format_); - int feature_dim = GetTensorFeatureDimIndex(num_dims(), data_format_); - OP_REQUIRES( - ctx, strides_[batch_dim] == 1 && strides_[feature_dim] == 1, - errors::Unimplemented("Current implementation does not yet support " - "strides in the batch and depth dimensions.")); - - OP_REQUIRES(ctx, dilations_.size() == num_dims(), - errors::InvalidArgument("Dilations field must " - "specify ", - num_dims(), " dimensions")); - OP_REQUIRES( - ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1, - errors::Unimplemented("Current implementation does not support " - "dilations in the batch and depth dimensions.")); - for (int i = 0; i < num_spatial_dims_; ++i) { - int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); - OP_REQUIRES(ctx, dilations_[input_dim] >= 1, - errors::Unimplemented("Dilation values must be positive; ", i, - "th spatial dimension had dilation ", - dilations_[input_dim])); - } - - TensorShape input_shape; - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_shape)); - - const TensorShape filter_shape = ctx->InputShape(1); - const TensorShape out_backprop_shape = ctx->InputShape(2); - - const TensorShape expanded_filter_shape = - depthwise_ ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape) - : filter_shape; - // Reuse dimension computation logic from conv_grad_ops.cc. - ConvBackpropDimensions dims; - OP_REQUIRES_OK(ctx, - ConvBackpropComputeDimensionsV2( - type_string(), num_spatial_dims_, input_shape, - expanded_filter_shape, out_backprop_shape, dilations_, - strides_, padding_, data_format_, &dims)); - - auto filter = ctx->Input(1); - auto out_backprop = ctx->Input(2); - - // The input gradients are computed by a convolution of the output - // gradients and the filter, with some appropriate padding. See the - // comment at the top of conv_grad_ops.h for details. - - xla::ConvolutionDimensionNumbers dnums; - dnums.set_input_batch_dimension(batch_dim); - dnums.set_output_batch_dimension(batch_dim); - dnums.set_input_feature_dimension(feature_dim); - dnums.set_output_feature_dimension(feature_dim); - - // TF filter shape is [ H, W, ..., inC, outC ] - // Transpose the input and output features for computing the gradient. - dnums.set_kernel_input_feature_dimension(num_spatial_dims_ + 1); - dnums.set_kernel_output_feature_dimension(num_spatial_dims_); - - std::vector<int64> kernel_spatial_dims(num_spatial_dims_); - std::vector<std::pair<int64, int64>> padding(num_spatial_dims_); - std::vector<int64> lhs_dilation(num_spatial_dims_); - std::vector<int64> rhs_dilation(num_spatial_dims_); - std::vector<int64> ones(num_spatial_dims_, 1); - for (int i = 0; i < num_spatial_dims_; ++i) { - int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); - dnums.add_input_spatial_dimensions(dim); - dnums.add_kernel_spatial_dimensions(i); - dnums.add_output_spatial_dimensions(dim); - - kernel_spatial_dims[i] = i; - padding[i] = {dims.spatial_dims[i].pad_before, - dims.spatial_dims[i].pad_after}; - lhs_dilation[i] = dims.spatial_dims[i].stride; - rhs_dilation[i] = dilations_[dim]; - } - - // Mirror the filter in the spatial dimensions. - xla::XlaOp mirrored_weights = xla::Rev(filter, kernel_spatial_dims); - - // activation gradients - // = gradients (with padding and dilation) <conv> mirrored_weights - xla::XlaOp in_backprop = xla::ConvGeneralDilated( - out_backprop, mirrored_weights, /*window_strides=*/ones, padding, - lhs_dilation, rhs_dilation, dnums, - /*feature_group_count=*/ - depthwise_ ? out_backprop_shape.dim_size(feature_dim) / - filter_shape.dim_size(num_spatial_dims_ + 1) - : 1); - - ctx->SetOutput(0, in_backprop); + TensorShape input_tensor_shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_tensor_shape)); + xla::Shape input_shape = + TensorShapeToXLAShape(ctx->input_xla_type(1), input_tensor_shape); + + xla::StatusOr<xla::XlaOp> in_backprop = + MakeXlaBackpropInputConvOp(ctx->op_kernel().type_string(), input_shape, + ctx->Input(1), ctx->Input(2), attrs_); + OP_REQUIRES_OK(ctx, in_backprop.status()); + ctx->SetOutput(0, in_backprop.ValueOrDie()); } protected: - const int num_spatial_dims_; - const bool depthwise_; - std::vector<int32> dilations_; - std::vector<int32> strides_; - Padding padding_; - TensorFormat data_format_ = FORMAT_NHWC; + ConvOpAttrs attrs_; private: TF_DISALLOW_COPY_AND_ASSIGN(ConvBackpropInputOp); @@ -462,172 +149,28 @@ class ConvBackpropFilterOp : public XlaOpKernel { public: explicit ConvBackpropFilterOp(OpKernelConstruction* ctx, int num_spatial_dims, bool depthwise) - : XlaOpKernel(ctx), - num_spatial_dims_(num_spatial_dims), - depthwise_(depthwise) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); - string data_format; - OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); - OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), - errors::InvalidArgument("Invalid data format")); + : XlaOpKernel(ctx) { + xla::StatusOr<ConvOpAttrs> attrs = + ConvOpAttrs::Create(num_spatial_dims, depthwise, ctx); + OP_REQUIRES_OK(ctx, attrs.status()); + attrs_ = attrs.ValueOrDie(); } - int num_dims() const { return num_spatial_dims_ + 2; } - void Compile(XlaOpKernelContext* ctx) override { - const int n_dim = GetTensorBatchDimIndex(num_dims(), data_format_); - const int c_dim = GetTensorFeatureDimIndex(num_dims(), data_format_); - - OP_REQUIRES( - ctx, (strides_[n_dim] == 1 && strides_[c_dim] == 1), - errors::InvalidArgument("Current implementation does not yet support " - "strides in the batch and depth dimensions.")); - - OP_REQUIRES(ctx, dilations_.size() == num_dims(), - errors::InvalidArgument("Dilations field must " - "specify ", - num_dims(), " dimensions")); - OP_REQUIRES( - ctx, dilations_[n_dim] == 1 && dilations_[c_dim] == 1, - errors::Unimplemented("Current implementation does not support " - "dilations in the batch and depth dimensions.")); - for (int i = 0; i < num_spatial_dims_; ++i) { - int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); - OP_REQUIRES(ctx, dilations_[input_dim] >= 1, - errors::Unimplemented("Dilation values must be positive; ", i, - "th spatial dimension had dilation ", - dilations_[input_dim])); - } - - const TensorShape activations_shape = ctx->InputShape(0); - TensorShape filter_shape; - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(1, &filter_shape)); - const TensorShape out_backprop_shape = ctx->InputShape(2); - - const TensorShape expanded_filter_shape = - depthwise_ ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape) - : filter_shape; - - // Reuse dimension computation logic from conv_grad_ops.cc. - ConvBackpropDimensions dims; - OP_REQUIRES_OK(ctx, - ConvBackpropComputeDimensionsV2( - type_string(), num_spatial_dims_, activations_shape, - expanded_filter_shape, out_backprop_shape, dilations_, - strides_, padding_, data_format_, &dims)); - - xla::XlaBuilder* b = ctx->builder(); - xla::XlaOp activations = ctx->Input(0); - xla::XlaOp gradients = ctx->Input(2); - - // The filter gradients are computed by a convolution of the input - // activations and the output gradients, with some appropriate padding. - // See the comment at the top of conv_grad_ops.h for details. - - xla::ConvolutionDimensionNumbers dnums; - - // The activations (inputs) form the LHS of the convolution. - // Activations have shape: [batch, in_rows, in_cols, ..., in_depth] - // For the gradient computation, we flip the roles of the batch and - // feature dimensions. - // Each spatial entry has size in_depth * batch - - // Swap n_dim and c_dim in the activations. - dnums.set_input_batch_dimension(c_dim); - dnums.set_input_feature_dimension(n_dim); - - // The gradients become the RHS of the convolution. - // The gradients have shape [batch, out_rows, out_cols, ..., out_depth] - // where the batch becomes the input feature for the convolution. - dnums.set_kernel_input_feature_dimension(n_dim); - dnums.set_kernel_output_feature_dimension(c_dim); - - std::vector<std::pair<int64, int64>> padding(num_spatial_dims_); - std::vector<int64> rhs_dilation(num_spatial_dims_); - std::vector<int64> window_strides(num_spatial_dims_); - std::vector<int64> ones(num_spatial_dims_, 1); - - // Tensorflow filter shape is [ H, W, ..., inC, outC ]. - for (int i = 0; i < num_spatial_dims_; ++i) { - dnums.add_output_spatial_dimensions(i); - } - dnums.set_output_batch_dimension(num_spatial_dims_); - dnums.set_output_feature_dimension(num_spatial_dims_ + 1); - - for (int i = 0; i < num_spatial_dims_; ++i) { - int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); - dnums.add_input_spatial_dimensions(dim); - dnums.add_kernel_spatial_dimensions(dim); - - // We will also need to pad the input with zeros such that after the - // convolution, we get the right size for the filter. - // The padded_in_rows should be such that when we convolve this with the - // expanded_out_rows as a filter, we should get filter_rows back. - // - const int64 padded_in_size = - dims.spatial_dims[i].expanded_output_size + - (dims.spatial_dims[i].filter_size - 1) * dilations_[dim]; - - // However it can be smaller than input_rows: in this - // case it means some of the inputs are not used. - // - // An example is to have input_cols = 3, filter_cols = 2 and stride = 2: - // - // INPUT = [ A B C ] - // - // FILTER = [ x y ] - // - // and the output will only have one column: a = A * x + B * y - // - // and input "C" is not used at all. - // - // We apply negative padding in this case. - const int64 pad_total = padded_in_size - dims.spatial_dims[i].input_size; - - // + For the VALID padding, we don't pad anything on the top/left side - // and pad the bottom/right side with the remaining space. - // + For the SAME padding, we pad top/left side the same as bottom/right - // side. - // - // In addition, if the padded input size is smaller than the input size, - // we need to ignore some training elements of the input. We do this by - // applying negative padding on the right/bottom. - const int64 pad_before = - padding_ == Padding::SAME ? std::max<int64>(pad_total / 2, 0) : 0; - - padding[i] = {pad_before, pad_total - pad_before}; - rhs_dilation[i] = dims.spatial_dims[i].stride; - window_strides[i] = dilations_[dim]; - } - - // Besides padding the input, we will also expand output_rows to - // expanded_out_rows = (output_rows - 1) * stride + 1 - // with zeros in between: - // - // a . . . b . . . c . . . d . . . e - // - // This is done by specifying the window dilation factors in the - // convolution HLO below. - auto filter_backprop = - xla::ConvGeneralDilated(activations, gradients, window_strides, padding, - /*lhs_dilation=*/ones, rhs_dilation, dnums); - - if (depthwise_) { - filter_backprop = ContractFilterForDepthwiseBackprop( - ctx, filter_shape, ctx->input_type(0), filter_backprop, b); - } - ctx->SetOutput(0, filter_backprop); + TensorShape filter_tensor_shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(1, &filter_tensor_shape)); + xla::Shape filter_shape = + TensorShapeToXLAShape(ctx->input_xla_type(0), filter_tensor_shape); + + xla::StatusOr<xla::XlaOp> filter_backprop = MakeXlaBackpropFilterConvOp( + ctx->op_kernel().type_string(), ctx->Input(0), filter_shape, + ctx->Input(2), attrs_); + OP_REQUIRES_OK(ctx, filter_backprop.status()); + ctx->SetOutput(0, filter_backprop.ValueOrDie()); } protected: - const int num_spatial_dims_; - const bool depthwise_; - std::vector<int32> dilations_; - std::vector<int32> strides_; - Padding padding_; - TensorFormat data_format_ = FORMAT_NHWC; + ConvOpAttrs attrs_; private: TF_DISALLOW_COPY_AND_ASSIGN(ConvBackpropFilterOp); diff --git a/tensorflow/compiler/tf2xla/shape_util.cc b/tensorflow/compiler/tf2xla/shape_util.cc index 9d1992205b..b589512dcd 100644 --- a/tensorflow/compiler/tf2xla/shape_util.cc +++ b/tensorflow/compiler/tf2xla/shape_util.cc @@ -41,6 +41,14 @@ Status XLAShapeToTensorShape(const xla::Shape& shape, // Convert a TensorShape into the equivalent XLA Shape proto. Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape, xla::Shape* shape) { + xla::PrimitiveType type; + TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(dtype, &type)); + *shape = TensorShapeToXLAShape(type, tensor_shape); + return Status::OK(); +} + +xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type, + const TensorShape& tensor_shape) { int rank = tensor_shape.dims(); std::vector<int64> dimensions(rank); std::vector<int64> layout(rank); @@ -50,11 +58,7 @@ Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape, // XLA uses minor-to-major; Tensorflow uses major-to-minor. std::iota(layout.rbegin(), layout.rend(), 0); - xla::PrimitiveType type; - TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(dtype, &type)); - - *shape = xla::ShapeUtil::MakeShapeWithLayout(type, dimensions, layout); - return Status::OK(); + return xla::ShapeUtil::MakeShapeWithLayout(type, dimensions, layout); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/shape_util.h b/tensorflow/compiler/tf2xla/shape_util.h index 58240b9c96..f7e34a5b40 100644 --- a/tensorflow/compiler/tf2xla/shape_util.h +++ b/tensorflow/compiler/tf2xla/shape_util.h @@ -35,6 +35,11 @@ Status XLAShapeToTensorShape(const xla::Shape& shape, Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape, xla::Shape* shape); +// Converts a TensorShape into the equivalent XLA Shape proto, taking an +// xla::PrimitiveType to specify the element type. This never fails. +xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type, + const TensorShape& tensor_shape); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_SHAPE_UTIL_H_ |