aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_ordering_test.cc
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2018-09-05 17:17:23 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-05 17:22:22 -0700
commit6bd9f8fa0c17c55fc0c11ba0d9281cab1688b115 (patch)
tree1afd3dff710c4f63bae267807435abdcec784edb /tensorflow/compiler/xla/service/hlo_ordering_test.cc
parent017599d0a1fa7a7227a43649db67e96311033a4e (diff)
Rollforward of cl/211656888 after fixing failing unit test.
*** Original change description *** Add HloSchedule class representing a sequential order of an HloModule. Currently we represent a sequential schedule of a module using a SequentialHloOrdering::HloModuleSequence which is a type alias of a bare map from HloComputation* to std::vector<HloInstruction*>. This CL replaces this with a proper class which results in better encap... *** PiperOrigin-RevId: 211726890
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_ordering_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering_test.cc101
1 files changed, 101 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc
index 126d3a2d9c..6b6005e7a5 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc
@@ -23,11 +23,13 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
namespace {
@@ -376,5 +378,104 @@ ENTRY root {
dataflow->GetValueDefinedAt(add_3)));
}
+TEST_F(HloOrderingTest,
+ ValuesLiveOutOfModuleInterfereWithInstructionsAfterRoot) {
+ // Tests that values live out of the module should interfere with values
+ // defined after the root instruction. That is:
+ //
+ // %param = param(0)
+ // ROOT %root = negate(%param)
+ // %dead = Constant(123.0)
+ //
+ // %root should interfere with %dead.
+ auto module = CreateNewModule();
+ const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
+
+ auto builder = HloComputation::Builder(TestName());
+ HloInstruction* param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape, "param"));
+ HloInstruction* root = builder.AddInstruction(
+ HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param));
+ HloInstruction* dead = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0f)));
+ HloComputation* entry =
+ module->AddEntryComputation(builder.Build(/*root_instruction=*/root));
+
+ HloSchedule schedule(module.get());
+ schedule.set_sequence(entry, {param, root, dead});
+ TF_ASSERT_OK(schedule.Verify());
+ SequentialHloOrdering ordering(schedule);
+
+ TF_ASSERT_OK_AND_ASSIGN(auto dataflow,
+ HloDataflowAnalysis::Run(*module, /*ssa_form=*/true));
+
+ EXPECT_TRUE(ordering.ExecutesBefore(root, dead));
+ EXPECT_FALSE(ordering.ExecutesBefore(dead, root));
+
+ EXPECT_FALSE(ordering.LiveRangeStrictlyBefore(
+ dataflow->GetValueDefinedAt(root), dataflow->GetValueDefinedAt(dead),
+ *dataflow));
+
+ EXPECT_TRUE(ordering.MayInterfere(dataflow->GetValueDefinedAt(root),
+ dataflow->GetValueDefinedAt(dead),
+ *dataflow));
+}
+
+TEST_F(HloOrderingTest,
+ ValuesLiveOutOfComputationInterfereWithInstructionsAfterRoot) {
+ // Tests that values live out of a computation should interfere with values
+ // defined after the root instruction of the computation. That is:
+ //
+ // subcomputation:
+ // %param = param(0)
+ // ROOT %root = negate(%param)
+ // %dead = Constant(123.0)
+ //
+ // entry computation:
+ // %c = constant(42.0)
+ // ROOT %call = call({%c}), subcomputation
+ //
+ // %root should interfere with %dead.
+ auto module = CreateNewModule();
+ const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
+
+ auto subbuilder = HloComputation::Builder(TestName() + ".sub");
+ HloInstruction* param = subbuilder.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape, "param"));
+ HloInstruction* root = subbuilder.AddInstruction(
+ HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param));
+ HloInstruction* dead = subbuilder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0f)));
+ HloComputation* subcomputation = module->AddEmbeddedComputation(
+ subbuilder.Build(/*root_instruction=*/root));
+
+ auto builder = HloComputation::Builder(TestName());
+ HloInstruction* c = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
+ HloInstruction* call = builder.AddInstruction(
+ HloInstruction::CreateCall(scalar_shape, {c}, subcomputation));
+ HloComputation* entry = module->AddEntryComputation(builder.Build());
+
+ HloSchedule schedule(module.get());
+ schedule.set_sequence(subcomputation, {param, root, dead});
+ schedule.set_sequence(entry, {c, call});
+ TF_ASSERT_OK(schedule.Verify());
+ SequentialHloOrdering ordering(schedule);
+
+ TF_ASSERT_OK_AND_ASSIGN(auto dataflow,
+ HloDataflowAnalysis::Run(*module, /*ssa_form=*/true));
+
+ EXPECT_TRUE(ordering.ExecutesBefore(root, dead));
+ EXPECT_FALSE(ordering.ExecutesBefore(dead, root));
+
+ EXPECT_FALSE(ordering.LiveRangeStrictlyBefore(
+ dataflow->GetValueDefinedAt(root), dataflow->GetValueDefinedAt(dead),
+ *dataflow));
+
+ EXPECT_TRUE(ordering.MayInterfere(dataflow->GetValueDefinedAt(root),
+ dataflow->GetValueDefinedAt(dead),
+ *dataflow));
+}
+
} // namespace
} // namespace xla