aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar Yunxing Dai <yunxing@google.com>2018-09-30 22:34:28 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-30 22:39:19 -0700
commitb797bfb750504e03a38a988c44e3c52e902e87c4 (patch)
tree2359a3ee011deb519752a84d45fec77f28d961b0 /tensorflow/compiler
parent987954ce50583409e54828a044e0866bcfdbd88a (diff)
[HloOrdering] Make parameter always defined before other instructions.
- Make parameter always defined before other instructions. - Add extra indentations to the predecessor field in ToString() method to make it clear. PiperOrigin-RevId: 215162840
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering.cc10
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering_test.cc20
2 files changed, 27 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc
index f1dc08bafa..23d41d91d6 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering.cc
+++ b/tensorflow/compiler/xla/service/hlo_ordering.cc
@@ -92,14 +92,18 @@ bool HloOrdering::ExecutesBefore(const HloInstruction* a,
}
bool HloOrdering::IsDefinedBefore(const HloValue& a, const HloValue& b) const {
- // If 'b' is an entry param then 'a' cannot be defined before 'b' because 'b'
- // is live into the module.
+ // Entry parameter should always be defined before other instructions.
const HloModule* module = b.defining_instruction()->parent()->parent();
if (b.defining_instruction()->parent() == module->entry_computation() &&
b.defining_instruction()->opcode() == HloOpcode::kParameter) {
return false;
}
+ if (a.defining_instruction()->parent() == module->entry_computation() &&
+ a.defining_instruction()->opcode() == HloOpcode::kParameter) {
+ return true;
+ }
+
// Phi values require special handling. Because XLA does not have a phi
// instruction, the definition instruction of the phis values are
// placeholders: either the subcomputation parameter (body or condition) or
@@ -316,7 +320,7 @@ string PredecessorHloOrdering::ToStringHelper(const string& name) const {
for (auto predecessor : all) {
if (predecessors_.at(computation)
->IsReachable(predecessor, instruction)) {
- pieces.push_back(absl::StrFormat(" %s", predecessor->name()));
+ pieces.push_back(absl::StrFormat(" %s", predecessor->name()));
}
}
}
diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc
index 00970bcda3..b045adc964 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc
@@ -174,6 +174,26 @@ TEST_F(HloOrderingTest, InstructionsInWhileComputations) {
EXPECT_FALSE(ordering.ExecutesBefore(body_param, cond_param));
}
+TEST_F(HloOrderingTest, ParametersDefinedBeforeOthers) {
+ // Entry parameter should always be defined before other instruction.
+ auto module = CreateNewModule();
+ const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
+ auto builder = HloComputation::Builder(TestName());
+ auto constant = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape, "param"));
+ module->AddEntryComputation(builder.Build());
+ TF_ASSERT_OK_AND_ASSIGN(auto dataflow,
+ HloDataflowAnalysis::Run(*module, /*ssa_form=*/true));
+
+ DependencyHloOrdering ordering(module.get());
+ EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(param),
+ dataflow->GetValueDefinedAt(constant)));
+ EXPECT_TRUE(!ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(constant),
+ dataflow->GetValueDefinedAt(param)));
+}
+
TEST_F(HloOrderingTest, ValuesInWhileComputations) {
// Tests the ordering of values (defined by dataflow analysis) in the body and
// condition of a while instruction. HLO code: