aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_ordering_test.cc
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2017-09-01 09:17:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-01 09:21:31 -0700
commit73d796423348347702d43b498257f34e41fba367 (patch)
treed9be8036c53d39d3ab4abddd7793724fc5dc52d1 /tensorflow/compiler/xla/service/hlo_ordering_test.cc
parent6e8d0c632dea30758c7cc343decdf8ab7956e59d (diff)
Rollback update-ability of dataflow and alias analysis added in cl/164923041 and cl/64778750. It did not scale as intended to large graphs when used in copy insertion. This change also includes some simplification and performance improvements to dataflow and alias analysis. Also add some value-ordering tests to HloOrderingTest using dataflow analysis to generate values.
PiperOrigin-RevId: 167283460
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_ordering_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering_test.cc89
1 files changed, 89 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc
index ad6070a9c1..c95e44bd5d 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <string>
#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
@@ -218,6 +219,94 @@ TEST_F(HloOrderingTest, InstructionsInWhileComputations) {
EXPECT_FALSE(ordering.ExecutesBefore(body_param, cond_param));
}
+TEST_F(HloOrderingTest, ValuesInWhileComputations) {
+ // Tests the ordering of values (defined by dataflow analysis) in the body and
+ // condition of a while instruction. HLO code:
+ //
+ // body(F32[]) %param):
+ // %negate = Negate(%param)
+ //
+ // condition(F32[] %param):
+ // %convert = Convert<PRED>(%param)
+ //
+ // entry:
+ // %constant = Constant(1.0)
+ // %while = While(%constant, body, condition)
+ // %add = Add(%constant, %while)
+ //
+ auto module = CreateNewModule();
+ const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
+
+ auto body_builder = HloComputation::Builder("body");
+ auto body_param = body_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape, "body_param"));
+ auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary(
+ scalar_shape, HloOpcode::kNegate, body_param));
+ HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
+
+ auto cond_builder = HloComputation::Builder("condition");
+ auto cond_param = cond_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape, "cond_param"));
+ auto convert = cond_builder.AddInstruction(HloInstruction::CreateConvert(
+ ShapeUtil::MakeShape(xla::PRED, {}), cond_param));
+ HloComputation* condition =
+ module->AddEmbeddedComputation(cond_builder.Build());
+
+ auto builder = HloComputation::Builder(TestName());
+ auto constant = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ auto xla_while = builder.AddInstruction(
+ HloInstruction::CreateWhile(scalar_shape, condition, body, constant));
+ auto add = builder.AddInstruction(HloInstruction::CreateBinary(
+ scalar_shape, HloOpcode::kAdd, constant, xla_while));
+ module->AddEntryComputation(builder.Build());
+
+ TF_ASSERT_OK_AND_ASSIGN(
+ auto dataflow, HloDataflowAnalysis::Run(module.get(), /*ssa_form=*/true));
+ DependencyHloOrdering ordering(module.get());
+
+ // Init value is defined before the while, but live range is not before the
+ // while because of the use of the init value in the add.
+ EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(constant),
+ dataflow->GetValueDefinedAt(xla_while)));
+ EXPECT_FALSE(
+ ordering.LiveRangeStrictlyBefore(dataflow->GetValueDefinedAt(constant),
+ dataflow->GetValueDefinedAt(xla_while)));
+ EXPECT_TRUE(ordering.MayInterfere(dataflow->GetValueDefinedAt(constant),
+ dataflow->GetValueDefinedAt(xla_while)));
+
+ // Any value defined in the body or condition is defined before the while, and
+ // has a live range strictly before the while.
+ EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(negate),
+ dataflow->GetValueDefinedAt(xla_while)));
+ EXPECT_TRUE(
+ ordering.LiveRangeStrictlyBefore(dataflow->GetValueDefinedAt(negate),
+ dataflow->GetValueDefinedAt(xla_while)));
+ EXPECT_FALSE(ordering.MayInterfere(dataflow->GetValueDefinedAt(negate),
+ dataflow->GetValueDefinedAt(xla_while)));
+
+ EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(convert),
+ dataflow->GetValueDefinedAt(xla_while)));
+ EXPECT_TRUE(
+ ordering.LiveRangeStrictlyBefore(dataflow->GetValueDefinedAt(convert),
+ dataflow->GetValueDefinedAt(xla_while)));
+ EXPECT_FALSE(ordering.MayInterfere(dataflow->GetValueDefinedAt(convert),
+ dataflow->GetValueDefinedAt(xla_while)));
+
+ // The live range of the while should be before the add.
+ EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(xla_while),
+ dataflow->GetValueDefinedAt(add)));
+ ASSERT_EQ(dataflow->GetValueDefinedAt(xla_while).uses().size(), 1);
+
+ const HloUse& while_use = dataflow->GetValueDefinedAt(xla_while).uses()[0];
+ EXPECT_EQ(while_use.instruction, add);
+ EXPECT_TRUE(ordering.UseIsBeforeValueDefinition(
+ while_use, dataflow->GetValueDefinedAt(add)));
+ EXPECT_TRUE(
+ ordering.LiveRangeStrictlyBefore(dataflow->GetValueDefinedAt(xla_while),
+ dataflow->GetValueDefinedAt(add)));
+}
+
} // namespace
} // namespace xla