diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/copy_insertion_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/copy_insertion_test.cc | 183 |
1 files changed, 0 insertions, 183 deletions
diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index 3096206c34..892d0d7b54 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -1351,189 +1351,6 @@ TEST_F(CopyInsertionTest, SwizzlingWhile) { EXPECT_THAT(xla_while->operand(0), op::Tuple(op::Copy(), op::Copy())); } -TEST_F(CopyInsertionTest, CrossingParameters) { - // Test a case where two parameters' dataflow cross with each other while - // input and output are aliased with same index: - // - // (p0 , p1) - // | \ /| - // | \ / | - // alias X alias - // | / \ | - // | / \| - // (p1 , p0) - auto module = CreateNewModule(); - const Shape tuple_shape = - ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); - - auto builder = HloComputation::Builder(TestName()); - auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, tuple_shape, "0")); - auto gte0 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0)); - auto gte1 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); - builder.AddInstruction(HloInstruction::CreateTuple({gte1, gte0})); - module->AddEntryComputation(builder.Build()); - ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( - /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); - ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( - /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1})); - InsertCopies(module.get()); - - EXPECT_EQ(CountCopies(*module), 4); -} - -TEST_F(CopyInsertionTest, ParametersAliasing) { - // Test a case where two parameters' dataflow don't interfere with each other - // while aliased. - // - // (p0 , p1) - // | | - // | | - // alias alias - // | | - // | | - // (p0 , p1) - auto module = CreateNewModule(); - const Shape tuple_shape = - ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); - - auto builder = HloComputation::Builder(TestName()); - auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, tuple_shape, "p0")); - auto gte0 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0)); - auto gte1 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); - builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); - module->AddEntryComputation(builder.Build()); - ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( - /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); - ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( - /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1})); - InsertCopies(module.get()); - - EXPECT_THAT(module->entry_computation()->root_instruction(), - op::Tuple(op::Copy(op::GetTupleElement(param, 0)), - op::Copy(op::GetTupleElement(param, 1)))); - - EXPECT_EQ(CountCopies(*module), 2); -} - -TEST_F(CopyInsertionTest, ParameterWithPartialAliasing) { - // Test a case where one parameter is aliased with result while another one - // isn't. - // - // (p0 , p1) - // | | - // | | - // alias | - // | | - // | | - // (p0 , p1) - auto module = CreateNewModule(); - const Shape tuple_shape = - ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); - - auto builder = HloComputation::Builder(TestName()); - auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, tuple_shape, "p0")); - auto gte0 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0)); - auto gte1 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); - builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); - module->AddEntryComputation(builder.Build()); - ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( - /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); - InsertCopies(module.get()); - - EXPECT_THAT(module->entry_computation()->root_instruction(), - op::Tuple(op::Copy(op::GetTupleElement(param, 0)), - op::Copy(op::GetTupleElement(param, 1)))); - - EXPECT_EQ(CountCopies(*module), 2); -} - -TEST_F(CopyInsertionTest, ParameterAndParallelOpsWithPartialAliasing) { - // Test a case where one parameter is aliased with result while another one - // isn't. - // - // +-- (p0 , p1) - // | | | - // | | | - // alias Negate Negate - // | | | - // | | | - // +-- (p0 , p1) - auto module = CreateNewModule(); - const Shape tuple_shape = - ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); - - auto builder = HloComputation::Builder(TestName()); - auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, tuple_shape, "p0")); - auto gte0 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0)); - auto gte1 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); - - auto negate0 = builder.AddInstruction( - HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte0)); - - auto negate1 = builder.AddInstruction( - HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte1)); - builder.AddInstruction(HloInstruction::CreateTuple({negate0, negate1})); - module->AddEntryComputation(builder.Build()); - ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( - /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); - InsertCopies(module.get()); - - EXPECT_EQ(CountCopies(*module), 0); -} - -TEST_F(CopyInsertionTest, ParameterAndOpsWithPartialAliasing) { - // Test a case where one parameter is aliased with result while another one - // isn't. - // - // +-- (p0 , p1) - // | | | - // | | | - // alias Negate Negate - // | | | - // | Add----+ - // | | | - // +-- (p0 , p1) - auto module = CreateNewModule(); - const Shape tuple_shape = - ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); - - auto builder = HloComputation::Builder(TestName()); - auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, tuple_shape, "p0")); - auto gte0 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0)); - auto gte1 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); - - auto negate0 = builder.AddInstruction( - HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte0)); - - auto negate1 = builder.AddInstruction( - HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte1)); - - auto add = builder.AddInstruction(HloInstruction::CreateBinary( - scalar_shape_, HloOpcode::kAdd, negate0, negate1)); - builder.AddInstruction(HloInstruction::CreateTuple({add, negate1})); - module->AddEntryComputation(builder.Build()); - ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( - /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); - InsertCopies(module.get()); - - EXPECT_EQ(CountCopies(*module), 0); -} - TEST_F(CopyInsertionTest, SwizzlingWhileWithOneOp) { // Test a while instruction with a body which permutes its tuple parameter // elements and applies one operation to one of the elements. The addition of |