aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/transpose_folding_test.cc
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-05-04 22:04:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-07 15:40:20 -0700
commit150089e6e67e4492f098cdd8f9f2f48dc9f9cc56 (patch)
tree778d8f20ab300ceea85a36c22d150570ff9530f8 /tensorflow/compiler/xla/service/transpose_folding_test.cc
parent939fc534a4b2f227ee337e7dcfa82ec9b6337814 (diff)
Remove uses of the kTransposeDot fusion
I didn't remove the enum itself, but after this change removing the enum should be a simple NFC change (famous last words!). This will make it easier to implement BatchDot on CPU. The change removes usages of kTransposeDot by: - Teaching TransposeFolding to "fuse" transposes into dots by flipping the lhs_contracting_dims/rhs_contracting_dims fields. - Replacing the notion of transpose_lhs/transpose_rhs in the IR emitters with "has a non-canonical LHS contraction dimension"/"has a non-canonical RHS contraction dimension" where the canonical LHS and RHS contraction dims [0] are 1 and 0. Some tests were getting away with creating Dot instructions with their dimensions numbers unset. I've fixed these to create canonical dot operations instead. It is possible (but hard to tell without trying) that some of the IR emission logic and Eigen runtime calls can now be simplified further. For instance, instead of passing in a `transpose_lhs` and `transpose_rhs` to the Eigen GEMM routines, we could instead pass in the LHS and RHS contraction dimensions directly. [0] See HloInstruction::CreateCanonicalDot. PiperOrigin-RevId: 195514907
Diffstat (limited to 'tensorflow/compiler/xla/service/transpose_folding_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/transpose_folding_test.cc219
1 files changed, 112 insertions, 107 deletions
diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc
index 0319109f7f..f73f1227aa 100644
--- a/tensorflow/compiler/xla/service/transpose_folding_test.cc
+++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
@@ -31,9 +32,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/logging.h"
+namespace op = xla::testing::opcode_matchers;
+
namespace xla {
namespace {
@@ -54,83 +58,102 @@ class TransposeFoldingTest : public HloTestBase {
};
TEST_F(TransposeFoldingTest, FoldDotTranspose) {
- auto builder = HloComputation::Builder("entry_computation");
- HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
- /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {2, 3}),
- /*name=*/"x"));
- HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
- /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {2, 3}),
- /*name=*/"y"));
- HloInstruction* transpose_y =
- builder.AddInstruction(HloInstruction::CreateTranspose(
- ShapeUtil::MakeShape(F32, {3, 2}), y, {1, 0}));
- DotDimensionNumbers dot_dnums;
- dot_dnums.add_lhs_contracting_dimensions(1);
- dot_dnums.add_rhs_contracting_dimensions(0);
- HloInstruction* dot = builder.AddInstruction(
- HloInstruction::CreateDot(ShapeUtil::MakeShape(F32, {2, 2}), /*lhs=*/x,
- /*rhs=*/transpose_y, dot_dnums));
+ string hlo_string = R"(
+HloModule FoldDotTranspose
+
+ENTRY entry_computation {
+ x = f32[2,3]{1,0} parameter(0)
+ y = f32[2,3]{1,0} parameter(1)
+ transpose = f32[3,2]{1,0} transpose(y), dimensions={1,0}
+ ROOT dot = f32[2,2]{1,0} dot(x, transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ tools::Parse(hlo_string));
- auto module = CreateNewModule("test_module");
- HloComputation* entry_computation =
- module->AddEntryComputation(builder.Build(dot));
FoldTranspose(module.get());
- // Instructions after folding: x, y, and the fusion.
- std::unordered_set<HloInstruction*> instruction_set(
- entry_computation->instructions().begin(),
- entry_computation->instructions().end());
- CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation.";
- CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation.";
- CHECK_EQ(1, instruction_set.size())
- << "entry_computation should contain exactly 3 instructions.";
- HloInstruction* fusion = *instruction_set.begin();
- EXPECT_EQ(HloOpcode::kFusion, fusion->opcode());
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ op::Dot(op::Parameter(0), op::Parameter(1),
+ /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/1));
+}
+
+TEST_F(TransposeFoldingTest, DontFoldTransposeOfBatchDim) {
+ string hlo_string = R"(
+HloModule FoldDotTranspose
- // The fusion instruction should contain two parameters, one transpose and
- // one dot.
- EXPECT_EQ(4, fusion->fused_instruction_count());
+ENTRY entry_computation {
+ x = f32[2,3] parameter(0)
+ y = f32[3,2] parameter(1)
+ transpose = f32[2,3] transpose(y), dimensions={1,0}
+ ROOT dot = f32[2] dot(x, transpose), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_contracting_dims={1}
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ tools::Parse(hlo_string));
+
+ TransposeFolding transpose_folding(
+ [](const HloInstruction& dot,
+ const TransposeFolding::OperandIndices& candidate_operands) {
+ return candidate_operands;
+ },
+ [](const HloInstruction& convolution,
+ const TransposeFolding::OperandIndices& candidate_operands) {
+ return candidate_operands;
+ });
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, transpose_folding.Run(module.get()));
+ EXPECT_FALSE(changed);
+}
+
+TEST_F(TransposeFoldingTest, DontFoldTransposeOfRank1Dot) {
+ string hlo_string = R"(
+HloModule FoldDotTranspose
+
+ENTRY entry_computation {
+ x = f32[3] parameter(0)
+ y = f32[3,2] parameter(1)
+ transpose = f32[2,3] transpose(y), dimensions={1,0}
+ ROOT dot = f32[2] dot(x, transpose), lhs_batch_dims={}, rhs_batch_dims={0}, lhs_contracting_dims={0}, rhs_contracting_dims={1}
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ tools::Parse(hlo_string));
+
+ TransposeFolding transpose_folding(
+ [](const HloInstruction& dot,
+ const TransposeFolding::OperandIndices& candidate_operands) {
+ return candidate_operands;
+ },
+ [](const HloInstruction& convolution,
+ const TransposeFolding::OperandIndices& candidate_operands) {
+ return candidate_operands;
+ });
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, transpose_folding.Run(module.get()));
+ EXPECT_FALSE(changed);
}
TEST_F(TransposeFoldingTest, FoldDotTransposeConstant) {
- auto builder = HloComputation::Builder("entry_computation");
- // 2x1
- HloInstruction* const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2<float>({{1}, {2}})));
- // 3x2
- HloInstruction* const1 =
- builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1, 2}, {3, 4}, {5, 6}})));
- HloInstruction* transpose0 =
- builder.AddInstruction(HloInstruction::CreateTranspose(
- ShapeUtil::MakeShape(F32, {1, 2}), const0, {1, 0}));
- HloInstruction* transpose1 =
- builder.AddInstruction(HloInstruction::CreateTranspose(
- ShapeUtil::MakeShape(F32, {2, 3}), const1, {1, 0}));
- DotDimensionNumbers dot_dnums;
- dot_dnums.add_lhs_contracting_dimensions(1);
- dot_dnums.add_rhs_contracting_dimensions(0);
- HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
- ShapeUtil::MakeShape(F32, {1, 3}),
- /*lhs=*/transpose0, /*rhs=*/transpose1, dot_dnums));
+ string hlo_string = R"(
+HloModule FoldDotTransposeConstant
+
+ENTRY entry_computation {
+ constant = f32[2,1]{1,0} constant(f32[2,1] { { 1 }, { 2 } })
+ transpose = f32[1,2]{1,0} transpose(constant), dimensions={1,0}
+ constant.1 = f32[3,2]{1,0} constant(f32[3,2] { { 1, 2 }, { 3, 4 }, { 5, 6 } })
+ transpose.1 = f32[2,3]{1,0} transpose(constant.1), dimensions={1,0}
+ ROOT dot = f32[1,3]{1,0} dot(transpose, transpose.1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ tools::Parse(hlo_string));
- auto module = CreateNewModule("test_module");
- HloComputation* entry_computation =
- module->AddEntryComputation(builder.Build(dot));
FoldTranspose(module.get());
- for (auto* instruction : entry_computation->instructions()) {
- if (instruction->opcode() == HloOpcode::kFusion) {
- CHECK_EQ(2, instruction->operand_count());
- EXPECT_EQ(const0, instruction->operand(0));
- EXPECT_EQ(const1, instruction->operand(1));
- }
- }
-
- // The created fusion instruction should contain two parameters, two
- // transposes (one for each parameter) and one dot.
- EXPECT_EQ(5,
- entry_computation->root_instruction()->fused_instruction_count());
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ op::Dot(op::Constant(), op::Constant(),
+ /*lhs_contracting_dim=*/0, /*rhs_contracting_dim=*/1));
}
TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) {
@@ -164,50 +187,32 @@ TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) {
EXPECT_EQ(6, callee_computation->instruction_count());
}
-TEST_F(TransposeFoldingTest, FoldDotTransposeInWhile) {
- auto builder = HloComputation::Builder("entry_computation");
- HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
- /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {2, 3}),
- /*name=*/"x"));
- HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
- /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {2, 3}),
- /*name=*/"y"));
- HloInstruction* transpose_y =
- builder.AddInstruction(HloInstruction::CreateTranspose(
- ShapeUtil::MakeShape(F32, {3, 2}), y, {1, 0}));
- DotDimensionNumbers dot_dnums;
- dot_dnums.add_lhs_contracting_dimensions(1);
- dot_dnums.add_rhs_contracting_dimensions(0);
- HloInstruction* dot = builder.AddInstruction(
- HloInstruction::CreateDot(ShapeUtil::MakeShape(F32, {2, 2}), /*lhs=*/x,
- /*rhs=*/transpose_y, dot_dnums));
-
- auto module = CreateNewModule("test_module");
- HloComputation* entry_computation =
- module->AddEntryComputation(builder.Build(dot));
+TEST_F(TransposeFoldingTest, FoldDotTransposeInCall) {
+ string hlo_string = R"(
+HloModule FoldDotTransposeInCall
- HloInstruction* call = module->OutlineExpressionFromComputation(
- {transpose_y, dot}, "outlined", entry_computation);
+callee {
+ name.0 = f32[2,3]{1,0} parameter(0)
+ name.1 = f32[2,3]{1,0} parameter(1)
+ transpose.clone = f32[3,2]{1,0} transpose(name.0), dimensions={1,0}
+ ROOT dot.clone = f32[2,2]{1,0} dot(name.1, transpose.clone), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+ENTRY entry_computation {
+ y = f32[2,3]{1,0} parameter(1)
+ x = f32[2,3]{1,0} parameter(0)
+ ROOT call = f32[2,2]{1,0} call(y, x), to_apply=callee
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ tools::Parse(hlo_string));
FoldTranspose(module.get());
- // Instructions after folding: x, y, and the fusion.
- std::unordered_set<HloInstruction*> instruction_set(
- entry_computation->instructions().begin(),
- entry_computation->instructions().end());
- CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation.";
- CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation.";
- CHECK_EQ(1, instruction_set.erase(call))
- << "call is not in entry_computation.";
- CHECK(instruction_set.empty())
- << "entry_computation should contain exactly 3 instructions.";
- HloInstruction* fusion =
- call->called_computations().front()->root_instruction();
- EXPECT_EQ(HloOpcode::kFusion, fusion->opcode());
-
- // The fusion instruction should contain two parameters, one transpose and
- // one dot.
- EXPECT_EQ(4, fusion->fused_instruction_count());
+ const HloComputation* callee = module->GetComputationWithName("callee");
+ ASSERT_NE(callee, nullptr);
+ EXPECT_THAT(callee->root_instruction(),
+ op::Dot(op::Parameter(1), op::Parameter(0),
+ /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/1));
}
// Test that a two dimension swap of the kernel gets folded into convolution.