aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/tf2xla/kernels/conv_ops.cc19
-rw-r--r--tensorflow/compiler/xla/client/computation_builder.cc45
-rw-r--r--tensorflow/compiler/xla/client/computation_builder.h5
-rw-r--r--tensorflow/compiler/xla/reference_util.cc11
-rw-r--r--tensorflow/compiler/xla/reference_util_test.cc16
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc13
-rw-r--r--tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc23
-rw-r--r--tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc12
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc7
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc69
-rw-r--r--tensorflow/compiler/xla/service/gpu/convolution_folding.cc34
-rw-r--r--tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc18
-rw-r--r--tensorflow/compiler/xla/service/gpu/convolution_thunk.cc9
-rw-r--r--tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/layout_assignment.cc12
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_insertion.cc19
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc16
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_test.cc15
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc14
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc29
-rw-r--r--tensorflow/compiler/xla/service/shape_inference_test.cc24
-rw-r--r--tensorflow/compiler/xla/service/transpose_folding.cc18
-rw-r--r--tensorflow/compiler/xla/service/transpose_folding_test.cc16
-rw-r--r--tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc18
-rw-r--r--tensorflow/compiler/xla/tests/convolution_test.cc18
-rw-r--r--tensorflow/compiler/xla/tests/convolution_variants_test.cc65
-rw-r--r--tensorflow/compiler/xla/tools/parser/hlo_parser.cc13
-rw-r--r--tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc6
-rw-r--r--tensorflow/compiler/xla/xla_data.proto22
30 files changed, 353 insertions, 241 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
index 885f716afa..c5017704e2 100644
--- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
@@ -184,10 +184,11 @@ class ConvOp : public XlaOpKernel {
dims.set_input_feature_dimension(feature_dim);
dims.set_output_feature_dimension(feature_dim);
for (int i = 0; i < num_spatial_dims_; ++i) {
- int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
- dims.add_spatial_dimensions(input_dim);
+ int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
+ dims.add_input_spatial_dimensions(dim);
dims.add_kernel_spatial_dimensions(i);
- window_strides.push_back(strides_.at(input_dim));
+ dims.add_output_spatial_dimensions(dim);
+ window_strides.push_back(strides_.at(dim));
}
dims.set_kernel_input_feature_dimension(num_spatial_dims_);
dims.set_kernel_output_feature_dimension(num_spatial_dims_ + 1);
@@ -302,9 +303,10 @@ class ConvBackpropInputOp : public XlaOpKernel {
std::vector<int64> lhs_dilation(num_spatial_dims_);
std::vector<int64> ones(num_spatial_dims_, 1);
for (int i = 0; i < num_spatial_dims_; ++i) {
- dnums.add_spatial_dimensions(
- GetTensorSpatialDimIndex(num_dims(), data_format_, i));
+ int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
+ dnums.add_input_spatial_dimensions(dim);
dnums.add_kernel_spatial_dimensions(i);
+ dnums.add_output_spatial_dimensions(dim);
kernel_spatial_dims[i] = i;
padding[i] = {dims.spatial_dims[i].pad_before,
@@ -439,9 +441,10 @@ class ConvBackpropFilterOp : public XlaOpKernel {
std::vector<int64> ones(num_spatial_dims_, 1);
for (int i = 0; i < num_spatial_dims_; ++i) {
- int dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
- dnums.add_spatial_dimensions(dim);
+ int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
+ dnums.add_input_spatial_dimensions(dim);
dnums.add_kernel_spatial_dimensions(dim);
+ dnums.add_output_spatial_dimensions(dim);
// We will also need to pad the input with zeros such that after the
// convolution, we get the right size for the filter.
@@ -506,7 +509,7 @@ class ConvBackpropFilterOp : public XlaOpKernel {
std::vector<int64> transpose_dims;
transpose_dims.reserve(num_dims());
for (int i = 0; i < num_spatial_dims_; ++i) {
- transpose_dims.push_back(dnums.spatial_dimensions(i));
+ transpose_dims.push_back(dnums.output_spatial_dimensions(i));
}
transpose_dims.push_back(c_dim);
transpose_dims.push_back(n_dim);
diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc
index b17d221ef5..cce9310003 100644
--- a/tensorflow/compiler/xla/client/computation_builder.cc
+++ b/tensorflow/compiler/xla/client/computation_builder.cc
@@ -694,11 +694,15 @@ bool ComputationBuilder::VerifyConvolution(
}
return true;
};
- return check_spatial_dimensions("spatial_dimensions",
- dimension_numbers.spatial_dimensions()) &&
+ return check_spatial_dimensions(
+ "input_spatial_dimensions",
+ dimension_numbers.input_spatial_dimensions()) &&
check_spatial_dimensions(
"kernel_spatial_dimensions",
- dimension_numbers.kernel_spatial_dimensions());
+ dimension_numbers.kernel_spatial_dimensions()) &&
+ check_spatial_dimensions(
+ "output_spatial_dimensions",
+ dimension_numbers.output_spatial_dimensions());
}
ComputationDataHandle ComputationBuilder::ConvWithGeneralDimensions(
@@ -730,11 +734,11 @@ ComputationDataHandle ComputationBuilder::ConvWithGeneralDimensions(
}
std::vector<int64> base_area_dimensions(
- dimension_numbers.spatial_dimensions_size());
+ dimension_numbers.input_spatial_dimensions_size());
for (std::vector<int64>::size_type i = 0; i < base_area_dimensions.size();
++i) {
base_area_dimensions[i] =
- lhs_shape->dimensions(dimension_numbers.spatial_dimensions(i));
+ lhs_shape->dimensions(dimension_numbers.input_spatial_dimensions(i));
}
std::vector<int64> window_dimensions(
@@ -1845,25 +1849,27 @@ ComputationBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) {
dimension_numbers.set_kernel_input_feature_dimension(
kConvKernelInputDimension);
for (int i = 0; i < num_spatial_dims; ++i) {
- dimension_numbers.add_spatial_dimensions(i + 2);
+ dimension_numbers.add_input_spatial_dimensions(i + 2);
dimension_numbers.add_kernel_spatial_dimensions(i + 2);
+ dimension_numbers.add_output_spatial_dimensions(i + 2);
}
return dimension_numbers;
}
/* static */ StatusOr<ConvolutionDimensionNumbers>
ComputationBuilder::CreateConvDimensionNumbers(
- int64 input_batch, int64 input_feature, int64 output_batch,
- int64 output_feature, int64 first_spatial, int64 second_spatial,
+ int64 input_batch, int64 input_feature, int64 input_first_spatial,
+ int64 input_second_spatial, int64 output_batch, int64 output_feature,
+ int64 output_first_spatial, int64 output_second_spatial,
int64 kernel_output_feature, int64 kernel_input_feature,
int64 kernel_first_spatial, int64 kernel_second_spatial) {
- if (std::set<int64>(
- {input_batch, input_feature, first_spatial, second_spatial})
+ if (std::set<int64>({input_batch, input_feature, input_first_spatial,
+ input_second_spatial})
.size() != 4) {
return FailedPrecondition(
"dimension numbers for the input are not unique: (%lld, %lld, %lld, "
"%lld)",
- input_batch, input_feature, first_spatial, second_spatial);
+ input_batch, input_feature, input_first_spatial, input_second_spatial);
}
if (std::set<int64>({kernel_output_feature, kernel_input_feature,
kernel_first_spatial, kernel_second_spatial})
@@ -1874,25 +1880,28 @@ ComputationBuilder::CreateConvDimensionNumbers(
kernel_output_feature, kernel_input_feature, kernel_first_spatial,
kernel_second_spatial);
}
- if (std::set<int64>(
- {output_batch, output_feature, first_spatial, second_spatial})
+ if (std::set<int64>({output_batch, output_feature, output_first_spatial,
+ output_second_spatial})
.size() != 4) {
return FailedPrecondition(
"dimension numbers for the output are not unique: (%lld, %lld, %lld, "
"%lld)",
- output_batch, output_feature, first_spatial, second_spatial);
+ output_batch, output_feature, output_first_spatial,
+ output_second_spatial);
}
ConvolutionDimensionNumbers dimension_numbers;
dimension_numbers.set_input_batch_dimension(input_batch);
dimension_numbers.set_input_feature_dimension(input_feature);
- dimension_numbers.set_output_batch_dimension(output_batch);
- dimension_numbers.set_output_feature_dimension(output_feature);
- dimension_numbers.add_spatial_dimensions(first_spatial);
- dimension_numbers.add_spatial_dimensions(second_spatial);
+ dimension_numbers.add_input_spatial_dimensions(input_first_spatial);
+ dimension_numbers.add_input_spatial_dimensions(input_second_spatial);
dimension_numbers.set_kernel_output_feature_dimension(kernel_output_feature);
dimension_numbers.set_kernel_input_feature_dimension(kernel_input_feature);
dimension_numbers.add_kernel_spatial_dimensions(kernel_first_spatial);
dimension_numbers.add_kernel_spatial_dimensions(kernel_second_spatial);
+ dimension_numbers.set_output_batch_dimension(output_batch);
+ dimension_numbers.set_output_feature_dimension(output_feature);
+ dimension_numbers.add_output_spatial_dimensions(output_first_spatial);
+ dimension_numbers.add_output_spatial_dimensions(output_second_spatial);
return dimension_numbers;
}
diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h
index 3a34010e6a..d2dbbbbebb 100644
--- a/tensorflow/compiler/xla/client/computation_builder.h
+++ b/tensorflow/compiler/xla/client/computation_builder.h
@@ -413,8 +413,9 @@ class ComputationBuilder {
// Creates a ConvolutionDimensionNumbers with the given arguments. Returns an
// error if either the input or the weight dimension numbers have conflicts.
static StatusOr<ConvolutionDimensionNumbers> CreateConvDimensionNumbers(
- int64 input_batch, int64 input_feature, int64 output_batch,
- int64 output_feature, int64 first_spatial, int64 second_spatial,
+ int64 input_batch, int64 input_feature, int64 input_first_spatial,
+ int64 input_second_spatial, int64 output_batch, int64 output_feature,
+ int64 output_first_spatial, int64 output_second_spatial,
int64 kernel_output_feature, int64 kernel_input_feature,
int64 kernel_first_spatial, int64 kernel_second_spatial);
diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc
index 5a899d550b..5bb81b80dd 100644
--- a/tensorflow/compiler/xla/reference_util.cc
+++ b/tensorflow/compiler/xla/reference_util.cc
@@ -102,7 +102,9 @@ ReferenceUtil::ConvArray3DGeneralDimensionsDilated(
const Array3D<float>& lhs, const Array3D<float>& rhs, int64 kernel_stride,
Padding padding, int64 lhs_dilation, int64 rhs_dilation,
const ConvolutionDimensionNumbers& dnums) {
- CHECK_EQ(dnums.spatial_dimensions_size(), 1);
+ CHECK_EQ(dnums.input_spatial_dimensions_size(), 1);
+ CHECK_EQ(dnums.kernel_spatial_dimensions_size(), 1);
+ CHECK_EQ(dnums.output_spatial_dimensions_size(), 1);
// Reuse the code for Array4D-convolution by extending the 3D input into a 4D
// array by adding a fourth dummy dimension of size 1 without stride, padding
// and dilation.
@@ -120,8 +122,9 @@ ReferenceUtil::ConvArray3DGeneralDimensionsDilated(
});
// Add a second dummy spatial dimensions.
ConvolutionDimensionNumbers dnums2d = dnums;
- dnums2d.add_spatial_dimensions(3);
+ dnums2d.add_input_spatial_dimensions(3);
dnums2d.add_kernel_spatial_dimensions(3);
+ dnums2d.add_output_spatial_dimensions(3);
std::unique_ptr<Array4D<float>> convr4 = ConvArray4DGeneralDimensionsDilated(
a4dlhs, a4drhs, {kernel_stride, 1}, padding, {lhs_dilation, 1},
{rhs_dilation, 1}, dnums2d);
@@ -465,9 +468,9 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
}
ordered_input_dimensions[0] =
- lhs_literal->shape().dimensions(dnums.spatial_dimensions(0));
+ lhs_literal->shape().dimensions(dnums.input_spatial_dimensions(0));
ordered_input_dimensions[1] =
- lhs_literal->shape().dimensions(dnums.spatial_dimensions(1));
+ lhs_literal->shape().dimensions(dnums.input_spatial_dimensions(1));
ordered_kernel_dimensions[0] =
rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0));
ordered_kernel_dimensions[1] =
diff --git a/tensorflow/compiler/xla/reference_util_test.cc b/tensorflow/compiler/xla/reference_util_test.cc
index eb6a71242f..846ccdc83d 100644
--- a/tensorflow/compiler/xla/reference_util_test.cc
+++ b/tensorflow/compiler/xla/reference_util_test.cc
@@ -60,7 +60,9 @@ TEST_F(ReferenceUtilTest, TransposeArray2D) {
TEST_F(ReferenceUtilTest, MatmulArray2D) {
Array2D<float> rhs({
- {7.f, 8.f}, {9.f, 10.f}, {11.f, 12.f},
+ {7.f, 8.f},
+ {9.f, 10.f},
+ {11.f, 12.f},
});
auto result = ReferenceUtil::MatmulArray2D(*matrix_, rhs);
auto actual_literal = Literal::CreateR2FromArray2D(*result);
@@ -326,8 +328,10 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithSamePadding) {
dimension_numbers.set_input_feature_dimension(0);
dimension_numbers.set_output_batch_dimension(2);
dimension_numbers.set_output_feature_dimension(0);
- dimension_numbers.add_spatial_dimensions(1);
- dimension_numbers.add_spatial_dimensions(3);
+ dimension_numbers.add_input_spatial_dimensions(1);
+ dimension_numbers.add_output_spatial_dimensions(1);
+ dimension_numbers.add_input_spatial_dimensions(3);
+ dimension_numbers.add_output_spatial_dimensions(3);
dimension_numbers.set_kernel_output_feature_dimension(0);
dimension_numbers.set_kernel_input_feature_dimension(2);
dimension_numbers.add_kernel_spatial_dimensions(1);
@@ -380,8 +384,10 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithValidPadding) {
dimension_numbers.set_input_feature_dimension(0);
dimension_numbers.set_output_batch_dimension(2);
dimension_numbers.set_output_feature_dimension(0);
- dimension_numbers.add_spatial_dimensions(1);
- dimension_numbers.add_spatial_dimensions(3);
+ dimension_numbers.add_input_spatial_dimensions(1);
+ dimension_numbers.add_output_spatial_dimensions(1);
+ dimension_numbers.add_input_spatial_dimensions(3);
+ dimension_numbers.add_output_spatial_dimensions(3);
dimension_numbers.set_kernel_output_feature_dimension(0);
dimension_numbers.set_kernel_input_feature_dimension(2);
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index 097f30be32..56dfb1cf0b 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -1624,8 +1624,11 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
ConvolutionDimensionNumbers dnums;
std::vector<int64> in_dims;
int in_channel_idx = -1;
- dnums.add_spatial_dimensions(-1); // filled in later
- dnums.add_spatial_dimensions(-1); // filled in later
+ // filled in later
+ dnums.add_input_spatial_dimensions(-1);
+ dnums.add_output_spatial_dimensions(-1);
+ dnums.add_input_spatial_dimensions(-1);
+ dnums.add_output_spatial_dimensions(-1);
for (int i = 0; i < strlen(options.dim_order); ++i) {
char ch = options.dim_order[i];
if (ch == 'N') {
@@ -1633,10 +1636,12 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
dnums.set_output_batch_dimension(i);
in_dims.push_back(options.in_batch);
} else if (ch == 'H') {
- dnums.set_spatial_dimensions(0, i);
+ dnums.set_input_spatial_dimensions(0, i);
+ dnums.set_output_spatial_dimensions(0, i);
in_dims.push_back(options.in_height);
} else if (ch == 'W') {
- dnums.set_spatial_dimensions(1, i);
+ dnums.set_input_spatial_dimensions(1, i);
+ dnums.set_output_spatial_dimensions(1, i);
in_dims.push_back(options.in_width);
} else if (ch == 'C') {
dnums.set_input_feature_dimension(i);
diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc
index a3dd13811c..2136aeb387 100644
--- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc
+++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc
@@ -41,8 +41,8 @@ StatusOr<bool> ConvCanonicalization::Run(HloModule* module) {
auto kernel_input_feature_dim = dnums.kernel_input_feature_dimension();
auto kernel_output_feature_dim = dnums.kernel_output_feature_dimension();
- int num_spatial_dims = dnums.spatial_dimensions_size();
- int num_dims = num_spatial_dims + 2;
+ const int64 num_spatial_dims = dnums.output_spatial_dimensions_size();
+ const int64 num_dims = num_spatial_dims + 2;
// A canonical convolution's dimension numbers need to satisfy the
// following conditions (see cs/PotentiallyImplementedAsEigenConvolution).
@@ -59,10 +59,10 @@ StatusOr<bool> ConvCanonicalization::Run(HloModule* module) {
std::vector<int64> new_input_dims(num_dims);
new_input_dim_order[0] = input_batch_dim;
new_input_dims[0] = input->shape().dimensions(input_batch_dim);
- for (int i = 0; i < num_spatial_dims; ++i) {
- new_input_dim_order[i + 1] = dnums.spatial_dimensions(i);
+ for (int64 i = 0; i < num_spatial_dims; ++i) {
+ new_input_dim_order[i + 1] = dnums.input_spatial_dimensions(i);
new_input_dims[i + 1] =
- input->shape().dimensions(dnums.spatial_dimensions(i));
+ input->shape().dimensions(dnums.input_spatial_dimensions(i));
}
new_input_dim_order[num_dims - 1] = input_feature_dim;
new_input_dims[num_dims - 1] =
@@ -78,7 +78,7 @@ StatusOr<bool> ConvCanonicalization::Run(HloModule* module) {
std::vector<int64> new_kernel_dim_order(num_dims);
std::vector<int64> new_kernel_dims(num_dims);
- for (int i = 0; i < num_spatial_dims; ++i) {
+ for (int64 i = 0; i < num_spatial_dims; ++i) {
new_kernel_dim_order[i] = dnums.kernel_spatial_dimensions(i);
new_kernel_dims[i] =
kernel->shape().dimensions(dnums.kernel_spatial_dimensions(i));
@@ -102,10 +102,10 @@ StatusOr<bool> ConvCanonicalization::Run(HloModule* module) {
auto output_feature_dim = dnums.output_feature_dimension();
new_output_dim_order[0] = output_batch_dim;
new_conv_dims[0] = hlo->shape().dimensions(output_batch_dim);
- for (int i = 0; i < num_spatial_dims; ++i) {
- new_output_dim_order[i + 1] = dnums.spatial_dimensions(i);
+ for (int64 i = 0; i < num_spatial_dims; ++i) {
+ new_output_dim_order[i + 1] = dnums.output_spatial_dimensions(i);
new_conv_dims[i + 1] =
- hlo->shape().dimensions(dnums.spatial_dimensions(i));
+ hlo->shape().dimensions(dnums.output_spatial_dimensions(i));
}
new_output_dim_order[num_dims - 1] = output_feature_dim;
new_conv_dims[num_dims - 1] = hlo->shape().dimensions(output_feature_dim);
@@ -115,9 +115,10 @@ StatusOr<bool> ConvCanonicalization::Run(HloModule* module) {
ConvolutionDimensionNumbers new_dnums;
new_dnums.set_input_batch_dimension(0);
new_dnums.set_output_batch_dimension(0);
- for (int i = 0; i < num_spatial_dims; ++i) {
- new_dnums.add_spatial_dimensions(i + 1);
+ for (int64 i = 0; i < num_spatial_dims; ++i) {
+ new_dnums.add_input_spatial_dimensions(i + 1);
new_dnums.add_kernel_spatial_dimensions(i);
+ new_dnums.add_output_spatial_dimensions(i + 1);
}
new_dnums.set_input_feature_dimension(num_dims - 1);
new_dnums.set_output_feature_dimension(num_dims - 1);
diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc
index d593ba26b6..968f53d5c7 100644
--- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc
@@ -69,8 +69,10 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) {
ConvolutionDimensionNumbers dnums;
dnums.set_input_batch_dimension(1);
dnums.set_output_batch_dimension(1);
- dnums.add_spatial_dimensions(2);
- dnums.add_spatial_dimensions(3);
+ dnums.add_input_spatial_dimensions(2);
+ dnums.add_output_spatial_dimensions(2);
+ dnums.add_input_spatial_dimensions(3);
+ dnums.add_output_spatial_dimensions(3);
dnums.set_input_feature_dimension(0);
dnums.set_output_feature_dimension(0);
dnums.add_kernel_spatial_dimensions(2);
@@ -125,8 +127,10 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) {
ConvolutionDimensionNumbers dnums;
dnums.set_input_batch_dimension(0);
dnums.set_output_batch_dimension(0);
- dnums.add_spatial_dimensions(1);
- dnums.add_spatial_dimensions(2);
+ dnums.add_input_spatial_dimensions(1);
+ dnums.add_output_spatial_dimensions(1);
+ dnums.add_input_spatial_dimensions(2);
+ dnums.add_output_spatial_dimensions(2);
dnums.set_input_feature_dimension(3);
dnums.set_output_feature_dimension(3);
dnums.add_kernel_spatial_dimensions(0);
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc
index d2e7f830d1..3993779da6 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc
@@ -49,18 +49,21 @@ bool PotentiallyImplementedAsEigenConvolution(
convolution.convolution_dimension_numbers();
// Only 1D and 2D convolutions are supported at the moment.
// TODO(b/32897908): add an optimized implementation for 3D convolution.
- const int64 num_spatial_dims = dnums.spatial_dimensions_size();
+ const int64 num_spatial_dims = dnums.output_spatial_dimensions_size();
if (num_spatial_dims > 2) {
return false;
}
for (int64 i = 0; i < num_spatial_dims; ++i) {
- if (dnums.spatial_dimensions(i) != i + 1) {
+ if (dnums.input_spatial_dimensions(i) != i + 1) {
return false;
}
if (dnums.kernel_spatial_dimensions(i) != i) {
return false;
}
+ if (dnums.output_spatial_dimensions(i) != i + 1) {
+ return false;
+ }
}
const Shape& output_shape = convolution.shape();
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index 49f4782693..502dd2e738 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -822,14 +822,16 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
// If the initialized_flag is false, initialize the selected value and index
// with the currently visiting operand.
SetToFirstInsertPoint(if_initialized.false_block, &ir_builder_);
- const auto save_operand_index = [&](
- const llvm_ir::IrArray::Index& operand_index) {
- for (int64 i = 0; i < rank; ++i) {
- llvm::Value* selected_index_address_slot = ir_builder_.CreateInBoundsGEP(
- selected_index_address, {ir_builder_.getInt32(i)});
- ir_builder_.CreateStore(operand_index[i], selected_index_address_slot);
- }
- };
+ const auto save_operand_index =
+ [&](const llvm_ir::IrArray::Index& operand_index) {
+ for (int64 i = 0; i < rank; ++i) {
+ llvm::Value* selected_index_address_slot =
+ ir_builder_.CreateInBoundsGEP(selected_index_address,
+ {ir_builder_.getInt32(i)});
+ ir_builder_.CreateStore(operand_index[i],
+ selected_index_address_slot);
+ }
+ };
llvm_ir::IrArray operand_array(GetIrArrayFor(operand));
llvm::Value* operand_data =
operand_array.EmitReadArrayElement(operand_index, &ir_builder_);
@@ -952,11 +954,12 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
// Input tensor.
const Shape& input_shape = convolution->operand(0)->shape();
int64 input_batch = input_shape.dimensions(dnums.input_batch_dimension());
- int64 input_rows = input_shape.dimensions(dnums.spatial_dimensions(0));
+ int64 input_rows =
+ input_shape.dimensions(dnums.input_spatial_dimensions(0));
int64 input_cols =
one_dim_convolution
? 1
- : input_shape.dimensions(dnums.spatial_dimensions(1));
+ : input_shape.dimensions(dnums.input_spatial_dimensions(1));
int64 input_channels =
input_shape.dimensions(dnums.input_feature_dimension());
@@ -976,11 +979,11 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
// Output tensor.
const Shape& convolution_shape = convolution->shape();
int64 output_rows =
- convolution_shape.dimensions(dnums.spatial_dimensions(0));
- int64 output_cols =
- one_dim_convolution
- ? 1
- : convolution_shape.dimensions(dnums.spatial_dimensions(1));
+ convolution_shape.dimensions(dnums.output_spatial_dimensions(0));
+ int64 output_cols = one_dim_convolution
+ ? 1
+ : convolution_shape.dimensions(
+ dnums.output_spatial_dimensions(1));
// Extract the window stride for the convolution.
const Window& window = convolution->window();
@@ -1068,10 +1071,10 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
return EmitTargetElementLoop(
convolution, [this, convolution, lhs, rhs, window,
dnums](const llvm_ir::IrArray::Index& index) {
- int num_spatial_dims = dnums.spatial_dimensions_size();
+ int num_spatial_dims = dnums.output_spatial_dimensions_size();
std::vector<llvm::Value*> output_spatial(num_spatial_dims);
for (int i = 0; i < num_spatial_dims; ++i) {
- output_spatial[i] = index[dnums.spatial_dimensions(i)];
+ output_spatial[i] = index[dnums.output_spatial_dimensions(i)];
}
llvm::Value* output_feature = index[dnums.output_feature_dimension()];
llvm::Value* batch = index[dnums.output_batch_dimension()];
@@ -1091,8 +1094,9 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
for (int i = 0; i < num_spatial_dims; ++i) {
kernel_spatial[i] =
loops
- .AddLoop(0, rhs->shape().dimensions(
- dnums.kernel_spatial_dimensions(i)),
+ .AddLoop(0,
+ rhs->shape().dimensions(
+ dnums.kernel_spatial_dimensions(i)),
tensorflow::strings::StrCat("k", i))
->GetIndVarValue();
}
@@ -1108,17 +1112,18 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
// Calculate the spatial index in the input array, taking striding,
// dilation and padding into account. An index in the padding will be
// out of the bounds of the array.
- const auto calculate_input_index = [this](
- llvm::Value* output_index, llvm::Value* kernel_index,
- const WindowDimension& window_dim) {
- llvm::Value* strided_index = ir_builder_.CreateNSWMul(
- output_index, ir_builder_.getInt64(window_dim.stride()));
- llvm::Value* dilated_kernel_index = ir_builder_.CreateNSWMul(
- kernel_index, ir_builder_.getInt64(window_dim.window_dilation()));
- return ir_builder_.CreateNSWSub(
- ir_builder_.CreateNSWAdd(strided_index, dilated_kernel_index),
- ir_builder_.getInt64(window_dim.padding_low()));
- };
+ const auto calculate_input_index =
+ [this](llvm::Value* output_index, llvm::Value* kernel_index,
+ const WindowDimension& window_dim) {
+ llvm::Value* strided_index = ir_builder_.CreateNSWMul(
+ output_index, ir_builder_.getInt64(window_dim.stride()));
+ llvm::Value* dilated_kernel_index = ir_builder_.CreateNSWMul(
+ kernel_index,
+ ir_builder_.getInt64(window_dim.window_dilation()));
+ return ir_builder_.CreateNSWSub(
+ ir_builder_.CreateNSWAdd(strided_index, dilated_kernel_index),
+ ir_builder_.getInt64(window_dim.padding_low()));
+ };
std::vector<llvm::Value*> input_spatial(num_spatial_dims);
for (int i = 0; i < num_spatial_dims; ++i) {
input_spatial[i] = calculate_input_index(
@@ -1144,7 +1149,7 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
for (int i = 0; i < num_spatial_dims; ++i) {
llvm::ConstantInt* input_bound =
ir_builder_.getInt64(window_util::DilatedBound(
- lhs->shape().dimensions(dnums.spatial_dimensions(i)),
+ lhs->shape().dimensions(dnums.input_spatial_dimensions(i)),
window.dimensions(i).base_dilation()));
llvm::Value* dim_in_bound =
ir_builder_.CreateICmpULT(input_spatial[i], input_bound);
@@ -1176,7 +1181,7 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
int num_dims = num_spatial_dims + 2;
llvm_ir::IrArray::Index input_index(num_dims);
for (int i = 0; i < num_spatial_dims; ++i) {
- input_index[dnums.spatial_dimensions(i)] = input_spatial[i];
+ input_index[dnums.input_spatial_dimensions(i)] = input_spatial[i];
}
input_index[dnums.input_feature_dimension()] = input_feature;
input_index[dnums.input_batch_dimension()] = batch;
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_folding.cc b/tensorflow/compiler/xla/service/gpu/convolution_folding.cc
index 5aaf072f9d..828ae675d7 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_folding.cc
+++ b/tensorflow/compiler/xla/service/gpu/convolution_folding.cc
@@ -74,9 +74,10 @@ MatchBackwardFilter(HloInstruction* conv) {
conv->convolution_dimension_numbers();
auto input_batch_dim = conv_dnums.input_batch_dimension();
auto input_feature_dim = conv_dnums.input_feature_dimension();
+ auto input_spatial_dims = conv_dnums.input_spatial_dimensions();
auto output_batch_dim = conv_dnums.output_batch_dimension();
auto output_feature_dim = conv_dnums.output_feature_dimension();
- auto spatial_dims = conv_dnums.spatial_dimensions();
+ auto output_spatial_dims = conv_dnums.output_spatial_dimensions();
for (const WindowDimension& window_dim : conv->window().dimensions()) {
if (window_dim.stride() != 1) {
@@ -108,11 +109,11 @@ MatchBackwardFilter(HloInstruction* conv) {
//
// Compute the window of the backward convolution.
Window backward_conv_window;
- for (int i = 0; i < spatial_dims.size(); ++i) {
+ for (int i = 0; i < input_spatial_dims.size(); ++i) {
WindowDimension* dim = backward_conv_window.add_dimensions();
// The window size of the backward convolution equals the output size of the
// forward convolution.
- int64 filter_size = conv->shape().dimensions(spatial_dims[i]);
+ int64 filter_size = conv->shape().dimensions(output_spatial_dims[i]);
dim->set_size(filter_size);
// The window stride equals the window dilation of the forward convolution.
dim->set_stride(conv->window().dimensions(i).window_dilation());
@@ -120,7 +121,8 @@ MatchBackwardFilter(HloInstruction* conv) {
// activations.
dim->set_padding_low(conv->window().dimensions(i).padding_low());
- int64 input_size = conv->operand(0)->shape().dimensions(spatial_dims[i]);
+ int64 input_size =
+ conv->operand(0)->shape().dimensions(input_spatial_dims[i]);
int64 output_size = conv->window().dimensions(i).size();
// Compute the range of the amount of valid high padding. We first compute
// min_padding_high, the amount of padding on the right/bottom to ensure the
@@ -189,8 +191,11 @@ MatchBackwardFilter(HloInstruction* conv) {
backward_conv_dnums.set_input_feature_dimension(input_batch_dim);
backward_conv_dnums.set_output_batch_dimension(output_feature_dim);
backward_conv_dnums.set_output_feature_dimension(output_batch_dim);
- for (int i = 0; i < spatial_dims.size(); ++i) {
- backward_conv_dnums.add_spatial_dimensions(spatial_dims[i]);
+ for (int i = 0; i < input_spatial_dims.size(); ++i) {
+ backward_conv_dnums.add_input_spatial_dimensions(input_spatial_dims[i]);
+ }
+ for (int i = 0; i < output_spatial_dims.size(); ++i) {
+ backward_conv_dnums.add_output_spatial_dimensions(output_spatial_dims[i]);
}
// The dimension numbering of the output of the forward convolution (before
// transposition) is the same as that of the activations (according to the
@@ -205,9 +210,9 @@ MatchBackwardFilter(HloInstruction* conv) {
PositionInContainer(transpose->dimensions(), output_batch_dim));
backward_conv_dnums.set_kernel_output_feature_dimension(
PositionInContainer(transpose->dimensions(), output_feature_dim));
- for (int i = 0; i < spatial_dims.size(); ++i) {
+ for (int i = 0; i < output_spatial_dims.size(); ++i) {
backward_conv_dnums.add_kernel_spatial_dimensions(
- PositionInContainer(transpose->dimensions(), spatial_dims[i]));
+ PositionInContainer(transpose->dimensions(), output_spatial_dims[i]));
}
return std::make_tuple(true, std::vector<HloInstruction*>({transpose, conv}),
@@ -272,12 +277,14 @@ MatchBackwardInput(HloInstruction* conv) {
}
}
- const auto& spatial_dims = dnums.spatial_dimensions();
- CHECK_EQ(conv->window().dimensions().size(), spatial_dims.size());
+ const auto& input_spatial_dims = dnums.input_spatial_dimensions();
+ const auto& output_spatial_dims = dnums.output_spatial_dimensions();
+ CHECK_EQ(conv->window().dimensions().size(), input_spatial_dims.size());
+ CHECK_EQ(output_spatial_dims.size(), input_spatial_dims.size());
const Window& old_window = conv->window();
Window new_window = old_window;
- for (size_t i = 0; i < spatial_dims.size(); ++i) {
+ for (size_t i = 0; i < input_spatial_dims.size(); ++i) {
// Restore backward convolution's padding config from the matched pattern.
// See the comment in tensorflow/core/kernels/conv_grad_tuple_ops.cc
// for how we convert backward input convolution to a variant of forward
@@ -310,8 +317,9 @@ MatchBackwardInput(HloInstruction* conv) {
// end at the border. The maximum amount (max_padding_high) equals
// min_padding_high+stride-1 -- max_padding_high+1 would cause the output
// size to change.
- auto unpadded_input_size = conv->shape().dimensions(spatial_dims[i]);
- auto output_size = conv->operand(0)->shape().dimensions(spatial_dims[i]);
+ auto unpadded_input_size = conv->shape().dimensions(output_spatial_dims[i]);
+ auto output_size =
+ conv->operand(0)->shape().dimensions(input_spatial_dims[i]);
auto padded_input_size = kernel_size + dim->stride() * (output_size - 1);
auto total_pad_size = padded_input_size - unpadded_input_size;
auto min_padding_high = total_pad_size - backward_padding_low;
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc b/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc
index 19b122ba06..112c496e1f 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc
@@ -49,8 +49,10 @@ class ConvolutionFoldingTest : public HloTestBase {
tf_default_dnums_for_backward_filter_.set_output_batch_dimension(3);
tf_default_dnums_for_backward_filter_.set_input_feature_dimension(0);
tf_default_dnums_for_backward_filter_.set_output_feature_dimension(0);
- tf_default_dnums_for_backward_filter_.add_spatial_dimensions(1);
- tf_default_dnums_for_backward_filter_.add_spatial_dimensions(2);
+ tf_default_dnums_for_backward_filter_.add_input_spatial_dimensions(1);
+ tf_default_dnums_for_backward_filter_.add_output_spatial_dimensions(1);
+ tf_default_dnums_for_backward_filter_.add_input_spatial_dimensions(2);
+ tf_default_dnums_for_backward_filter_.add_output_spatial_dimensions(2);
tf_default_dnums_for_backward_filter_.set_kernel_input_feature_dimension(0);
tf_default_dnums_for_backward_filter_.set_kernel_output_feature_dimension(
3);
@@ -61,8 +63,10 @@ class ConvolutionFoldingTest : public HloTestBase {
tf_default_dnums_for_backward_input_.set_output_batch_dimension(0);
tf_default_dnums_for_backward_input_.set_input_feature_dimension(3);
tf_default_dnums_for_backward_input_.set_output_feature_dimension(3);
- tf_default_dnums_for_backward_input_.add_spatial_dimensions(1);
- tf_default_dnums_for_backward_input_.add_spatial_dimensions(2);
+ tf_default_dnums_for_backward_input_.add_input_spatial_dimensions(1);
+ tf_default_dnums_for_backward_input_.add_output_spatial_dimensions(1);
+ tf_default_dnums_for_backward_input_.add_input_spatial_dimensions(2);
+ tf_default_dnums_for_backward_input_.add_output_spatial_dimensions(2);
tf_default_dnums_for_backward_input_.set_kernel_input_feature_dimension(3);
tf_default_dnums_for_backward_input_.set_kernel_output_feature_dimension(2);
tf_default_dnums_for_backward_input_.add_kernel_spatial_dimensions(0);
@@ -258,8 +262,10 @@ TEST_F(ConvolutionFoldingTest, BackwardInputConvolveEvenPadding) {
conv_dnums.set_output_batch_dimension(0);
conv_dnums.set_input_feature_dimension(1);
conv_dnums.set_output_feature_dimension(1);
- conv_dnums.add_spatial_dimensions(2);
- conv_dnums.add_spatial_dimensions(3);
+ conv_dnums.add_input_spatial_dimensions(2);
+ conv_dnums.add_output_spatial_dimensions(2);
+ conv_dnums.add_input_spatial_dimensions(3);
+ conv_dnums.add_output_spatial_dimensions(3);
conv_dnums.set_kernel_input_feature_dimension(0);
conv_dnums.set_kernel_output_feature_dimension(1);
conv_dnums.add_kernel_spatial_dimensions(2);
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
index 5fe5f55857..037eec8ef5 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
@@ -29,12 +29,12 @@ namespace se = ::perftools::gputools;
namespace xla {
namespace gpu {
+using se::dnn::AlgorithmDesc;
using se::dnn::BatchDescriptor;
using se::dnn::ConvolutionDescriptor;
using se::dnn::DataLayout;
using se::dnn::FilterDescriptor;
using se::dnn::FilterLayout;
-using se::dnn::AlgorithmDesc;
ConvolveScratchAllocator::ConvolveScratchAllocator(
int device_ordinal, DeviceMemoryAllocator* memory_allocator)
@@ -131,8 +131,9 @@ tensorflow::Status ConvolutionThunk::ExecuteOnStream(
const int effective_num_dimensions = std::max(2, num_dimensions);
CHECK_EQ(F32, output_shape_.element_type());
- CHECK_EQ(num_dimensions, dim_nums_.spatial_dimensions_size());
+ CHECK_EQ(num_dimensions, dim_nums_.input_spatial_dimensions_size());
CHECK_EQ(num_dimensions, dim_nums_.kernel_spatial_dimensions_size());
+ CHECK_EQ(num_dimensions, dim_nums_.output_spatial_dimensions_size());
for (const WindowDimension& dim : window_.dimensions()) {
CHECK_EQ(dim.padding_low(), dim.padding_high());
}
@@ -148,7 +149,7 @@ tensorflow::Status ConvolutionThunk::ExecuteOnStream(
// Note that the dimensions are reversed. The same holds below.
input_descriptor.set_spatial_dim(
static_cast<se::dnn::DimIndex>(effective_num_dimensions - dim - 1),
- input_shape_.dimensions(dim_nums_.spatial_dimensions(dim)));
+ input_shape_.dimensions(dim_nums_.input_spatial_dimensions(dim)));
}
FilterDescriptor filter_descriptor(effective_num_dimensions);
@@ -182,7 +183,7 @@ tensorflow::Status ConvolutionThunk::ExecuteOnStream(
for (int dim = 0; dim < num_dimensions; ++dim) {
output_descriptor.set_spatial_dim(
static_cast<se::dnn::DimIndex>(effective_num_dimensions - dim - 1),
- output_shape_.dimensions(dim_nums_.spatial_dimensions(dim)));
+ output_shape_.dimensions(dim_nums_.output_spatial_dimensions(dim)));
}
// Add a singleton dimension in the 1D convolution case.
diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
index 9a4bfd0905..1d47ffde43 100644
--- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
@@ -156,8 +156,10 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfConvolutionUnfused) {
conv_dnums.set_output_batch_dimension(0);
conv_dnums.set_input_feature_dimension(1);
conv_dnums.set_output_feature_dimension(1);
- conv_dnums.add_spatial_dimensions(2);
- conv_dnums.add_spatial_dimensions(3);
+ conv_dnums.add_input_spatial_dimensions(2);
+ conv_dnums.add_output_spatial_dimensions(2);
+ conv_dnums.add_input_spatial_dimensions(3);
+ conv_dnums.add_output_spatial_dimensions(3);
conv_dnums.set_kernel_output_feature_dimension(0);
conv_dnums.set_kernel_input_feature_dimension(1);
conv_dnums.add_kernel_spatial_dimensions(2);
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
index 8fb7a6adda..658fd05cd4 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
@@ -100,7 +100,7 @@ bool ImplementedAsDnnConvolution(const HloInstruction& hlo) {
if (hlo.opcode() == HloOpcode::kConvolution) {
const ConvolutionDimensionNumbers& dnums =
hlo.convolution_dimension_numbers();
- if (dnums.spatial_dimensions_size() > 3) {
+ if (dnums.input_spatial_dimensions_size() > 3) {
return false;
}
diff --git a/tensorflow/compiler/xla/service/gpu/layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/layout_assignment.cc
index 0bbd63fb7b..d475c4171b 100644
--- a/tensorflow/compiler/xla/service/gpu/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/gpu/layout_assignment.cc
@@ -80,9 +80,9 @@ Status GpuLayoutAssignment::AddBackendConstraints(
const ConvolutionDimensionNumbers& dimension_numbers =
instruction->convolution_dimension_numbers();
std::vector<int64> input_layout;
- for (int i = dimension_numbers.spatial_dimensions_size() - 1; i >= 0;
- --i) {
- input_layout.push_back(dimension_numbers.spatial_dimensions(i));
+ for (int i = dimension_numbers.input_spatial_dimensions_size() - 1;
+ i >= 0; --i) {
+ input_layout.push_back(dimension_numbers.input_spatial_dimensions(i));
}
input_layout.push_back(dimension_numbers.input_feature_dimension());
input_layout.push_back(dimension_numbers.input_batch_dimension());
@@ -102,9 +102,9 @@ Status GpuLayoutAssignment::AddBackendConstraints(
*filter_shape.mutable_layout() = LayoutUtil::MakeLayout(filter_layout);
std::vector<int64> output_layout;
- for (int i = dimension_numbers.spatial_dimensions_size() - 1; i >= 0;
- --i) {
- output_layout.push_back(dimension_numbers.spatial_dimensions(i));
+ for (int i = dimension_numbers.output_spatial_dimensions_size() - 1;
+ i >= 0; --i) {
+ output_layout.push_back(dimension_numbers.output_spatial_dimensions(i));
}
output_layout.push_back(dimension_numbers.output_feature_dimension());
output_layout.push_back(dimension_numbers.output_batch_dimension());
diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
index 9274e16a45..11290eda4f 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
+++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
@@ -49,8 +49,8 @@ HloInstruction* MaybePaddedAndSlicedInput(
// applies positive padding and dilation.
PaddingConfig padding_config =
MakeNoPaddingConfig(input->shape().dimensions_size());
- for (size_t i = 0; i < conv_dnums.spatial_dimensions().size(); ++i) {
- int64 dim = conv_dnums.spatial_dimensions(i);
+ for (size_t i = 0; i < conv_dnums.input_spatial_dimensions().size(); ++i) {
+ int64 dim = conv_dnums.input_spatial_dimensions(i);
padding_config.mutable_dimensions(dim)->set_edge_padding_low(
std::max<int64>(0LL, conv_window.dimensions(i).padding_low()));
padding_config.mutable_dimensions(dim)->set_edge_padding_high(
@@ -81,8 +81,8 @@ HloInstruction* MaybePaddedAndSlicedInput(
std::vector<int64> limit_indices(input->shape().dimensions().begin(),
input->shape().dimensions().end());
std::vector<int64> strides(input->shape().dimensions_size(), 1);
- for (size_t i = 0; i < conv_dnums.spatial_dimensions().size(); ++i) {
- int64 dim = conv_dnums.spatial_dimensions(i);
+ for (size_t i = 0; i < conv_dnums.input_spatial_dimensions().size(); ++i) {
+ int64 dim = conv_dnums.input_spatial_dimensions(i);
// If dimension "dim" has negative padding, increase the start index or
// decrement the limit index by the amount of negative padding.
start_indices[dim] +=
@@ -117,8 +117,8 @@ HloInstruction* MaybePaddedKernel(const Window& conv_window,
for (size_t i = 0; i < kernel->shape().dimensions_size(); ++i) {
padding_config.add_dimensions();
}
- for (size_t i = 0; i < conv_dnums.spatial_dimensions().size(); ++i) {
- int64 dim = conv_dnums.spatial_dimensions(i);
+ for (size_t i = 0; i < conv_dnums.kernel_spatial_dimensions().size(); ++i) {
+ int64 dim = conv_dnums.kernel_spatial_dimensions(i);
padding_config.mutable_dimensions(dim)->set_interior_padding(
conv_window.dimensions(i).window_dilation() - 1);
}
@@ -229,7 +229,7 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution(
// later. Therefore, the amount of new padding (low or high) is the minimum
// of the amount of old padding low and old padding high.
int64 new_conv_padding = std::min(padding_low, padding_high);
- int64 dim = backward_conv_dnums.spatial_dimensions(i);
+ int64 dim = backward_conv_dnums.input_spatial_dimensions(i);
input_padding_config.mutable_dimensions(dim)->set_edge_padding_low(
padding_low - new_conv_padding);
input_padding_config.mutable_dimensions(dim)->set_edge_padding_high(
@@ -369,12 +369,11 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution(
std::vector<int64> limit_indices(
new_backward_conv->shape().dimensions().begin(),
new_backward_conv->shape().dimensions().end());
- std::vector<int64> strides(new_backward_conv->shape().dimensions_size(),
- 1LL);
+ std::vector<int64> strides(new_backward_conv->shape().dimensions_size(), 1LL);
for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) {
int64 padding_low = backward_conv->window().dimensions(i).padding_low();
int64 padding_high = backward_conv->window().dimensions(i).padding_high();
- int64 dim = backward_conv_dnums.spatial_dimensions(i);
+ int64 dim = backward_conv_dnums.output_spatial_dimensions(i);
if (padding_low > padding_high) {
// If the amount of low padding (of the old backward convolution) is
// larger, we internally pad the low end of the activations and slice
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index 0a1ebe3416..e693d167a1 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -812,7 +812,8 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
CHECK(ShapeUtil::SameElementType(lhs_shape, result_shape));
const auto& dnums = conv->convolution_dimension_numbers();
- const int64 num_spatial_dims = dnums.spatial_dimensions_size();
+ const int64 num_spatial_dims = dnums.output_spatial_dimensions_size();
+ CHECK_EQ(num_spatial_dims, dnums.input_spatial_dimensions_size());
CHECK_EQ(num_spatial_dims, dnums.kernel_spatial_dimensions_size());
CHECK_GE(num_spatial_dims, 0);
CHECK_EQ(window.dimensions_size(), num_spatial_dims);
@@ -877,13 +878,15 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
// Find corresponding spatial dimension index for input (lhs).
for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) {
// Spatial dimension number for input (lhs) and output.
- const int64 spatial_dim = dnums.spatial_dimensions(ki);
+ const int64 input_spatial_dim = dnums.input_spatial_dimensions(ki);
+ const int64 output_spatial_dim =
+ dnums.output_spatial_dimensions(ki);
// Calculate lhs (input) index without taking base dilation into
// account.
const auto& window_dim = window.dimensions(ki);
const int64 undilated_index =
- out_index[spatial_dim] * window_dim.stride() -
+ out_index[output_spatial_dim] * window_dim.stride() -
window_dim.padding_low() +
rhs_spatial_index[ki] * window_dim.window_dilation();
// Skip if the lhs (input) index is to be dilated.
@@ -892,12 +895,13 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
}
// Calculate the actual lhs (input) index after dilation.
- lhs_index[spatial_dim] =
+ lhs_index[input_spatial_dim] =
undilated_index / window_dim.base_dilation();
// Skip if input index is not in bound.
- if (!(lhs_index[spatial_dim] >= 0 &&
- lhs_index[spatial_dim] < lhs_shape.dimensions(spatial_dim))) {
+ if (!(lhs_index[input_spatial_dim] >= 0 &&
+ lhs_index[input_spatial_dim] <
+ lhs_shape.dimensions(input_spatial_dim))) {
goto cnt;
}
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
index d0d6029d5f..b2c4351896 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
@@ -751,7 +751,8 @@ TEST_F(HloEvaluatorTest, SimpleConv1D) {
dnums.set_output_batch_dimension(0);
dnums.set_input_feature_dimension(1);
dnums.set_output_feature_dimension(1);
- dnums.add_spatial_dimensions(2);
+ dnums.add_input_spatial_dimensions(2);
+ dnums.add_output_spatial_dimensions(2);
dnums.set_kernel_output_feature_dimension(0);
dnums.set_kernel_input_feature_dimension(1);
@@ -886,8 +887,10 @@ TEST_F(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) {
dnums.set_output_batch_dimension(2);
dnums.set_input_feature_dimension(0);
dnums.set_output_feature_dimension(0);
- dnums.add_spatial_dimensions(1);
- dnums.add_spatial_dimensions(3);
+ dnums.add_input_spatial_dimensions(1);
+ dnums.add_output_spatial_dimensions(1);
+ dnums.add_input_spatial_dimensions(3);
+ dnums.add_output_spatial_dimensions(3);
dnums.set_kernel_output_feature_dimension(0);
dnums.set_kernel_input_feature_dimension(2);
@@ -960,8 +963,10 @@ TEST_F(HloEvaluatorTest, Conv2DGeneralDimensions) {
dnums.set_output_batch_dimension(2);
dnums.set_input_feature_dimension(0);
dnums.set_output_feature_dimension(0);
- dnums.add_spatial_dimensions(1);
- dnums.add_spatial_dimensions(3);
+ dnums.add_input_spatial_dimensions(1);
+ dnums.add_output_spatial_dimensions(1);
+ dnums.add_input_spatial_dimensions(3);
+ dnums.add_output_spatial_dimensions(3);
dnums.set_kernel_output_feature_dimension(0);
dnums.set_kernel_input_feature_dimension(2);
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 854185af56..c30c432654 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -3021,25 +3021,25 @@ string HloInstruction::ConvolutionDimensionNumbersToString() const {
// lhs_dims[i] is the symbol of the logical dimension i for the lhs
// operand. E.g. if batch has dimension number 2, then lhs_dims[2] == "b".
- std::vector<string> lhs_dims(2 + dnums.spatial_dimensions().size());
+ std::vector<string> lhs_dims(2 + dnums.input_spatial_dimensions().size());
lhs_dims[dnums.input_batch_dimension()] = 'b';
lhs_dims[dnums.input_feature_dimension()] = 'f';
- for (int64 i = 0; i < dnums.spatial_dimensions().size(); ++i) {
- lhs_dims[dnums.spatial_dimensions(i)] = StrCat(i);
+ for (int64 i = 0; i < dnums.input_spatial_dimensions().size(); ++i) {
+ lhs_dims[dnums.input_spatial_dimensions(i)] = StrCat(i);
}
std::vector<string> rhs_dims(2 + dnums.kernel_spatial_dimensions().size());
rhs_dims[dnums.kernel_input_feature_dimension()] = "i";
rhs_dims[dnums.kernel_output_feature_dimension()] = "o";
- for (int64 i = 0; i < dnums.spatial_dimensions().size(); ++i) {
+ for (int64 i = 0; i < dnums.kernel_spatial_dimensions().size(); ++i) {
rhs_dims[dnums.kernel_spatial_dimensions(i)] = StrCat(i);
}
- std::vector<string> output_dims(2 + dnums.spatial_dimensions().size());
+ std::vector<string> output_dims(2 + dnums.output_spatial_dimensions().size());
output_dims[dnums.output_batch_dimension()] = 'b';
output_dims[dnums.output_feature_dimension()] = 'f';
- for (int64 i = 0; i < dnums.spatial_dimensions().size(); ++i) {
- output_dims[dnums.spatial_dimensions(i)] = StrCat(i);
+ for (int64 i = 0; i < dnums.output_spatial_dimensions().size(); ++i) {
+ output_dims[dnums.output_spatial_dimensions(i)] = StrCat(i);
}
result += "dim_labels=";
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 0a2bf939c1..3df1911d07 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -1445,7 +1445,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
ShapeUtil::HumanString(lhs).c_str(),
ShapeUtil::HumanString(rhs).c_str());
}
- if (dnums.spatial_dimensions_size() !=
+ if (dnums.input_spatial_dimensions_size() !=
dnums.kernel_spatial_dimensions_size()) {
return InvalidArgument(
"Both arguments to convolution must have same number of dimensions.\n"
@@ -1453,7 +1453,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
window.DebugString().c_str());
}
- const int num_spatial_dims = dnums.spatial_dimensions_size();
+ const int num_spatial_dims = dnums.input_spatial_dimensions_size();
if (window.dimensions_size() != num_spatial_dims) {
return InvalidArgument(
"Window must have same number of dimensions as dimension numbers.\n"
@@ -1482,8 +1482,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
std::vector<int64> input_dnums(num_dims);
input_dnums[0] = dnums.input_batch_dimension();
input_dnums[1] = dnums.input_feature_dimension();
- std::copy(dnums.spatial_dimensions().begin(),
- dnums.spatial_dimensions().end(), input_dnums.begin() + 2);
+ std::copy(dnums.input_spatial_dimensions().begin(),
+ dnums.input_spatial_dimensions().end(), input_dnums.begin() + 2);
std::sort(input_dnums.begin(), input_dnums.end());
std::vector<int64> window_dnums(num_dims);
@@ -1493,12 +1493,20 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
dnums.kernel_spatial_dimensions().end(), window_dnums.begin() + 2);
std::sort(window_dnums.begin(), window_dnums.end());
+ std::vector<int64> output_dnums(num_dims);
+ output_dnums[0] = dnums.output_batch_dimension();
+ output_dnums[1] = dnums.output_feature_dimension();
+ std::copy(dnums.output_spatial_dimensions().begin(),
+ dnums.output_spatial_dimensions().end(), output_dnums.begin() + 2);
+ std::sort(output_dnums.begin(), output_dnums.end());
+
std::vector<int64> expected_dnums(num_dims);
std::iota(expected_dnums.begin(), expected_dnums.end(), 0);
const auto in_range = [num_dims](int64 i) { return 0 <= i && i < num_dims; };
if (!std::all_of(input_dnums.begin(), input_dnums.end(), in_range) ||
- !std::all_of(window_dnums.begin(), window_dnums.end(), in_range)) {
+ !std::all_of(window_dnums.begin(), window_dnums.end(), in_range) ||
+ !std::all_of(output_dnums.begin(), output_dnums.end(), in_range)) {
return InvalidArgument(
"A dimension number is out of range in convolution: %s",
dnums.DebugString().c_str());
@@ -1516,10 +1524,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
"once: %s",
dnums.DebugString().c_str());
}
+ if (output_dnums != expected_dnums) {
+ return InvalidArgument(
+ "Output dimensions of convolution must contain each dimension exactly "
+ "once: %s",
+ dnums.DebugString().c_str());
+ }
std::vector<int64> input_spatial_dims(num_spatial_dims);
for (int i = 0; i < num_spatial_dims; ++i) {
- input_spatial_dims[i] = lhs.dimensions(dnums.spatial_dimensions(i));
+ input_spatial_dims[i] = lhs.dimensions(dnums.input_spatial_dimensions(i));
}
const int64 input_features = lhs.dimensions(dnums.input_feature_dimension());
const int64 input_batch = lhs.dimensions(dnums.input_batch_dimension());
@@ -1567,7 +1581,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
dimensions[dnums.output_batch_dimension()] = input_batch;
dimensions[dnums.output_feature_dimension()] = kernel_output_features;
for (int i = 0; i < num_spatial_dims; ++i) {
- dimensions[dnums.spatial_dimensions(i)] = window_output_shape.dimensions(i);
+ dimensions[dnums.output_spatial_dimensions(i)] =
+ window_output_shape.dimensions(i);
}
return ShapeUtil::MakeShape(lhs.element_type(), dimensions);
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc
index d12f7bd145..be93c879c0 100644
--- a/tensorflow/compiler/xla/service/shape_inference_test.cc
+++ b/tensorflow/compiler/xla/service/shape_inference_test.cc
@@ -395,8 +395,10 @@ TEST_F(ShapeInferenceTest, Convolve) {
dnums.set_output_batch_dimension(0);
dnums.set_input_feature_dimension(1);
dnums.set_output_feature_dimension(1);
- dnums.add_spatial_dimensions(2);
- dnums.add_spatial_dimensions(3);
+ dnums.add_input_spatial_dimensions(2);
+ dnums.add_output_spatial_dimensions(2);
+ dnums.add_input_spatial_dimensions(3);
+ dnums.add_output_spatial_dimensions(3);
// Dimension order: x1, batch, feature, x0
Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 3});
@@ -437,8 +439,10 @@ TEST_F(ShapeInferenceTest, ConvolveWithWindowDilation) {
dnums.set_output_batch_dimension(0);
dnums.set_input_feature_dimension(1);
dnums.set_output_feature_dimension(1);
- dnums.add_spatial_dimensions(2);
- dnums.add_spatial_dimensions(3);
+ dnums.add_input_spatial_dimensions(2);
+ dnums.add_output_spatial_dimensions(2);
+ dnums.add_input_spatial_dimensions(3);
+ dnums.add_output_spatial_dimensions(3);
// Dimension order: x1, batch, feature, x0
Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 3});
@@ -480,8 +484,10 @@ TEST_F(ShapeInferenceTest, ConvolveWithBaseDilation) {
dnums.set_output_batch_dimension(0);
dnums.set_input_feature_dimension(1);
dnums.set_output_feature_dimension(1);
- dnums.add_spatial_dimensions(2);
- dnums.add_spatial_dimensions(3);
+ dnums.add_input_spatial_dimensions(2);
+ dnums.add_output_spatial_dimensions(2);
+ dnums.add_input_spatial_dimensions(3);
+ dnums.add_output_spatial_dimensions(3);
// Dimension order: x1, batch, feature, x0
Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 4});
@@ -524,8 +530,10 @@ TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) {
dnums.set_output_batch_dimension(3);
dnums.set_input_feature_dimension(2);
dnums.set_output_feature_dimension(2);
- dnums.add_spatial_dimensions(0);
- dnums.add_spatial_dimensions(1);
+ dnums.add_input_spatial_dimensions(0);
+ dnums.add_output_spatial_dimensions(0);
+ dnums.add_input_spatial_dimensions(1);
+ dnums.add_output_spatial_dimensions(1);
dnums.set_kernel_input_feature_dimension(0); // duplicated with kernel_x0
dnums.set_kernel_output_feature_dimension(3);
dnums.add_kernel_spatial_dimensions(0);
diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc
index 8c2640adf5..fb55d4e543 100644
--- a/tensorflow/compiler/xla/service/transpose_folding.cc
+++ b/tensorflow/compiler/xla/service/transpose_folding.cc
@@ -58,27 +58,11 @@ TransposeFolding::OperandIndices CanFoldOperandsIntoConvolution(
return {};
}
- const ConvolutionDimensionNumbers& dnums =
- convolution.convolution_dimension_numbers();
-
TransposeFolding::OperandIndices operand_set;
for (int64 i = 0; i < convolution.operand_count(); ++i) {
auto& operand = *convolution.operand(i);
if (operand.opcode() == HloOpcode::kTranspose &&
operand.user_count() == 1) {
- const auto& transpose_dimensions = operand.dimensions();
- // We can transpose the LHS so long as it doesn't move around spatial
- // dimensions because ConvolutionDimensionNumbers doesn't have different
- // fields for input and output spatial dimensions.
- if (i == 0 &&
- std::any_of(dnums.spatial_dimensions().begin(),
- dnums.spatial_dimensions().end(),
- [&](const int64 spatial_dimension) {
- return transpose_dimensions[spatial_dimension] !=
- spatial_dimension;
- })) {
- continue;
- }
operand_set.push_back(i);
}
}
@@ -137,7 +121,7 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) {
transpose_dimensions[dnums.input_batch_dimension()]);
new_dnums.set_input_feature_dimension(
transpose_dimensions[dnums.input_feature_dimension()]);
- for (const auto& spatial_dimension : dnums.spatial_dimensions()) {
+ for (const auto& spatial_dimension : dnums.input_spatial_dimensions()) {
CHECK_EQ(spatial_dimension, transpose_dimensions[spatial_dimension]);
}
new_lhs = &transpose_operand;
diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc
index 00462f9be1..6ac32e88f1 100644
--- a/tensorflow/compiler/xla/service/transpose_folding_test.cc
+++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc
@@ -362,10 +362,18 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) {
EXPECT_EQ(
dnums.input_batch_dimension(),
new_conv->convolution_dimension_numbers().input_feature_dimension());
- EXPECT_EQ(dnums.spatial_dimensions(0),
- new_conv->convolution_dimension_numbers().spatial_dimensions(0));
- EXPECT_EQ(dnums.spatial_dimensions(1),
- new_conv->convolution_dimension_numbers().spatial_dimensions(1));
+ EXPECT_EQ(
+ dnums.input_spatial_dimensions(0),
+ new_conv->convolution_dimension_numbers().input_spatial_dimensions(0));
+ EXPECT_EQ(
+ dnums.input_spatial_dimensions(1),
+ new_conv->convolution_dimension_numbers().input_spatial_dimensions(1));
+ EXPECT_EQ(
+ dnums.output_spatial_dimensions(0),
+ new_conv->convolution_dimension_numbers().output_spatial_dimensions(0));
+ EXPECT_EQ(
+ dnums.output_spatial_dimensions(1),
+ new_conv->convolution_dimension_numbers().output_spatial_dimensions(1));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc
index b0a63bccbb..896b34fb6e 100644
--- a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc
+++ b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc
@@ -39,8 +39,8 @@ class ConvolutionDimensionNumbersTest : public ClientLibraryTestBase {};
// Tests the convolution operation with invalid input dimension numbers.
TEST_F(ConvolutionDimensionNumbersTest, InvalidInputDimensionNumbers) {
auto dimension_numbers_status =
- ComputationBuilder::CreateConvDimensionNumbers(0, 2, 0, 2, 2, 3, 0, 1, 2,
- 3);
+ ComputationBuilder::CreateConvDimensionNumbers(0, 2, 2, 3, 0, 1, 2, 3, 0,
+ 1, 2, 3);
ASSERT_FALSE(dimension_numbers_status.ok());
ASSERT_THAT(dimension_numbers_status.status().error_message(),
::testing::HasSubstr("input are not unique"));
@@ -49,13 +49,23 @@ TEST_F(ConvolutionDimensionNumbersTest, InvalidInputDimensionNumbers) {
// Tests the convolution operation with invalid weight dimension numbers.
TEST_F(ConvolutionDimensionNumbersTest, InvalidWeightDimensionNumbers) {
auto dimension_numbers_status =
- ComputationBuilder::CreateConvDimensionNumbers(0, 1, 0, 1, 2, 3, 2, 3, 2,
- 3);
+ ComputationBuilder::CreateConvDimensionNumbers(0, 1, 2, 3, 0, 1, 2, 3, 0,
+ 2, 2, 3);
ASSERT_FALSE(dimension_numbers_status.ok());
ASSERT_THAT(dimension_numbers_status.status().error_message(),
::testing::HasSubstr("weight are not unique"));
}
+// Tests the convolution operation with invalid output dimension numbers.
+TEST_F(ConvolutionDimensionNumbersTest, InvalidOutputDimensionNumbers) {
+ auto dimension_numbers_status =
+ ComputationBuilder::CreateConvDimensionNumbers(0, 1, 2, 3, 0, 2, 2, 3, 0,
+ 1, 2, 3);
+ ASSERT_FALSE(dimension_numbers_status.ok());
+ ASSERT_THAT(dimension_numbers_status.status().error_message(),
+ ::testing::HasSubstr("output are not unique"));
+}
+
XLA_TEST_F(ConvolutionDimensionNumbersTest,
TwoConvsWithDifferentDimensionNumbers) {
auto input_array = MakeUnique<Array4D<float>>(2, 3, 5, 5);
diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc
index 8de7c9ffdc..2924c08615 100644
--- a/tensorflow/compiler/xla/tests/convolution_test.cc
+++ b/tensorflow/compiler/xla/tests/convolution_test.cc
@@ -370,9 +370,12 @@ XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) {
ConvolutionDimensionNumbers dnums;
dnums.set_input_batch_dimension(0);
dnums.set_output_batch_dimension(0);
- dnums.add_spatial_dimensions(1);
- dnums.add_spatial_dimensions(2);
- dnums.add_spatial_dimensions(3);
+ dnums.add_input_spatial_dimensions(1);
+ dnums.add_output_spatial_dimensions(1);
+ dnums.add_input_spatial_dimensions(2);
+ dnums.add_output_spatial_dimensions(2);
+ dnums.add_input_spatial_dimensions(3);
+ dnums.add_output_spatial_dimensions(3);
dnums.set_input_feature_dimension(4);
dnums.set_output_feature_dimension(4);
dnums.add_kernel_spatial_dimensions(0);
@@ -423,8 +426,10 @@ XLA_TEST_F(ConvolutionTest, Convolve2D_1x3x3x5_3x3x5x5_Valid) {
ConvolutionDimensionNumbers dnums;
dnums.set_input_batch_dimension(0);
dnums.set_output_batch_dimension(0);
- dnums.add_spatial_dimensions(1);
- dnums.add_spatial_dimensions(2);
+ dnums.add_input_spatial_dimensions(1);
+ dnums.add_output_spatial_dimensions(1);
+ dnums.add_input_spatial_dimensions(2);
+ dnums.add_output_spatial_dimensions(2);
dnums.set_input_feature_dimension(3);
dnums.set_output_feature_dimension(3);
dnums.add_kernel_spatial_dimensions(0);
@@ -538,7 +543,8 @@ XLA_TEST_P(Convolve1D1WindowTest, Convolve1D1Window) {
ConvolutionDimensionNumbers dnums;
dnums.set_input_batch_dimension(0);
dnums.set_output_batch_dimension(0);
- dnums.add_spatial_dimensions(1);
+ dnums.add_input_spatial_dimensions(1);
+ dnums.add_output_spatial_dimensions(1);
dnums.set_input_feature_dimension(2);
dnums.set_output_feature_dimension(2);
dnums.add_kernel_spatial_dimensions(0);
diff --git a/tensorflow/compiler/xla/tests/convolution_variants_test.cc b/tensorflow/compiler/xla/tests/convolution_variants_test.cc
index 9b36e3722b..9c1145def8 100644
--- a/tensorflow/compiler/xla/tests/convolution_variants_test.cc
+++ b/tensorflow/compiler/xla/tests/convolution_variants_test.cc
@@ -320,9 +320,10 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter3x3in2x2Padded) {
Array4D<float> input_array(1, 1, 2, 2, {1, 2, 3, 4});
auto input = builder.ConstantR4FromArray4D<float>(input_array);
- const Array4D<float> filter_array(1, 1, 3, 3, {10000, 0, 1000, // row 0
- 0, 100, 0, // row 1
- 10, 0, 1}); // row 2
+ const Array4D<float> filter_array(1, 1, 3, 3,
+ {10000, 0, 1000, // row 0
+ 0, 100, 0, // row 1
+ 10, 0, 1}); // row 2
auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
builder.Conv(input, filter, {1, 1}, Padding::kSame);
@@ -472,7 +473,9 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x2Input3x1x2x2) {
builder.Conv(input, filter, {1, 1}, Padding::kValid);
std::vector<float> expected_data = {
- 23, 33, 43,
+ 23,
+ 33,
+ 43,
};
Array4D<float> expected(bs, 1, 1, 1, expected_data);
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
@@ -669,10 +672,11 @@ XLA_TEST_F(ConvolutionVariantsTest, FlatLhsDilation) {
std::iota(input_data.begin(), input_data.end(), 1.0);
Array4D<float> input_array(1, 1, 3, 4, input_data);
- Array4D<float> filter_array(1, 1, 4, 3, {100, 10, 1, //
- 200, 20, 2, //
- 300, 30, 3, //
- 400, 40, 4});
+ Array4D<float> filter_array(1, 1, 4, 3,
+ {100, 10, 1, //
+ 200, 20, 2, //
+ 300, 30, 3, //
+ 400, 40, 4});
auto input = builder.ConstantR4FromArray4D<float>(input_array);
auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
builder.ConvGeneralDilated(
@@ -681,9 +685,10 @@ XLA_TEST_F(ConvolutionVariantsTest, FlatLhsDilation) {
/*rhs_dilation=*/{},
ComputationBuilder::CreateDefaultConvDimensionNumbers());
- Array4D<float> expected(1, 1, 3, 5, {204, 40, 406, 60, 608, //
- 1518, 180, 1821, 210, 2124, //
- 4146, 460, 4651, 510, 5156});
+ Array4D<float> expected(1, 1, 3, 5,
+ {204, 40, 406, 60, 608, //
+ 1518, 180, 1821, 210, 2124, //
+ 4146, 460, 4651, 510, 5156});
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
}
@@ -926,7 +931,8 @@ XLA_TEST_F(ConvolutionVariantsTest, RandomData_Input16x16x1x1_Filter16x16x1x1) {
ComputeAndCompareR4<float>(&builder, *expected, {}, error_spec_);
}
-XLA_TEST_F(ConvolutionVariantsTest, RandomData_Input16x16x16x16_Filter16x16x16x16) {
+XLA_TEST_F(ConvolutionVariantsTest,
+ RandomData_Input16x16x16x16_Filter16x16x16x16) {
constexpr int bs = 16;
constexpr int iz = 16;
constexpr int oz = 16;
@@ -976,8 +982,10 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2x1x1Input1x2x3x1GeneralPadding) {
// NHWC input format.
dnums.set_input_batch_dimension(0);
dnums.set_output_batch_dimension(0);
- dnums.add_spatial_dimensions(1);
- dnums.add_spatial_dimensions(2);
+ dnums.add_input_spatial_dimensions(1);
+ dnums.add_output_spatial_dimensions(1);
+ dnums.add_input_spatial_dimensions(2);
+ dnums.add_output_spatial_dimensions(2);
dnums.set_input_feature_dimension(3);
dnums.set_output_feature_dimension(3);
@@ -1018,8 +1026,10 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1GeneralPadding) {
// NHWC input format.
dnums.set_input_batch_dimension(0);
dnums.set_output_batch_dimension(0);
- dnums.add_spatial_dimensions(1);
- dnums.add_spatial_dimensions(2);
+ dnums.add_input_spatial_dimensions(1);
+ dnums.add_output_spatial_dimensions(1);
+ dnums.add_input_spatial_dimensions(2);
+ dnums.add_output_spatial_dimensions(2);
dnums.set_input_feature_dimension(3);
dnums.set_output_feature_dimension(3);
@@ -1060,8 +1070,10 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1NoPadding) {
// NHWC input format.
dnums.set_input_batch_dimension(0);
dnums.set_output_batch_dimension(0);
- dnums.add_spatial_dimensions(1);
- dnums.add_spatial_dimensions(2);
+ dnums.add_input_spatial_dimensions(1);
+ dnums.add_output_spatial_dimensions(1);
+ dnums.add_input_spatial_dimensions(2);
+ dnums.add_output_spatial_dimensions(2);
dnums.set_input_feature_dimension(3);
dnums.set_output_feature_dimension(3);
@@ -1099,8 +1111,10 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x3Input1x2x3x2NoPadding) {
// NHWC input format.
dnums.set_input_batch_dimension(0);
dnums.set_output_batch_dimension(0);
- dnums.add_spatial_dimensions(1);
- dnums.add_spatial_dimensions(2);
+ dnums.add_input_spatial_dimensions(1);
+ dnums.add_output_spatial_dimensions(1);
+ dnums.add_input_spatial_dimensions(2);
+ dnums.add_output_spatial_dimensions(2);
dnums.set_input_feature_dimension(3);
dnums.set_output_feature_dimension(3);
@@ -1131,7 +1145,8 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x3Input1x2x3x2NoPadding) {
// Conv([1,2,3], Reverse([5,6]), padding_low=1)
// into
// BackwardInputConv([1,2,3], [5,6], padding_low=0, padding_high=1)
-XLA_TEST_F(ConvolutionVariantsTest, BackwardInputLowPaddingLessThanHighPadding) {
+XLA_TEST_F(ConvolutionVariantsTest,
+ BackwardInputLowPaddingLessThanHighPadding) {
ComputationBuilder builder(client_, TestName());
auto gradients = builder.ConstantR4FromArray4D<float>(
@@ -1149,7 +1164,8 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputLowPaddingLessThanHighPadding)
// Conv([1], Reverse([1,10,100]), padding_high=3, base_dilation=3)
// into
// BackwardInputConv([1], [1,10,100], stride=3, padding=(2,1))
-XLA_TEST_F(ConvolutionVariantsTest, BackwardInputLowPaddingGreaterThanHighPadding) {
+XLA_TEST_F(ConvolutionVariantsTest,
+ BackwardInputLowPaddingGreaterThanHighPadding) {
ComputationBuilder builder(client_, TestName());
auto gradients = builder.ConstantR4FromArray4D<float>(
@@ -1206,7 +1222,8 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputWithNegativePaddingHigh) {
ComputeAndCompareR4<float>(&builder, {{{{12, 23, 30, 0}}}}, {}, error_spec_);
}
-XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterLowPaddingLessThanHighPadding) {
+XLA_TEST_F(ConvolutionVariantsTest,
+ BackwardFilterLowPaddingLessThanHighPadding) {
ComputationBuilder builder(client_, TestName());
// activations: 1,2,3,4 ---pad--> 0,1,2,3,4,0,0
@@ -1230,7 +1247,7 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterLowPaddingLessThanHighPadding)
}
XLA_TEST_F(ConvolutionVariantsTest,
- BackwardFilterLowPaddingGreaterThanHighPadding) {
+ BackwardFilterLowPaddingGreaterThanHighPadding) {
ComputationBuilder builder(client_, TestName());
// activations: 1,2,3,4 ---pad--> 0,0,1,2,3,4
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
index a10497665a..47979ec6f3 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
@@ -1685,7 +1685,7 @@ bool HloParser::ParseConvolutionDimensionNumbers(
StrCat("expects unique lhs dimension numbers, but sees ", lhs));
}
for (int i = 0; i < rank - 2; i++) {
- dnums->add_spatial_dimensions(-1);
+ dnums->add_input_spatial_dimensions(-1);
}
for (int i = 0; i < rank; i++) {
char c = lhs[i];
@@ -1694,7 +1694,7 @@ bool HloParser::ParseConvolutionDimensionNumbers(
} else if (c == 'f') {
dnums->set_input_feature_dimension(i);
} else if (c < '0' + rank && c >= '0') {
- dnums->set_spatial_dimensions(c - '0', i);
+ dnums->set_input_spatial_dimensions(c - '0', i);
} else {
return TokenError(
Printf("expects [0-%lldbf] in lhs dimension numbers", rank - 1));
@@ -1732,6 +1732,9 @@ bool HloParser::ParseConvolutionDimensionNumbers(
return TokenError(
StrCat("expects unique output dimension numbers, but sees ", out));
}
+ for (int i = 0; i < rank - 2; i++) {
+ dnums->add_output_spatial_dimensions(-1);
+ }
for (int i = 0; i < rank; i++) {
char c = out[i];
if (c == 'b') {
@@ -1739,11 +1742,7 @@ bool HloParser::ParseConvolutionDimensionNumbers(
} else if (c == 'f') {
dnums->set_output_feature_dimension(i);
} else if (c < '0' + rank && c >= '0') {
- if (dnums->spatial_dimensions(c - '0') != i) {
- return TokenError(
- "output spatial dimensions should be the same as input spatial "
- "dimensions");
- }
+ dnums->set_output_spatial_dimensions(c - '0', i);
} else {
return TokenError(
Printf("expects [0-%lldbf] in output dimension numbers", rank - 1));
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
index e56f120def..90cdb87a1e 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
@@ -873,12 +873,6 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2
.status()
.error_message(),
"must have the same rank");
-
- ExpectHasSubstr(Parse(StrCat(prefix, ",dim_labels=0bf_io0->b0f", suffix))
- .status()
- .error_message(),
- "output spatial dimensions should be the same as input "
- "spatial dimensions");
}
TEST_F(HloParserTest, UnexpectedAttribute) {
diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto
index d3c5a88807..b560354050 100644
--- a/tensorflow/compiler/xla/xla_data.proto
+++ b/tensorflow/compiler/xla/xla_data.proto
@@ -417,15 +417,9 @@ message ConvolutionDimensionNumbers {
// The number of the dimension that represents features in the input.
int64 input_feature_dimension = 8;
- // The number of the dimension that represents batch in the output.
- int64 output_batch_dimension = 9;
-
- // The number of the dimension that represents features in the output.
- int64 output_feature_dimension = 10;
-
// The dimension numbers for the spatial dimensions that the window
- // moves through in the input (lhs) and output.
- repeated int64 spatial_dimensions = 5;
+ // moves through in the input.
+ repeated int64 input_spatial_dimensions = 11;
// The number of the dimension that represents input features in the
// convolutional kernel (rhs).
@@ -439,6 +433,18 @@ message ConvolutionDimensionNumbers {
// moves through in the kernel (rhs). window.strides(0) is the
// stride in the kernel_spatial_dimensions(0) dimension.
repeated int64 kernel_spatial_dimensions = 6;
+
+ // The number of the dimension that represents batch in the output.
+ int64 output_batch_dimension = 9;
+
+ // The number of the dimension that represents features in the output.
+ int64 output_feature_dimension = 10;
+
+ // The dimension numbers for the spatial dimensions that the window
+ // moves through in the output.
+ repeated int64 output_spatial_dimensions = 12;
+
+ // Next = 13
};
message ConvolveRequest {