aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Blake Hechtman <blakehechtman@google.com>2017-11-30 11:13:19 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-30 11:17:10 -0800
commitea1c29552b01f3404e27999a27a1919b3accc594 (patch)
treef943b90431dc0710f9358760169470e56f98b09d
parent4e8301be75a234d53b08bec577ac0069fc40bea3 (diff)
Change depthwise convolution filter expansion and contraction with
algebraic manipulation instead of slices and pads that are more difficult to fuse. PiperOrigin-RevId: 177480353
-rw-r--r--tensorflow/compiler/tf2xla/kernels/conv_ops.cc166
1 files changed, 112 insertions, 54 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
index c150394c07..61f4d1993a 100644
--- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
@@ -46,72 +46,130 @@ TensorShape ExpandedFilterShapeForDepthwiseConvolution(
return expanded_shape;
}
+// Broadcast zeros to ExpandedFilterShapeForDepthwiseConvolution.
+xla::ComputationDataHandle CreateExpandedZero(
+ const TensorShape& filter_shape, DataType dtype,
+ xla::ComputationBuilder* builder) {
+ TensorShape expanded_filter_shape =
+ ExpandedFilterShapeForDepthwiseConvolution(filter_shape);
+ return builder->Broadcast(XlaHelpers::Zero(builder, dtype),
+ expanded_filter_shape.dim_sizes());
+}
+
+// Create a mask for depthwise convolution that will make a normal convolution
+// produce the same results as a depthwise convolution. For a [2, 2, 3, 2]
+// depthwise filter this returns a [2, 2, 3, 6] tesnsor
+// 1 1 0 0 0 0 1 1 0 0 0 0
+// 0 0 1 1 0 0 0 0 1 1 0 0
+// 0 0 0 0 1 1 0 0 0 0 1 1
+//
+// 1 1 0 0 0 0 1 1 0 0 0 0
+// 0 0 1 1 0 0 0 0 1 1 0 0
+// 0 0 0 0 1 1 0 0 0 0 1 1
+//
+// The first step is to create a one tensor, A, that is [3]
+// 0 1 2
+//
+// and another tensor, B, that is [3 * 2]
+// 0 1 2 3 4 5
+//
+// and divide B it by 2 to get
+// 0 0 1 1 2 2
+//
+// then we broadcast the B to [2, 2, 3, 3 * 2]
+// 0 0 1 1 2 2 0 0 1 1 2 2
+// 0 0 1 1 2 2 0 0 1 1 2 2
+// 0 0 1 1 2 2 0 0 1 1 2 2
+//
+// 0 0 1 1 2 2 0 0 1 1 2 2
+// 0 0 1 1 2 2 0 0 1 1 2 2
+// 0 0 1 1 2 2 0 0 1 1 2 2
+//
+// Finally compare A and broadcasted B in dimension 2 amd return the result at
+// the beginning of the comment.
+xla::ComputationDataHandle CreateExpandedFilterMask(
+ const TensorShape& filter_shape, xla::ComputationBuilder* builder) {
+ TensorShape expanded_filter_shape =
+ ExpandedFilterShapeForDepthwiseConvolution(filter_shape);
+ int64 depthwise_multiplier = filter_shape.dim_size(filter_shape.dims() - 1);
+ int64 input_feature = filter_shape.dim_size(filter_shape.dims() - 2);
+
+ // Create a M sized linspace and an M*N sized linspace that will be
+ // broadcasted into perpendicular dimensions and compared.
+ xla::ComputationDataHandle input_feature_iota;
+ // DT_INT32 Iota will always return status::OK().
+ TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, input_feature,
+ &input_feature_iota));
+ xla::ComputationDataHandle expanded_feature_iota;
+ TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32,
+ input_feature * depthwise_multiplier,
+ &expanded_feature_iota));
+
+ // Divide the M*N sized linspace by the depthwise_multiplier to create
+ // [0 0 1 1 2 2] in the example in the function comment.
+ expanded_feature_iota =
+ builder->Div(expanded_feature_iota,
+ XlaHelpers::IntegerLiteral(builder, DataType::DT_INT32,
+ depthwise_multiplier));
+
+ // Broadcast the N*M linspace to [H, W, ..., M, M*N].
+ auto expanded_feature_broadcast_dims = expanded_filter_shape.dim_sizes();
+ expanded_feature_broadcast_dims.pop_back();
+ auto broadcasted_expanded_feature_iota = builder->Broadcast(
+ expanded_feature_iota, expanded_feature_broadcast_dims);
+
+ // Compare the broadcasted linspace to the input feature linspace in the
+ // input feature dimension to create a diagonal predicate.
+ return builder->Eq(broadcasted_expanded_feature_iota, input_feature_iota,
+ {expanded_filter_shape.dims() - 2});
+}
+
// Expands a filter of shape [H, W, ..., M, N] to [H, W, ..., M, M*N] by adding
// zeros for the cross-depth filters. Used to build a depthwise convolution.
xla::ComputationDataHandle ExpandFilterForDepthwiseConvolution(
const TensorShape& filter_shape, DataType dtype,
const xla::ComputationDataHandle& filter,
xla::ComputationBuilder* builder) {
- // Filter has shape [H, W, ..., M, N]
- // Dilate to [H, W, ..., M*M, N] using M inter-element padding, and then
- // reshape to [H, W, ..., M, M*N].
- int num_spatial_dims = filter_shape.dims() - 2;
- const int64 in_depth = filter_shape.dim_size(num_spatial_dims);
- xla::PaddingConfig padding = xla::MakeNoPaddingConfig(filter_shape.dims());
- padding.mutable_dimensions(num_spatial_dims)->set_interior_padding(in_depth);
- auto dilated_filter =
- builder->Pad(filter, XlaHelpers::Zero(builder, dtype), padding);
-
+ int64 depthwise_multiplier = filter_shape.dim_size(filter_shape.dims() - 1);
+ int64 input_feature = filter_shape.dim_size(filter_shape.dims() - 2);
TensorShape expanded_filter_shape =
ExpandedFilterShapeForDepthwiseConvolution(filter_shape);
- return builder->Reshape(dilated_filter, expanded_filter_shape.dim_sizes());
+
+ // Create a [H, W, ..., 1, N*M] reshape of the filter.
+ TensorShape implicit_broadcast_filter_shape = expanded_filter_shape;
+ implicit_broadcast_filter_shape.set_dim(
+ implicit_broadcast_filter_shape.dims() - 2, 1);
+ implicit_broadcast_filter_shape.set_dim(
+ implicit_broadcast_filter_shape.dims() - 1,
+ depthwise_multiplier * input_feature);
+ auto implicit_broadcast_filter =
+ builder->Reshape(filter, implicit_broadcast_filter_shape.dim_sizes());
+
+ // Broadcast the filter to [H, W, ..., M, M*N].
+ auto expanded_zero = CreateExpandedZero(filter_shape, dtype, builder);
+ auto expanded_filter = builder->Add(implicit_broadcast_filter, expanded_zero);
+
+ // If the filter mask is set, choose the broadcasted filter, othwerwise,
+ // choose zero.
+ return builder->Select(CreateExpandedFilterMask(filter_shape, builder),
+ expanded_filter, expanded_zero);
}
// Inverse of ExpandFilterForDepthwiseConvolution.
xla::ComputationDataHandle ContractFilterForDepthwiseBackprop(
- const TensorShape& filter_shape, DataType dtype,
+ XlaOpKernelContext* ctx, const TensorShape& filter_shape, DataType dtype,
const xla::ComputationDataHandle& filter_backprop,
xla::ComputationBuilder* builder) {
- int num_spatial_dims = filter_shape.dims() - 2;
-
- // Reshape to [H, W, ..., M*M, N]
- TensorShape shape = filter_shape;
- int64 in_depth = filter_shape.dim_size(num_spatial_dims);
- shape.set_dim(num_spatial_dims, in_depth * in_depth);
- auto reshaped = builder->Reshape(filter_backprop, shape.dim_sizes());
-
- std::vector<int64> zeros(filter_shape.dims());
- std::vector<int64> strides(filter_shape.dims(), 1LL);
- strides[num_spatial_dims] = in_depth + 1;
- return builder->Slice(reshaped, zeros, shape.dim_sizes(), strides);
-
- // Alternate implementation for backends without strided Slice() support.
- // TODO(phawkins): Remove when all backends support strided slice.
- // // Pad [..., M * (M + 1), N]
- // xla::PaddingConfig config =
- // xla::MakeNoPaddingConfig(filter_shape.dims());
- // config.mutable_dimensions(num_spatial_dims)
- // ->set_edge_padding_high(in_depth);
- // auto zero = XlaHelpers::Zero(builder, dtype);
- // auto padded = builder->Pad(reshaped, zero, config);
- //
- // // Reshape to [..., M, M + 1, N]
- // shape = filter_shape;
- // shape.set_dim(num_spatial_dims, in_depth);
- // shape.set_dim(num_spatial_dims + 1, in_depth + 1);
- // int64 out_depth = filter_shape.dim_size(num_spatial_dims + 1);
- // shape.AddDim(out_depth);
- // reshaped = builder->Reshape(padded, shape.dim_sizes());
- //
- // // Slice to [..., M, 1, N]
- // std::vector<int64> zeros(shape.dims());
- // std::vector<int64> strides(shape.dims(), 1LL);
- // shape.set_dim(num_spatial_dims + 1, 1);
- // auto sliced = builder->Slice(reshaped, zeros, shape.dim_sizes(),
- // strides);
- //
- // // Reshape to [..., M, N]
- // return builder->Reshape(sliced, filter_shape.dim_sizes());
+ TensorShape expanded_filter_shape =
+ ExpandedFilterShapeForDepthwiseConvolution(filter_shape);
+ auto masked_expanded_filter = builder->Select(
+ CreateExpandedFilterMask(filter_shape, builder), filter_backprop,
+ CreateExpandedZero(filter_shape, dtype, builder));
+ return builder->Reshape(
+ builder->Reduce(masked_expanded_filter, XlaHelpers::Zero(builder, dtype),
+ *ctx->GetOrCreateAdd(dtype),
+ {expanded_filter_shape.dims() - 2}),
+ filter_shape.dim_sizes());
}
class ConvOp : public XlaOpKernel {
@@ -202,7 +260,7 @@ class ConvOp : public XlaOpKernel {
dims.set_input_feature_dimension(feature_dim);
dims.set_output_feature_dimension(feature_dim);
for (int i = 0; i < num_spatial_dims_; ++i) {
- int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
+ const int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
dims.add_input_spatial_dimensions(dim);
dims.add_kernel_spatial_dimensions(i);
dims.add_output_spatial_dimensions(dim);
@@ -574,7 +632,7 @@ class ConvBackpropFilterOp : public XlaOpKernel {
if (depthwise_) {
filter_backprop_reshaped = ContractFilterForDepthwiseBackprop(
- filter_shape, ctx->input_type(0), filter_backprop_reshaped, b);
+ ctx, filter_shape, ctx->input_type(0), filter_backprop_reshaped, b);
}
ctx->SetOutput(0, filter_backprop_reshaped);
}