diff options
author | David Majnemer <majnemer@google.com> | 2017-11-28 20:41:47 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-11-28 20:45:25 -0800 |
commit | bdde4d040cf01ef241ad349cf222c227b9a88814 (patch) | |
tree | 0c7e7b9fe1ac0ded4e41c8855e6bae450bff93f6 /tensorflow/compiler/xla/service/transpose_folding_test.cc | |
parent | d2e7a2e4bf295a23d6a2e86aa7e0636f00cc2d75 (diff) |
[XLA] Support transposing the spatial dimensions of a convolution's activations
PiperOrigin-RevId: 177260886
Diffstat (limited to 'tensorflow/compiler/xla/service/transpose_folding_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/transpose_folding_test.cc | 64 |
1 files changed, 64 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index 6ac32e88f1..ba99852905 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -376,5 +376,69 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) { new_conv->convolution_dimension_numbers().output_spatial_dimensions(1)); } +// Test that a transpose of every dimension in the activations gets folded into +// convolution. +TEST_F(TransposeFoldingTest, FoldConvComplexTransposeLhs) { + auto builder = HloComputation::Builder("entry_computation"); + HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {3, 2, 1, 1}), + /*name=*/"x")); + HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), + /*name=*/"y")); + HloInstruction* transpose_x = + builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), x, {1, 0, 3, 2})); + auto dnums = ComputationBuilder::CreateDefaultConvDimensionNumbers(); + Window window; + for (int i = 0; i < 2; ++i) { + WindowDimension* dim = window.add_dimensions(); + dim->set_padding_low(0); + dim->set_padding_high(0); + dim->set_base_dilation(1); + dim->set_window_dilation(1); + dim->set_stride(1); + dim->set_size(y->shape().dimensions(dnums.kernel_spatial_dimensions(i))); + } + StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape( + transpose_x->shape(), y->shape(), window, dnums); + EXPECT_IS_OK(conv_shape); + HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( + conv_shape.ValueOrDie(), transpose_x, y, window, dnums)); + + HloModule module("test_module"); + HloComputation* entry_computation = + module.AddEntryComputation(builder.Build(conv)); + FoldTranspose(&module); + + // Instructions after folding: x, y, and the convolution. + std::unordered_set<HloInstruction*> instruction_set( + entry_computation->instructions().begin(), + entry_computation->instructions().end()); + EXPECT_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation."; + EXPECT_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation."; + EXPECT_EQ(1, instruction_set.size()) + << "entry_computation should contain exactly 3 instructions."; + HloInstruction* new_conv = *instruction_set.begin(); + EXPECT_EQ(HloOpcode::kConvolution, new_conv->opcode()); + EXPECT_EQ(dnums.input_feature_dimension(), + new_conv->convolution_dimension_numbers().input_batch_dimension()); + EXPECT_EQ( + dnums.input_batch_dimension(), + new_conv->convolution_dimension_numbers().input_feature_dimension()); + EXPECT_EQ( + dnums.input_spatial_dimensions(0), + new_conv->convolution_dimension_numbers().input_spatial_dimensions(1)); + EXPECT_EQ( + dnums.input_spatial_dimensions(1), + new_conv->convolution_dimension_numbers().input_spatial_dimensions(0)); + EXPECT_EQ( + dnums.output_spatial_dimensions(0), + new_conv->convolution_dimension_numbers().output_spatial_dimensions(0)); + EXPECT_EQ( + dnums.output_spatial_dimensions(1), + new_conv->convolution_dimension_numbers().output_spatial_dimensions(1)); +} + } // namespace } // namespace xla |