aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_ordering_test.cc
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2018-09-05 10:34:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-05 10:41:37 -0700
commit7fa693209fe238478739b3982f652a7e35be91f3 (patch)
treeeb31635c366d9eceb144970ddb2b659441204ce1 /tensorflow/compiler/xla/service/hlo_ordering_test.cc
parent08313b87960962efb98bcd684776c8305fa9909a (diff)
Add HloSchedule class representing a sequential order of an HloModule.
Currently we represent a sequential schedule of a module using a SequentialHloOrdering::HloModuleSequence which is a type alias of a bare map from HloComputation* to std::vector<HloInstruction*>. This CL replaces this with a proper class which results in better encapsulation of code which deals with schedules and better enforcement of invariants. This CL also fixes a corner-case bug in dataflow analysis, where values of instructions which are live out of the computation erroneously did not interfere with the values of instructions scheduled after the root instruction. PiperOrigin-RevId: 211656888
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_ordering_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering_test.cc101
1 files changed, 101 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc
index 126d3a2d9c..6b6005e7a5 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc
@@ -23,11 +23,13 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
namespace {
@@ -376,5 +378,104 @@ ENTRY root {
dataflow->GetValueDefinedAt(add_3)));
}
+TEST_F(HloOrderingTest,
+ ValuesLiveOutOfModuleInterfereWithInstructionsAfterRoot) {
+ // Tests that values live out of the module should interfere with values
+ // defined after the root instruction. That is:
+ //
+ // %param = param(0)
+ // ROOT %root = negate(%param)
+ // %dead = Constant(123.0)
+ //
+ // %root should interfere with %dead.
+ auto module = CreateNewModule();
+ const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
+
+ auto builder = HloComputation::Builder(TestName());
+ HloInstruction* param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape, "param"));
+ HloInstruction* root = builder.AddInstruction(
+ HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param));
+ HloInstruction* dead = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0f)));
+ HloComputation* entry =
+ module->AddEntryComputation(builder.Build(/*root_instruction=*/root));
+
+ HloSchedule schedule(module.get());
+ schedule.set_sequence(entry, {param, root, dead});
+ TF_ASSERT_OK(schedule.Verify());
+ SequentialHloOrdering ordering(schedule);
+
+ TF_ASSERT_OK_AND_ASSIGN(auto dataflow,
+ HloDataflowAnalysis::Run(*module, /*ssa_form=*/true));
+
+ EXPECT_TRUE(ordering.ExecutesBefore(root, dead));
+ EXPECT_FALSE(ordering.ExecutesBefore(dead, root));
+
+ EXPECT_FALSE(ordering.LiveRangeStrictlyBefore(
+ dataflow->GetValueDefinedAt(root), dataflow->GetValueDefinedAt(dead),
+ *dataflow));
+
+ EXPECT_TRUE(ordering.MayInterfere(dataflow->GetValueDefinedAt(root),
+ dataflow->GetValueDefinedAt(dead),
+ *dataflow));
+}
+
+TEST_F(HloOrderingTest,
+ ValuesLiveOutOfComputationInterfereWithInstructionsAfterRoot) {
+ // Tests that values live out of a computation should interfere with values
+ // defined after the root instruction of the computation. That is:
+ //
+ // subcomputation:
+ // %param = param(0)
+ // ROOT %root = negate(%param)
+ // %dead = Constant(123.0)
+ //
+ // entry computation:
+ // %c = constant(42.0)
+ // ROOT %call = call({%c}), subcomputation
+ //
+ // %root should interfere with %dead.
+ auto module = CreateNewModule();
+ const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
+
+ auto subbuilder = HloComputation::Builder(TestName() + ".sub");
+ HloInstruction* param = subbuilder.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape, "param"));
+ HloInstruction* root = subbuilder.AddInstruction(
+ HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param));
+ HloInstruction* dead = subbuilder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0f)));
+ HloComputation* subcomputation = module->AddEmbeddedComputation(
+ subbuilder.Build(/*root_instruction=*/root));
+
+ auto builder = HloComputation::Builder(TestName());
+ HloInstruction* c = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
+ HloInstruction* call = builder.AddInstruction(
+ HloInstruction::CreateCall(scalar_shape, {c}, subcomputation));
+ HloComputation* entry = module->AddEntryComputation(builder.Build());
+
+ HloSchedule schedule(module.get());
+ schedule.set_sequence(subcomputation, {param, root, dead});
+ schedule.set_sequence(entry, {c, call});
+ TF_ASSERT_OK(schedule.Verify());
+ SequentialHloOrdering ordering(schedule);
+
+ TF_ASSERT_OK_AND_ASSIGN(auto dataflow,
+ HloDataflowAnalysis::Run(*module, /*ssa_form=*/true));
+
+ EXPECT_TRUE(ordering.ExecutesBefore(root, dead));
+ EXPECT_FALSE(ordering.ExecutesBefore(dead, root));
+
+ EXPECT_FALSE(ordering.LiveRangeStrictlyBefore(
+ dataflow->GetValueDefinedAt(root), dataflow->GetValueDefinedAt(dead),
+ *dataflow));
+
+ EXPECT_TRUE(ordering.MayInterfere(dataflow->GetValueDefinedAt(root),
+ dataflow->GetValueDefinedAt(dead),
+ *dataflow));
+}
+
} // namespace
} // namespace xla