diff options
author | 2018-09-05 17:17:23 -0700 | |
---|---|---|
committer | 2018-09-05 17:22:22 -0700 | |
commit | 6bd9f8fa0c17c55fc0c11ba0d9281cab1688b115 (patch) | |
tree | 1afd3dff710c4f63bae267807435abdcec784edb /tensorflow/compiler/xla/service/hlo_ordering_test.cc | |
parent | 017599d0a1fa7a7227a43649db67e96311033a4e (diff) |
Rollforward of cl/211656888 after fixing failing unit test.
*** Original change description ***
Add HloSchedule class representing a sequential order of an HloModule.
Currently we represent a sequential schedule of a module using a SequentialHloOrdering::HloModuleSequence which is a type alias of a bare map from HloComputation* to std::vector<HloInstruction*>. This CL replaces this with a proper class which results in better encap...
***
PiperOrigin-RevId: 211726890
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_ordering_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_ordering_test.cc | 101 |
1 files changed, 101 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc index 126d3a2d9c..6b6005e7a5 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc @@ -23,11 +23,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { @@ -376,5 +378,104 @@ ENTRY root { dataflow->GetValueDefinedAt(add_3))); } +TEST_F(HloOrderingTest, + ValuesLiveOutOfModuleInterfereWithInstructionsAfterRoot) { + // Tests that values live out of the module should interfere with values + // defined after the root instruction. That is: + // + // %param = param(0) + // ROOT %root = negate(%param) + // %dead = Constant(123.0) + // + // %root should interfere with %dead. + auto module = CreateNewModule(); + const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); + + auto builder = HloComputation::Builder(TestName()); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param")); + HloInstruction* root = builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param)); + HloInstruction* dead = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0f))); + HloComputation* entry = + module->AddEntryComputation(builder.Build(/*root_instruction=*/root)); + + HloSchedule schedule(module.get()); + schedule.set_sequence(entry, {param, root, dead}); + TF_ASSERT_OK(schedule.Verify()); + SequentialHloOrdering ordering(schedule); + + TF_ASSERT_OK_AND_ASSIGN(auto dataflow, + HloDataflowAnalysis::Run(*module, /*ssa_form=*/true)); + + EXPECT_TRUE(ordering.ExecutesBefore(root, dead)); + EXPECT_FALSE(ordering.ExecutesBefore(dead, root)); + + EXPECT_FALSE(ordering.LiveRangeStrictlyBefore( + dataflow->GetValueDefinedAt(root), dataflow->GetValueDefinedAt(dead), + *dataflow)); + + EXPECT_TRUE(ordering.MayInterfere(dataflow->GetValueDefinedAt(root), + dataflow->GetValueDefinedAt(dead), + *dataflow)); +} + +TEST_F(HloOrderingTest, + ValuesLiveOutOfComputationInterfereWithInstructionsAfterRoot) { + // Tests that values live out of a computation should interfere with values + // defined after the root instruction of the computation. That is: + // + // subcomputation: + // %param = param(0) + // ROOT %root = negate(%param) + // %dead = Constant(123.0) + // + // entry computation: + // %c = constant(42.0) + // ROOT %call = call({%c}), subcomputation + // + // %root should interfere with %dead. + auto module = CreateNewModule(); + const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); + + auto subbuilder = HloComputation::Builder(TestName() + ".sub"); + HloInstruction* param = subbuilder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param")); + HloInstruction* root = subbuilder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param)); + HloInstruction* dead = subbuilder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0f))); + HloComputation* subcomputation = module->AddEmbeddedComputation( + subbuilder.Build(/*root_instruction=*/root)); + + auto builder = HloComputation::Builder(TestName()); + HloInstruction* c = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f))); + HloInstruction* call = builder.AddInstruction( + HloInstruction::CreateCall(scalar_shape, {c}, subcomputation)); + HloComputation* entry = module->AddEntryComputation(builder.Build()); + + HloSchedule schedule(module.get()); + schedule.set_sequence(subcomputation, {param, root, dead}); + schedule.set_sequence(entry, {c, call}); + TF_ASSERT_OK(schedule.Verify()); + SequentialHloOrdering ordering(schedule); + + TF_ASSERT_OK_AND_ASSIGN(auto dataflow, + HloDataflowAnalysis::Run(*module, /*ssa_form=*/true)); + + EXPECT_TRUE(ordering.ExecutesBefore(root, dead)); + EXPECT_FALSE(ordering.ExecutesBefore(dead, root)); + + EXPECT_FALSE(ordering.LiveRangeStrictlyBefore( + dataflow->GetValueDefinedAt(root), dataflow->GetValueDefinedAt(dead), + *dataflow)); + + EXPECT_TRUE(ordering.MayInterfere(dataflow->GetValueDefinedAt(root), + dataflow->GetValueDefinedAt(dead), + *dataflow)); +} + } // namespace } // namespace xla |