aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-10 14:52:27 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-10 15:03:36 -0800
commit4c24ab2d0651b048d81c4743b73ac92b5c39d8cc (patch)
tree673b9b0e957f1a100196172904c34bd74a447e3b
parent69f231135304799f581d34df37d72cdc07fc8f58 (diff)
[TF:XLA] Fix infinite loop in HLO data flow analysis.
Merge input values at phi nodes correctly: If a phi operand is the phi itself, and the other operands are all the same, then the phi node is redundant. PiperOrigin-RevId: 181521522
-rw-r--r--tensorflow/compiler/xla/service/BUILD2
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion_test.cc185
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc14
3 files changed, 198 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index f26dc64fee..16d227a00f 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -1684,6 +1684,7 @@ tf_cc_test(
":hlo",
":hlo_graph_dumper",
":hlo_matchers",
+ ":hlo_runner",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
@@ -1691,7 +1692,6 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/tests:hlo_test_base",
- "//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
)
diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc
index 8388574716..128ee726ea 100644
--- a/tensorflow/compiler/xla/service/copy_insertion_test.cc
+++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_runner.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
@@ -1726,5 +1727,189 @@ void BM_ParallelWhiles(int num_iters, int num_whiles) {
BENCHMARK(BM_SequentialWhiles)->Arg(512)->Arg(1024)->Arg(2048)->Arg(4096);
BENCHMARK(BM_ParallelWhiles)->Arg(512)->Arg(1024)->Arg(2048)->Arg(4096);
+TEST_F(CopyInsertionTest, SimpleControlFlowTest) {
+ const string& hlo_string = R"(
+HloModule TestModule
+
+if-body.v5 {
+ constant.3 = s32[] constant(-1)
+ p.1 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
+ get-tuple-element.18 = (s32[], s32[], s32[]) get-tuple-element(p.1), index=1
+ get-tuple-element.65 = s32[] get-tuple-element(get-tuple-element.18), index=0
+ get-tuple-element.66 = s32[] get-tuple-element(get-tuple-element.18), index=1
+ add.3 = s32[] add(get-tuple-element.65, get-tuple-element.66)
+ tuple.33 = (s32[]) tuple(add.3)
+ ROOT tuple.34 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(constant.3, get-tuple-element.18, tuple.33)
+}
+
+if-condition.v4 {
+ p.2 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
+ get-tuple-element.67 = s32[] get-tuple-element(p.2), index=0
+ constant.4 = s32[] constant(0)
+ ROOT equal-to = pred[] equal-to(get-tuple-element.67, constant.4)
+}
+
+_functionalize_body_1__.v28 {
+ arg_tuple.4 = (s32[], s32[], s32[], s32[]) parameter(0)
+ get-tuple-element.68 = s32[] get-tuple-element(arg_tuple.4), index=0
+ constant.7 = s32[] constant(1)
+ add.4 = s32[] add(get-tuple-element.68, constant.7)
+ get-tuple-element.69 = s32[] get-tuple-element(arg_tuple.4), index=1
+ get-tuple-element.70 = s32[] get-tuple-element(arg_tuple.4), index=2
+ less-than-or-equal-to = pred[] less-than-or-equal-to(get-tuple-element.69, get-tuple-element.70)
+ constant.8 = s32[] constant(0)
+ select = s32[] select(less-than-or-equal-to, constant.8, constant.7)
+ get-tuple-element.71 = s32[] get-tuple-element(arg_tuple.4), index=3
+ tuple.35 = (s32[], s32[], s32[]) tuple(get-tuple-element.69, get-tuple-element.71, get-tuple-element.70)
+ tuple.36 = (s32[]) tuple(constant.8)
+ tuple.37 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(select, tuple.35, tuple.36)
+ while = (s32[], (s32[], s32[], s32[]), (s32[])) while(tuple.37), condition=if-condition.v4, body=if-body.v5
+ get-tuple-element.72 = (s32[]) get-tuple-element(while), index=2
+ get-tuple-element.73 = s32[] get-tuple-element(get-tuple-element.72), index=0
+ ROOT tuple.38 = (s32[], s32[], s32[], s32[]) tuple(add.4, get-tuple-element.69, get-tuple-element.70, get-tuple-element.73)
+}
+
+cond_wrapper.v3.1 {
+ inputs.1 = (s32[], s32[], s32[], s32[]) parameter(0)
+ get-tuple-element.75 = s32[] get-tuple-element(inputs.1), index=0
+ constant.11 = s32[] constant(7)
+ ROOT less-than.2 = pred[] less-than(get-tuple-element.75, constant.11)
+}
+
+_functionalize_body_2__.v25 {
+ arg_tuple.5 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
+ get-tuple-element.76 = s32[] get-tuple-element(arg_tuple.5), index=0
+ get-tuple-element.77 = s32[] get-tuple-element(arg_tuple.5), index=2
+ get-tuple-element.78 = s32[] get-tuple-element(arg_tuple.5), index=3
+ get-tuple-element.79 = s32[] get-tuple-element(arg_tuple.5), index=4
+ tuple.39 = (s32[], s32[], s32[], s32[]) tuple(get-tuple-element.76, get-tuple-element.77, get-tuple-element.78, get-tuple-element.79)
+ while.2 = (s32[], s32[], s32[], s32[]) while(tuple.39), condition=cond_wrapper.v3.1, body=_functionalize_body_1__.v28
+ get-tuple-element.80 = s32[] get-tuple-element(while.2), index=0
+ get-tuple-element.81 = s32[] get-tuple-element(arg_tuple.5), index=1
+ constant.12 = s32[] constant(1)
+ add.5 = s32[] add(get-tuple-element.81, constant.12)
+ get-tuple-element.82 = s32[] get-tuple-element(while.2), index=3
+ ROOT tuple.40 = (s32[], s32[], s32[], s32[], s32[]) tuple(get-tuple-element.80, add.5, get-tuple-element.77, get-tuple-element.78, get-tuple-element.82)
+}
+
+cond_wrapper.v3.2 {
+ inputs.2 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
+ get-tuple-element.83 = s32[] get-tuple-element(inputs.2), index=1
+ constant.13 = s32[] constant(5)
+ ROOT less-than.3 = pred[] less-than(get-tuple-element.83, constant.13)
+}
+
+ENTRY TestComputation {
+ arg_tuple.6 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
+ ROOT while.3 = (s32[], s32[], s32[], s32[], s32[]) while(arg_tuple.6), condition=cond_wrapper.v3.2, body=_functionalize_body_2__.v25
+}
+)";
+ auto module_or_status =
+ HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest());
+ auto module = module_or_status.ConsumeValueOrDie();
+ InsertCopies(module.get());
+}
+
+TEST_F(CopyInsertionTest, ControlFlowTest) {
+ const string& hlo_string = R"(
+HloModule TestModule
+
+if-body.v5 {
+ constant.3 = s32[] constant(-1)
+ p.1 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
+ get-tuple-element.18 = (s32[], s32[], s32[]) get-tuple-element(p.1), index=1
+ get-tuple-element.65 = s32[] get-tuple-element(get-tuple-element.18), index=0
+ get-tuple-element.66 = s32[] get-tuple-element(get-tuple-element.18), index=1
+ add.3 = s32[] add(get-tuple-element.65, get-tuple-element.66)
+ tuple.33 = (s32[]) tuple(add.3)
+ ROOT tuple.34 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(constant.3, get-tuple-element.18, tuple.33)
+}
+
+if-condition.v4 {
+ p.2 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
+ get-tuple-element.67 = s32[] get-tuple-element(p.2), index=0
+ constant.4 = s32[] constant(0)
+ ROOT equal-to = pred[] equal-to(get-tuple-element.67, constant.4)
+}
+
+if-body.v5.1 {
+ constant.5 = s32[] constant(-1)
+ p.3 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
+ get-tuple-element.68 = (s32[], s32[], s32[]) get-tuple-element(p.3), index=1
+ get-tuple-element.70 = s32[] get-tuple-element(get-tuple-element.68), index=2
+ multiply.1 = s32[] multiply(get-tuple-element.70, get-tuple-element.70)
+ tuple.35 = (s32[]) tuple(multiply.1)
+ ROOT tuple.36 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(constant.5, get-tuple-element.68, tuple.35)
+}
+
+if-condition.v4.1 {
+ p.4 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
+ get-tuple-element.71 = s32[] get-tuple-element(p.4), index=0
+ constant.6 = s32[] constant(1)
+ ROOT equal-to.1 = pred[] equal-to(get-tuple-element.71, constant.6)
+}
+
+_functionalize_body_1__.v28 {
+ arg_tuple.4 = (s32[], s32[], s32[], s32[]) parameter(0)
+ get-tuple-element.72 = s32[] get-tuple-element(arg_tuple.4), index=0
+ constant.7 = s32[] constant(1)
+ add.4 = s32[] add(get-tuple-element.72, constant.7)
+ get-tuple-element.73 = s32[] get-tuple-element(arg_tuple.4), index=1
+ get-tuple-element.74 = s32[] get-tuple-element(arg_tuple.4), index=2
+ less-than-or-equal-to = pred[] less-than-or-equal-to(get-tuple-element.73, get-tuple-element.74)
+ constant.8 = s32[] constant(0)
+ select = s32[] select(less-than-or-equal-to, constant.8, constant.7)
+ get-tuple-element.75 = s32[] get-tuple-element(arg_tuple.4), index=3
+ tuple.37 = (s32[], s32[], s32[]) tuple(get-tuple-element.73, get-tuple-element.75, get-tuple-element.74)
+ tuple.38 = (s32[]) tuple(constant.8)
+ tuple.39 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(select, tuple.37, tuple.38)
+ while = (s32[], (s32[], s32[], s32[]), (s32[])) while(tuple.39), condition=if-condition.v4, body=if-body.v5
+ while.1 = (s32[], (s32[], s32[], s32[]), (s32[])) while(while), condition=if-condition.v4.1, body=if-body.v5.1
+ get-tuple-element.76 = (s32[]) get-tuple-element(while.1), index=2
+ get-tuple-element.77 = s32[] get-tuple-element(get-tuple-element.76), index=0
+ ROOT tuple.40 = (s32[], s32[], s32[], s32[]) tuple(add.4, get-tuple-element.73, get-tuple-element.74, get-tuple-element.77)
+}
+
+cond_wrapper.v3.1 {
+ inputs.1 = (s32[], s32[], s32[], s32[]) parameter(0)
+ get-tuple-element.78 = s32[] get-tuple-element(inputs.1), index=0
+ constant.11 = s32[] constant(7)
+ ROOT less-than.2 = pred[] less-than(get-tuple-element.78, constant.11)
+}
+
+_functionalize_body_2__.v25 {
+ arg_tuple.5 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
+ get-tuple-element.79 = s32[] get-tuple-element(arg_tuple.5), index=0
+ get-tuple-element.80 = s32[] get-tuple-element(arg_tuple.5), index=2
+ get-tuple-element.81 = s32[] get-tuple-element(arg_tuple.5), index=3
+ get-tuple-element.82 = s32[] get-tuple-element(arg_tuple.5), index=4
+ tuple.41 = (s32[], s32[], s32[], s32[]) tuple(get-tuple-element.79, get-tuple-element.80, get-tuple-element.81, get-tuple-element.82)
+ while.2 = (s32[], s32[], s32[], s32[]) while(tuple.41), condition=cond_wrapper.v3.1, body=_functionalize_body_1__.v28
+ get-tuple-element.83 = s32[] get-tuple-element(while.2), index=0
+ get-tuple-element.84 = s32[] get-tuple-element(arg_tuple.5), index=1
+ constant.12 = s32[] constant(1)
+ add.5 = s32[] add(get-tuple-element.84, constant.12)
+ get-tuple-element.85 = s32[] get-tuple-element(while.2), index=3
+ ROOT tuple.42 = (s32[], s32[], s32[], s32[], s32[]) tuple(get-tuple-element.83, add.5, get-tuple-element.80, get-tuple-element.81, get-tuple-element.85)
+}
+
+cond_wrapper.v3.2 {
+ inputs.2 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
+ get-tuple-element.86 = s32[] get-tuple-element(inputs.2), index=1
+ constant.13 = s32[] constant(5)
+ ROOT less-than.3 = pred[] less-than(get-tuple-element.86, constant.13)
+}
+
+ENTRY TestComputation {
+ arg_tuple.6 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
+ ROOT while.3 = (s32[], s32[], s32[], s32[], s32[]) while(arg_tuple.6), condition=cond_wrapper.v3.2, body=_functionalize_body_2__.v25
+}
+)";
+ auto module_or_status =
+ HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest());
+ auto module = module_or_status.ConsumeValueOrDie();
+ InsertCopies(module.get());
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
index 80d89d851e..d25fc5d741 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
@@ -154,7 +154,11 @@ bool HloDataflowAnalysis::Phi(
tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs) {
CHECK(ssa_form_);
VLOG(4) << "Phi(" << instruction->name() << ")";
-
+ VLOG(5) << "instruction value set = "
+ << GetInstructionValueSet(instruction).ToString();
+ for (const InstructionValueSet* input : inputs) {
+ VLOG(5) << "input value set = " << input->ToString();
+ }
for (const InstructionValueSet* input : inputs) {
DCHECK(ShapeUtil::Compatible(instruction->shape(), input->shape()));
}
@@ -171,9 +175,14 @@ bool HloDataflowAnalysis::Phi(
value_set.values().size() == 1 ? value_set.values()[0] : nullptr;
// Construct a vector of unique value IDs of the inputs.
+ // Don't add value ids where the input is equal to the definition.
std::vector<HloValue::Id> input_value_ids;
for (const InstructionValueSet* input : inputs) {
for (const HloValue* value : input->element(index).values()) {
+ if (value->defining_instruction() == instruction &&
+ value->defining_index() == index) {
+ continue;
+ }
input_value_ids.push_back(value->id());
}
}
@@ -190,6 +199,7 @@ bool HloDataflowAnalysis::Phi(
current_value->defining_instruction() == instruction &&
current_value->defining_index() == index);
if (current_value_defined_here) {
+ VLOG(5) << "current_value_defined_here: " << current_value->ToString();
CHECK(current_value->is_phi());
auto it = std::find(input_value_ids.begin(), input_value_ids.end(),
current_value->id());
@@ -197,7 +207,7 @@ bool HloDataflowAnalysis::Phi(
input_value_ids.erase(it);
}
}
-
+ VLOG(5) << "after input_value_ids.size = " << input_value_ids.size();
if (input_value_ids.empty()) {
// A value set which has at least one element should never have its value
// set reduced to zero elements. During dataflow value sets only can go