diff options
author | 2018-09-20 15:42:23 -0700 | |
---|---|---|
committer | 2018-09-20 15:58:45 -0700 | |
commit | 23a88ec5e913ba7086a9aef57875447ccf96e4b5 (patch) | |
tree | 6e94f3f0fa29edb01c2cef29628e9a0d7f0b5034 /tensorflow/compiler/tf2xla | |
parent | 1797aacbd8b910fb8c15577f66257b35af97cc1a (diff) |
It is more computationally efficient to represent resize bilinear as a
depthwise convolution instead of a full convolution now that it exists in XLA.
PiperOrigin-RevId: 213896333
Diffstat (limited to 'tensorflow/compiler/tf2xla')
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc | 76 |
1 files changed, 34 insertions, 42 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc index d9a0257b70..7b2bb4a7c5 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" @@ -132,14 +133,14 @@ int64 CalculateUpperPadding(int64 in_size, int64 out_size, int64 kernel_size, // If the 2D kernel would be very large, the 1D kernel can be applied once in // each dimension due to the symmetry of the kernel along all axis to reduce the // computational intensity. -std::vector<float> Make1DKernel(int64 n) { +xla::XlaOp Make1DKernel(xla::XlaBuilder* builder, int64 n) { std::vector<float> kernel(n * 2 - 1); for (int64 i = 0; i < n; ++i) { float v = (i + 1.0f) / n; kernel[i] = v; kernel[n * 2 - 2 - i] = v; } - return kernel; + return xla::ConstantR1<float>(builder, kernel); } // Kernels with more than 16 spatial elements are considered intense and the @@ -149,41 +150,26 @@ const int64 kMax2DKernelSize = 16; xla::XlaOp MakeBilinearResizeKernel(xla::XlaBuilder* builder, absl::Span<const int64> kernel_size, int64 channels) { - xla::XlaOp channels_iota = xla::Iota(builder, xla::S32, channels); + auto depthwise_kernel = xla::Broadcast( + xla::Zero(builder, xla::F32), + {(2 * kernel_size[0] - 1), (2 * kernel_size[1] - 1), channels, 1}); - auto diag = xla::ConvertElementType( - xla::Eq(xla::Broadcast(channels_iota, {2 * kernel_size[0] - 1, - 2 * kernel_size[1] - 1, channels}), - channels_iota, /*broadcast_dimensions=*/{2}), - xla::PrimitiveType::F32); return xla::Mul( - xla::Mul(diag, - xla::ConstantR1<float>(builder, Make1DKernel(kernel_size[1])), + xla::Add(depthwise_kernel, Make1DKernel(builder, kernel_size[1]), /*broadcast_dimensions=*/{1}), - xla::ConstantR1<float>(builder, Make1DKernel(kernel_size[0])), + Make1DKernel(builder, kernel_size[0]), /*broadcast_dimensions=*/{0}); } xla::XlaOp MakeBilinearResizeKernelInDim(xla::XlaBuilder* builder, absl::Span<const int64> kernel_size, int64 channels, int64 dim) { - xla::XlaOp channels_iota = xla::Iota(builder, xla::S32, channels); - - auto diag = xla::ConvertElementType( - xla::Eq( - xla::Broadcast(channels_iota, - {dim == 0 ? (2 * kernel_size[0] - 1) : 1, - dim == 1 ? (2 * kernel_size[1] - 1) : 1, channels}), - channels_iota, /*broadcast_dimensions=*/{2}), - xla::PrimitiveType::F32); - if (dim == 1) { - return xla::Mul( - diag, xla::ConstantR1<float>(builder, Make1DKernel(kernel_size[1])), - /*broadcast_dimensions=*/{1}); - } - return xla::Mul(diag, - xla::ConstantR1<float>(builder, Make1DKernel(kernel_size[0])), - /*broadcast_dimensions=*/{0}); + auto depthwise_kernel = + xla::Broadcast(xla::Zero(builder, xla::F32), + {dim == 0 ? (2 * kernel_size[0] - 1) : 1, + dim == 1 ? (2 * kernel_size[1] - 1) : 1, channels, 1}); + return xla::Add(depthwise_kernel, Make1DKernel(builder, kernel_size[dim]), + /*broadcast_dimensions=*/{dim}); } xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, @@ -206,8 +192,8 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, xla::ConvolutionDimensionNumbers dimension_numbers; dimension_numbers.set_input_batch_dimension(0); dimension_numbers.set_output_batch_dimension(0); - dimension_numbers.set_input_feature_dimension(3); - dimension_numbers.set_output_feature_dimension(3); + dimension_numbers.set_input_feature_dimension(num_spatial_dims + 1); + dimension_numbers.set_output_feature_dimension(num_spatial_dims + 1); for (int i = 0; i < num_spatial_dims; ++i) { dimension_numbers.add_input_spatial_dimensions(1 + i); dimension_numbers.add_output_spatial_dimensions(1 + i); @@ -285,7 +271,8 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, {{dims.kernel_size[0] - 1, upper_padding[0]}, {dims.kernel_size[1] - 1, upper_padding[1]}}, /*lhs_dilation=*/dims.kernel_size, - /*rhs_dilation=*/{1, 1}, dimension_numbers); + /*rhs_dilation=*/{1, 1}, dimension_numbers, + /*feature_group_count=*/channels); } else { xla::XlaOp kernel0 = MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 0); @@ -294,7 +281,8 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, /*padding=*/ {{dims.kernel_size[0] - 1, upper_padding[0]}, {0, 0}}, /*lhs_dilation=*/{dims.kernel_size[0], 1}, - /*rhs_dilation=*/{1, 1}, dimension_numbers); + /*rhs_dilation=*/{1, 1}, dimension_numbers, + /*feature_group_count=*/channels); xla::XlaOp kernel1 = MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 1); output = xla::ConvGeneralDilated( @@ -302,7 +290,8 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, /*padding=*/ {{0, 0}, {dims.kernel_size[1] - 1, upper_padding[1]}}, /*lhs_dilation=*/{1, dims.kernel_size[1]}, - /*rhs_dilation=*/{1, 1}, dimension_numbers); + /*rhs_dilation=*/{1, 1}, dimension_numbers, + /*feature_group_count=*/channels); } // Add broadcasts to handle expanding from a size == 1 dimension to a @@ -331,15 +320,15 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder, xla::ConvolutionDimensionNumbers dimension_numbers; dimension_numbers.set_input_batch_dimension(0); dimension_numbers.set_output_batch_dimension(0); - dimension_numbers.set_input_feature_dimension(3); - dimension_numbers.set_output_feature_dimension(3); + dimension_numbers.set_input_feature_dimension(num_spatial_dims + 1); + dimension_numbers.set_output_feature_dimension(num_spatial_dims + 1); for (int i = 0; i < num_spatial_dims; ++i) { - dimension_numbers.add_input_spatial_dimensions(1 + i); - dimension_numbers.add_output_spatial_dimensions(1 + i); + dimension_numbers.add_input_spatial_dimensions(i + 1); + dimension_numbers.add_output_spatial_dimensions(i + 1); dimension_numbers.add_kernel_spatial_dimensions(i); } - dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims); - dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims + 1); + dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims + 1); + dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims); xla::XlaOp output; if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) { xla::XlaOp kernel = @@ -362,7 +351,8 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder, {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, /*lhs_dilation=*/dims.stride, - /*rhs_dilation=*/{1, 1}, dimension_numbers); + /*rhs_dilation=*/{1, 1}, dimension_numbers, + /*feature_group_count=*/channels); } else { xla::XlaOp kernel0 = MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 0); @@ -388,14 +378,16 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder, /*padding=*/ {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, {0, 0}}, /*lhs_dilation=*/{dims.stride[0], 1}, - /*rhs_dilation=*/{1, 1}, dimension_numbers); + /*rhs_dilation=*/{1, 1}, dimension_numbers, + /*feature_group_count=*/channels); output = xla::ConvGeneralDilated( output, kernel1, /*window_strides=*/{1, dims.kernel_size[1]}, /*padding=*/ {{0, 0}, {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, /*lhs_dilation=*/{1, dims.stride[1]}, - /*rhs_dilation=*/{1, 1}, dimension_numbers); + /*rhs_dilation=*/{1, 1}, dimension_numbers, + /*feature_group_count=*/channels); } // If in_size[i] > 1 and grad_size[i] == 1, pad the output in dimension i. |