aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla
diff options
context:
space:
mode:
authorGravatar Blake Hechtman <blakehechtman@google.com>2018-09-20 15:42:23 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-20 15:58:45 -0700
commit23a88ec5e913ba7086a9aef57875447ccf96e4b5 (patch)
tree6e94f3f0fa29edb01c2cef29628e9a0d7f0b5034 /tensorflow/compiler/tf2xla
parent1797aacbd8b910fb8c15577f66257b35af97cc1a (diff)
It is more computationally efficient to represent resize bilinear as a
depthwise convolution instead of a full convolution now that it exists in XLA. PiperOrigin-RevId: 213896333
Diffstat (limited to 'tensorflow/compiler/tf2xla')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc76
1 files changed, 34 insertions, 42 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
index d9a0257b70..7b2bb4a7c5 100644
--- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/numeric.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
@@ -132,14 +133,14 @@ int64 CalculateUpperPadding(int64 in_size, int64 out_size, int64 kernel_size,
// If the 2D kernel would be very large, the 1D kernel can be applied once in
// each dimension due to the symmetry of the kernel along all axis to reduce the
// computational intensity.
-std::vector<float> Make1DKernel(int64 n) {
+xla::XlaOp Make1DKernel(xla::XlaBuilder* builder, int64 n) {
std::vector<float> kernel(n * 2 - 1);
for (int64 i = 0; i < n; ++i) {
float v = (i + 1.0f) / n;
kernel[i] = v;
kernel[n * 2 - 2 - i] = v;
}
- return kernel;
+ return xla::ConstantR1<float>(builder, kernel);
}
// Kernels with more than 16 spatial elements are considered intense and the
@@ -149,41 +150,26 @@ const int64 kMax2DKernelSize = 16;
xla::XlaOp MakeBilinearResizeKernel(xla::XlaBuilder* builder,
absl::Span<const int64> kernel_size,
int64 channels) {
- xla::XlaOp channels_iota = xla::Iota(builder, xla::S32, channels);
+ auto depthwise_kernel = xla::Broadcast(
+ xla::Zero(builder, xla::F32),
+ {(2 * kernel_size[0] - 1), (2 * kernel_size[1] - 1), channels, 1});
- auto diag = xla::ConvertElementType(
- xla::Eq(xla::Broadcast(channels_iota, {2 * kernel_size[0] - 1,
- 2 * kernel_size[1] - 1, channels}),
- channels_iota, /*broadcast_dimensions=*/{2}),
- xla::PrimitiveType::F32);
return xla::Mul(
- xla::Mul(diag,
- xla::ConstantR1<float>(builder, Make1DKernel(kernel_size[1])),
+ xla::Add(depthwise_kernel, Make1DKernel(builder, kernel_size[1]),
/*broadcast_dimensions=*/{1}),
- xla::ConstantR1<float>(builder, Make1DKernel(kernel_size[0])),
+ Make1DKernel(builder, kernel_size[0]),
/*broadcast_dimensions=*/{0});
}
xla::XlaOp MakeBilinearResizeKernelInDim(xla::XlaBuilder* builder,
absl::Span<const int64> kernel_size,
int64 channels, int64 dim) {
- xla::XlaOp channels_iota = xla::Iota(builder, xla::S32, channels);
-
- auto diag = xla::ConvertElementType(
- xla::Eq(
- xla::Broadcast(channels_iota,
- {dim == 0 ? (2 * kernel_size[0] - 1) : 1,
- dim == 1 ? (2 * kernel_size[1] - 1) : 1, channels}),
- channels_iota, /*broadcast_dimensions=*/{2}),
- xla::PrimitiveType::F32);
- if (dim == 1) {
- return xla::Mul(
- diag, xla::ConstantR1<float>(builder, Make1DKernel(kernel_size[1])),
- /*broadcast_dimensions=*/{1});
- }
- return xla::Mul(diag,
- xla::ConstantR1<float>(builder, Make1DKernel(kernel_size[0])),
- /*broadcast_dimensions=*/{0});
+ auto depthwise_kernel =
+ xla::Broadcast(xla::Zero(builder, xla::F32),
+ {dim == 0 ? (2 * kernel_size[0] - 1) : 1,
+ dim == 1 ? (2 * kernel_size[1] - 1) : 1, channels, 1});
+ return xla::Add(depthwise_kernel, Make1DKernel(builder, kernel_size[dim]),
+ /*broadcast_dimensions=*/{dim});
}
xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder,
@@ -206,8 +192,8 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder,
xla::ConvolutionDimensionNumbers dimension_numbers;
dimension_numbers.set_input_batch_dimension(0);
dimension_numbers.set_output_batch_dimension(0);
- dimension_numbers.set_input_feature_dimension(3);
- dimension_numbers.set_output_feature_dimension(3);
+ dimension_numbers.set_input_feature_dimension(num_spatial_dims + 1);
+ dimension_numbers.set_output_feature_dimension(num_spatial_dims + 1);
for (int i = 0; i < num_spatial_dims; ++i) {
dimension_numbers.add_input_spatial_dimensions(1 + i);
dimension_numbers.add_output_spatial_dimensions(1 + i);
@@ -285,7 +271,8 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder,
{{dims.kernel_size[0] - 1, upper_padding[0]},
{dims.kernel_size[1] - 1, upper_padding[1]}},
/*lhs_dilation=*/dims.kernel_size,
- /*rhs_dilation=*/{1, 1}, dimension_numbers);
+ /*rhs_dilation=*/{1, 1}, dimension_numbers,
+ /*feature_group_count=*/channels);
} else {
xla::XlaOp kernel0 =
MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 0);
@@ -294,7 +281,8 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder,
/*padding=*/
{{dims.kernel_size[0] - 1, upper_padding[0]}, {0, 0}},
/*lhs_dilation=*/{dims.kernel_size[0], 1},
- /*rhs_dilation=*/{1, 1}, dimension_numbers);
+ /*rhs_dilation=*/{1, 1}, dimension_numbers,
+ /*feature_group_count=*/channels);
xla::XlaOp kernel1 =
MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 1);
output = xla::ConvGeneralDilated(
@@ -302,7 +290,8 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder,
/*padding=*/
{{0, 0}, {dims.kernel_size[1] - 1, upper_padding[1]}},
/*lhs_dilation=*/{1, dims.kernel_size[1]},
- /*rhs_dilation=*/{1, 1}, dimension_numbers);
+ /*rhs_dilation=*/{1, 1}, dimension_numbers,
+ /*feature_group_count=*/channels);
}
// Add broadcasts to handle expanding from a size == 1 dimension to a
@@ -331,15 +320,15 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder,
xla::ConvolutionDimensionNumbers dimension_numbers;
dimension_numbers.set_input_batch_dimension(0);
dimension_numbers.set_output_batch_dimension(0);
- dimension_numbers.set_input_feature_dimension(3);
- dimension_numbers.set_output_feature_dimension(3);
+ dimension_numbers.set_input_feature_dimension(num_spatial_dims + 1);
+ dimension_numbers.set_output_feature_dimension(num_spatial_dims + 1);
for (int i = 0; i < num_spatial_dims; ++i) {
- dimension_numbers.add_input_spatial_dimensions(1 + i);
- dimension_numbers.add_output_spatial_dimensions(1 + i);
+ dimension_numbers.add_input_spatial_dimensions(i + 1);
+ dimension_numbers.add_output_spatial_dimensions(i + 1);
dimension_numbers.add_kernel_spatial_dimensions(i);
}
- dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims);
- dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims + 1);
+ dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims + 1);
+ dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims);
xla::XlaOp output;
if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) {
xla::XlaOp kernel =
@@ -362,7 +351,8 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder,
{{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1},
{dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}},
/*lhs_dilation=*/dims.stride,
- /*rhs_dilation=*/{1, 1}, dimension_numbers);
+ /*rhs_dilation=*/{1, 1}, dimension_numbers,
+ /*feature_group_count=*/channels);
} else {
xla::XlaOp kernel0 =
MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 0);
@@ -388,14 +378,16 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder,
/*padding=*/
{{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, {0, 0}},
/*lhs_dilation=*/{dims.stride[0], 1},
- /*rhs_dilation=*/{1, 1}, dimension_numbers);
+ /*rhs_dilation=*/{1, 1}, dimension_numbers,
+ /*feature_group_count=*/channels);
output = xla::ConvGeneralDilated(
output, kernel1, /*window_strides=*/{1, dims.kernel_size[1]},
/*padding=*/
{{0, 0}, {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}},
/*lhs_dilation=*/{1, dims.stride[1]},
- /*rhs_dilation=*/{1, 1}, dimension_numbers);
+ /*rhs_dilation=*/{1, 1}, dimension_numbers,
+ /*feature_group_count=*/channels);
}
// If in_size[i] > 1 and grad_size[i] == 1, pad the output in dimension i.