aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/copy_insertion_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/copy_insertion_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion_test.cc183
1 files changed, 183 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc
index 892d0d7b54..3096206c34 100644
--- a/tensorflow/compiler/xla/service/copy_insertion_test.cc
+++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc
@@ -1351,6 +1351,189 @@ TEST_F(CopyInsertionTest, SwizzlingWhile) {
EXPECT_THAT(xla_while->operand(0), op::Tuple(op::Copy(), op::Copy()));
}
+TEST_F(CopyInsertionTest, CrossingParameters) {
+ // Test a case where two parameters' dataflow cross with each other while
+ // input and output are aliased with same index:
+ //
+ // (p0 , p1)
+ // | \ /|
+ // | \ / |
+ // alias X alias
+ // | / \ |
+ // | / \|
+ // (p1 , p0)
+ auto module = CreateNewModule();
+ const Shape tuple_shape =
+ ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
+
+ auto builder = HloComputation::Builder(TestName());
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, tuple_shape, "0"));
+ auto gte0 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
+ auto gte1 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
+ builder.AddInstruction(HloInstruction::CreateTuple({gte1, gte0}));
+ module->AddEntryComputation(builder.Build());
+ ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
+ /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
+ ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
+ /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}));
+ InsertCopies(module.get());
+
+ EXPECT_EQ(CountCopies(*module), 4);
+}
+
+TEST_F(CopyInsertionTest, ParametersAliasing) {
+ // Test a case where two parameters' dataflow don't interfere with each other
+ // while aliased.
+ //
+ // (p0 , p1)
+ // | |
+ // | |
+ // alias alias
+ // | |
+ // | |
+ // (p0 , p1)
+ auto module = CreateNewModule();
+ const Shape tuple_shape =
+ ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
+
+ auto builder = HloComputation::Builder(TestName());
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, tuple_shape, "p0"));
+ auto gte0 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
+ auto gte1 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
+ builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
+ module->AddEntryComputation(builder.Build());
+ ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
+ /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
+ ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
+ /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}));
+ InsertCopies(module.get());
+
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ op::Tuple(op::Copy(op::GetTupleElement(param, 0)),
+ op::Copy(op::GetTupleElement(param, 1))));
+
+ EXPECT_EQ(CountCopies(*module), 2);
+}
+
+TEST_F(CopyInsertionTest, ParameterWithPartialAliasing) {
+ // Test a case where one parameter is aliased with result while another one
+ // isn't.
+ //
+ // (p0 , p1)
+ // | |
+ // | |
+ // alias |
+ // | |
+ // | |
+ // (p0 , p1)
+ auto module = CreateNewModule();
+ const Shape tuple_shape =
+ ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
+
+ auto builder = HloComputation::Builder(TestName());
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, tuple_shape, "p0"));
+ auto gte0 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
+ auto gte1 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
+ builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
+ module->AddEntryComputation(builder.Build());
+ ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
+ /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
+ InsertCopies(module.get());
+
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ op::Tuple(op::Copy(op::GetTupleElement(param, 0)),
+ op::Copy(op::GetTupleElement(param, 1))));
+
+ EXPECT_EQ(CountCopies(*module), 2);
+}
+
+TEST_F(CopyInsertionTest, ParameterAndParallelOpsWithPartialAliasing) {
+ // Test a case where one parameter is aliased with result while another one
+ // isn't.
+ //
+ // +-- (p0 , p1)
+ // | | |
+ // | | |
+ // alias Negate Negate
+ // | | |
+ // | | |
+ // +-- (p0 , p1)
+ auto module = CreateNewModule();
+ const Shape tuple_shape =
+ ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
+
+ auto builder = HloComputation::Builder(TestName());
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, tuple_shape, "p0"));
+ auto gte0 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
+ auto gte1 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
+
+ auto negate0 = builder.AddInstruction(
+ HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte0));
+
+ auto negate1 = builder.AddInstruction(
+ HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte1));
+ builder.AddInstruction(HloInstruction::CreateTuple({negate0, negate1}));
+ module->AddEntryComputation(builder.Build());
+ ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
+ /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
+ InsertCopies(module.get());
+
+ EXPECT_EQ(CountCopies(*module), 0);
+}
+
+TEST_F(CopyInsertionTest, ParameterAndOpsWithPartialAliasing) {
+ // Test a case where one parameter is aliased with result while another one
+ // isn't.
+ //
+ // +-- (p0 , p1)
+ // | | |
+ // | | |
+ // alias Negate Negate
+ // | | |
+ // | Add----+
+ // | | |
+ // +-- (p0 , p1)
+ auto module = CreateNewModule();
+ const Shape tuple_shape =
+ ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
+
+ auto builder = HloComputation::Builder(TestName());
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, tuple_shape, "p0"));
+ auto gte0 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
+ auto gte1 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
+
+ auto negate0 = builder.AddInstruction(
+ HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte0));
+
+ auto negate1 = builder.AddInstruction(
+ HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte1));
+
+ auto add = builder.AddInstruction(HloInstruction::CreateBinary(
+ scalar_shape_, HloOpcode::kAdd, negate0, negate1));
+ builder.AddInstruction(HloInstruction::CreateTuple({add, negate1}));
+ module->AddEntryComputation(builder.Build());
+ ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
+ /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
+ InsertCopies(module.get());
+
+ EXPECT_EQ(CountCopies(*module), 0);
+}
+
TEST_F(CopyInsertionTest, SwizzlingWhileWithOneOp) {
// Test a while instruction with a body which permutes its tuple parameter
// elements and applies one operation to one of the elements. The addition of