aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2018-09-05 10:34:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-05 10:41:37 -0700
commit7fa693209fe238478739b3982f652a7e35be91f3 (patch)
treeeb31635c366d9eceb144970ddb2b659441204ce1 /tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
parent08313b87960962efb98bcd684776c8305fa9909a (diff)
Add HloSchedule class representing a sequential order of an HloModule.
Currently we represent a sequential schedule of a module using a SequentialHloOrdering::HloModuleSequence which is a type alias of a bare map from HloComputation* to std::vector<HloInstruction*>. This CL replaces this with a proper class which results in better encapsulation of code which deals with schedules and better enforcement of invariants. This CL also fixes a corner-case bug in dataflow analysis, where values of instructions which are live out of the computation erroneously did not interfere with the values of instructions scheduled after the root instruction. PiperOrigin-RevId: 211656888
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc29
1 files changed, 16 insertions, 13 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
index 62eea2b06c..0a86f83ed9 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
@@ -1261,9 +1262,10 @@ TEST_P(HloDataflowAnalysisTest, MultipleEntryParameters_Sequential) {
auto entry = module_->AddEntryComputation(builder.Build());
RunAnalysis(GetParam());
- SequentialHloOrdering::HloModuleSequence sequence;
- sequence.insert({entry, {param0, negate, param1, exp, add}});
- SequentialHloOrdering ordering(module_.get(), sequence);
+ HloSchedule schedule(module_.get());
+ schedule.set_sequence(entry, {param0, negate, param1, exp, add});
+ TF_ASSERT_OK(schedule.Verify());
+ SequentialHloOrdering ordering(schedule);
// Entry parameters interfere as if they are defined simultaneously at
// the very beginning.
@@ -1339,14 +1341,16 @@ TEST_P(HloDataflowAnalysisTest, WhileParameters_Sequential) {
bool ssa_form = GetParam();
RunAnalysis(ssa_form);
- SequentialHloOrdering::HloModuleSequence sequence;
- sequence.insert({entry, {param, xla_while}});
- sequence.insert({condition, {cond_param, cond_constant}});
+ HloSchedule schedule(module_.get());
+ schedule.set_sequence(entry, {param, xla_while});
+ schedule.set_sequence(condition, {cond_param, cond_constant});
// Construct the order such that 'constant' and its use 'exp' are before
// body_param.
- sequence.insert({body, {constant, exp, body_param, add}});
+ schedule.set_sequence(
+ body, {constant, exp, body_param, add, dead_constant, dead_negate});
+ TF_ASSERT_OK(schedule.Verify());
- SequentialHloOrdering ordering(module_.get(), sequence);
+ SequentialHloOrdering ordering(schedule);
// 'add' is live out of the body and will interfere with an later instructions
// such as 'dead_constant' and 'dead_negate'.
@@ -1476,11 +1480,10 @@ TEST_P(HloDataflowAnalysisTest, OverlappedValuesSequentialOrder) {
auto entry = module_->AddEntryComputation(builder.Build());
RunAnalysis(GetParam());
- SequentialHloOrdering::HloModuleSequence sequence;
- std::vector<const HloInstruction*> order = {param, negate, exp, add};
- sequence.emplace(entry, order);
-
- SequentialHloOrdering ordering(module_.get(), sequence);
+ HloSchedule schedule(module_.get());
+ schedule.set_sequence(entry, {param, negate, exp, add});
+ TF_ASSERT_OK(schedule.Verify());
+ SequentialHloOrdering ordering(schedule);
EXPECT_TRUE(InstructionsMayInterfere(ordering, param, negate));
EXPECT_FALSE(InstructionsMayInterfere(ordering, param, exp));