aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/transpose_folding_test.cc
diff options
context:
space:
mode:
authorGravatar David Majnemer <majnemer@google.com>2017-11-28 20:41:47 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-28 20:45:25 -0800
commitbdde4d040cf01ef241ad349cf222c227b9a88814 (patch)
tree0c7e7b9fe1ac0ded4e41c8855e6bae450bff93f6 /tensorflow/compiler/xla/service/transpose_folding_test.cc
parentd2e7a2e4bf295a23d6a2e86aa7e0636f00cc2d75 (diff)
[XLA] Support transposing the spatial dimensions of a convolution's activations
PiperOrigin-RevId: 177260886
Diffstat (limited to 'tensorflow/compiler/xla/service/transpose_folding_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/transpose_folding_test.cc64
1 files changed, 64 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc
index 6ac32e88f1..ba99852905 100644
--- a/tensorflow/compiler/xla/service/transpose_folding_test.cc
+++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc
@@ -376,5 +376,69 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) {
new_conv->convolution_dimension_numbers().output_spatial_dimensions(1));
}
+// Test that a transpose of every dimension in the activations gets folded into
+// convolution.
+TEST_F(TransposeFoldingTest, FoldConvComplexTransposeLhs) {
+ auto builder = HloComputation::Builder("entry_computation");
+ HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
+ /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {3, 2, 1, 1}),
+ /*name=*/"x"));
+ HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
+ /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {2, 3, 1, 1}),
+ /*name=*/"y"));
+ HloInstruction* transpose_x =
+ builder.AddInstruction(HloInstruction::CreateTranspose(
+ ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), x, {1, 0, 3, 2}));
+ auto dnums = ComputationBuilder::CreateDefaultConvDimensionNumbers();
+ Window window;
+ for (int i = 0; i < 2; ++i) {
+ WindowDimension* dim = window.add_dimensions();
+ dim->set_padding_low(0);
+ dim->set_padding_high(0);
+ dim->set_base_dilation(1);
+ dim->set_window_dilation(1);
+ dim->set_stride(1);
+ dim->set_size(y->shape().dimensions(dnums.kernel_spatial_dimensions(i)));
+ }
+ StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape(
+ transpose_x->shape(), y->shape(), window, dnums);
+ EXPECT_IS_OK(conv_shape);
+ HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
+ conv_shape.ValueOrDie(), transpose_x, y, window, dnums));
+
+ HloModule module("test_module");
+ HloComputation* entry_computation =
+ module.AddEntryComputation(builder.Build(conv));
+ FoldTranspose(&module);
+
+ // Instructions after folding: x, y, and the convolution.
+ std::unordered_set<HloInstruction*> instruction_set(
+ entry_computation->instructions().begin(),
+ entry_computation->instructions().end());
+ EXPECT_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation.";
+ EXPECT_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation.";
+ EXPECT_EQ(1, instruction_set.size())
+ << "entry_computation should contain exactly 3 instructions.";
+ HloInstruction* new_conv = *instruction_set.begin();
+ EXPECT_EQ(HloOpcode::kConvolution, new_conv->opcode());
+ EXPECT_EQ(dnums.input_feature_dimension(),
+ new_conv->convolution_dimension_numbers().input_batch_dimension());
+ EXPECT_EQ(
+ dnums.input_batch_dimension(),
+ new_conv->convolution_dimension_numbers().input_feature_dimension());
+ EXPECT_EQ(
+ dnums.input_spatial_dimensions(0),
+ new_conv->convolution_dimension_numbers().input_spatial_dimensions(1));
+ EXPECT_EQ(
+ dnums.input_spatial_dimensions(1),
+ new_conv->convolution_dimension_numbers().input_spatial_dimensions(0));
+ EXPECT_EQ(
+ dnums.output_spatial_dimensions(0),
+ new_conv->convolution_dimension_numbers().output_spatial_dimensions(0));
+ EXPECT_EQ(
+ dnums.output_spatial_dimensions(1),
+ new_conv->convolution_dimension_numbers().output_spatial_dimensions(1));
+}
+
} // namespace
} // namespace xla