aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2018-09-05 17:17:23 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-05 17:22:22 -0700
commit6bd9f8fa0c17c55fc0c11ba0d9281cab1688b115 (patch)
tree1afd3dff710c4f63bae267807435abdcec784edb /tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
parent017599d0a1fa7a7227a43649db67e96311033a4e (diff)
Rollforward of cl/211656888 after fixing failing unit test.
*** Original change description *** Add HloSchedule class representing a sequential order of an HloModule. Currently we represent a sequential schedule of a module using a SequentialHloOrdering::HloModuleSequence which is a type alias of a bare map from HloComputation* to std::vector<HloInstruction*>. This CL replaces this with a proper class which results in better encap... *** PiperOrigin-RevId: 211726890
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc29
1 files changed, 16 insertions, 13 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
index 72b236801a..510d6360a1 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
@@ -1261,9 +1262,10 @@ TEST_P(HloDataflowAnalysisTest, MultipleEntryParameters_Sequential) {
auto entry = module_->AddEntryComputation(builder.Build());
RunAnalysis(GetParam());
- SequentialHloOrdering::HloModuleSequence sequence;
- sequence.insert({entry, {param0, negate, param1, exp, add}});
- SequentialHloOrdering ordering(module_.get(), sequence);
+ HloSchedule schedule(module_.get());
+ schedule.set_sequence(entry, {param0, negate, param1, exp, add});
+ TF_ASSERT_OK(schedule.Verify());
+ SequentialHloOrdering ordering(schedule);
// Entry parameters interfere as if they are defined simultaneously at
// the very beginning.
@@ -1339,14 +1341,16 @@ TEST_P(HloDataflowAnalysisTest, WhileParameters_Sequential) {
bool ssa_form = GetParam();
RunAnalysis(ssa_form);
- SequentialHloOrdering::HloModuleSequence sequence;
- sequence.insert({entry, {param, xla_while}});
- sequence.insert({condition, {cond_param, cond_constant}});
+ HloSchedule schedule(module_.get());
+ schedule.set_sequence(entry, {param, xla_while});
+ schedule.set_sequence(condition, {cond_param, cond_constant});
// Construct the order such that 'constant' and its use 'exp' are before
// body_param.
- sequence.insert({body, {constant, exp, body_param, add}});
+ schedule.set_sequence(
+ body, {constant, exp, body_param, add, dead_constant, dead_negate});
+ TF_ASSERT_OK(schedule.Verify());
- SequentialHloOrdering ordering(module_.get(), sequence);
+ SequentialHloOrdering ordering(schedule);
// 'add' is live out of the body and will interfere with an later instructions
// such as 'dead_constant' and 'dead_negate'.
@@ -1476,11 +1480,10 @@ TEST_P(HloDataflowAnalysisTest, OverlappedValuesSequentialOrder) {
auto entry = module_->AddEntryComputation(builder.Build());
RunAnalysis(GetParam());
- SequentialHloOrdering::HloModuleSequence sequence;
- std::vector<const HloInstruction*> order = {param, negate, exp, add};
- sequence.emplace(entry, order);
-
- SequentialHloOrdering ordering(module_.get(), sequence);
+ HloSchedule schedule(module_.get());
+ schedule.set_sequence(entry, {param, negate, exp, add});
+ TF_ASSERT_OK(schedule.Verify());
+ SequentialHloOrdering ordering(schedule);
EXPECT_TRUE(InstructionsMayInterfere(ordering, param, negate));
EXPECT_FALSE(InstructionsMayInterfere(ordering, param, exp));