diff options
author | 2018-01-10 14:52:27 -0800 | |
---|---|---|
committer | 2018-01-10 15:03:36 -0800 | |
commit | 4c24ab2d0651b048d81c4743b73ac92b5c39d8cc (patch) | |
tree | 673b9b0e957f1a100196172904c34bd74a447e3b | |
parent | 69f231135304799f581d34df37d72cdc07fc8f58 (diff) |
[TF:XLA] Fix infinite loop in HLO data flow analysis.
Merge input values at phi nodes correctly: If a phi operand is the phi itself,
and the other operands are all the same, then the phi node is redundant.
PiperOrigin-RevId: 181521522
-rw-r--r-- | tensorflow/compiler/xla/service/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/copy_insertion_test.cc | 185 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc | 14 |
3 files changed, 198 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index f26dc64fee..16d227a00f 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -1684,6 +1684,7 @@ tf_cc_test( ":hlo", ":hlo_graph_dumper", ":hlo_matchers", + ":hlo_runner", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -1691,7 +1692,6 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], ) diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index 8388574716..128ee726ea 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_runner.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -1726,5 +1727,189 @@ void BM_ParallelWhiles(int num_iters, int num_whiles) { BENCHMARK(BM_SequentialWhiles)->Arg(512)->Arg(1024)->Arg(2048)->Arg(4096); BENCHMARK(BM_ParallelWhiles)->Arg(512)->Arg(1024)->Arg(2048)->Arg(4096); +TEST_F(CopyInsertionTest, SimpleControlFlowTest) { + const string& hlo_string = R"( +HloModule TestModule + +if-body.v5 { + constant.3 = s32[] constant(-1) + p.1 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0) + get-tuple-element.18 = (s32[], s32[], s32[]) get-tuple-element(p.1), index=1 + get-tuple-element.65 = s32[] get-tuple-element(get-tuple-element.18), index=0 + get-tuple-element.66 = s32[] get-tuple-element(get-tuple-element.18), index=1 + add.3 = s32[] add(get-tuple-element.65, get-tuple-element.66) + tuple.33 = (s32[]) tuple(add.3) + ROOT tuple.34 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(constant.3, get-tuple-element.18, tuple.33) +} + +if-condition.v4 { + p.2 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0) + get-tuple-element.67 = s32[] get-tuple-element(p.2), index=0 + constant.4 = s32[] constant(0) + ROOT equal-to = pred[] equal-to(get-tuple-element.67, constant.4) +} + +_functionalize_body_1__.v28 { + arg_tuple.4 = (s32[], s32[], s32[], s32[]) parameter(0) + get-tuple-element.68 = s32[] get-tuple-element(arg_tuple.4), index=0 + constant.7 = s32[] constant(1) + add.4 = s32[] add(get-tuple-element.68, constant.7) + get-tuple-element.69 = s32[] get-tuple-element(arg_tuple.4), index=1 + get-tuple-element.70 = s32[] get-tuple-element(arg_tuple.4), index=2 + less-than-or-equal-to = pred[] less-than-or-equal-to(get-tuple-element.69, get-tuple-element.70) + constant.8 = s32[] constant(0) + select = s32[] select(less-than-or-equal-to, constant.8, constant.7) + get-tuple-element.71 = s32[] get-tuple-element(arg_tuple.4), index=3 + tuple.35 = (s32[], s32[], s32[]) tuple(get-tuple-element.69, get-tuple-element.71, get-tuple-element.70) + tuple.36 = (s32[]) tuple(constant.8) + tuple.37 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(select, tuple.35, tuple.36) + while = (s32[], (s32[], s32[], s32[]), (s32[])) while(tuple.37), condition=if-condition.v4, body=if-body.v5 + get-tuple-element.72 = (s32[]) get-tuple-element(while), index=2 + get-tuple-element.73 = s32[] get-tuple-element(get-tuple-element.72), index=0 + ROOT tuple.38 = (s32[], s32[], s32[], s32[]) tuple(add.4, get-tuple-element.69, get-tuple-element.70, get-tuple-element.73) +} + +cond_wrapper.v3.1 { + inputs.1 = (s32[], s32[], s32[], s32[]) parameter(0) + get-tuple-element.75 = s32[] get-tuple-element(inputs.1), index=0 + constant.11 = s32[] constant(7) + ROOT less-than.2 = pred[] less-than(get-tuple-element.75, constant.11) +} + +_functionalize_body_2__.v25 { + arg_tuple.5 = (s32[], s32[], s32[], s32[], s32[]) parameter(0) + get-tuple-element.76 = s32[] get-tuple-element(arg_tuple.5), index=0 + get-tuple-element.77 = s32[] get-tuple-element(arg_tuple.5), index=2 + get-tuple-element.78 = s32[] get-tuple-element(arg_tuple.5), index=3 + get-tuple-element.79 = s32[] get-tuple-element(arg_tuple.5), index=4 + tuple.39 = (s32[], s32[], s32[], s32[]) tuple(get-tuple-element.76, get-tuple-element.77, get-tuple-element.78, get-tuple-element.79) + while.2 = (s32[], s32[], s32[], s32[]) while(tuple.39), condition=cond_wrapper.v3.1, body=_functionalize_body_1__.v28 + get-tuple-element.80 = s32[] get-tuple-element(while.2), index=0 + get-tuple-element.81 = s32[] get-tuple-element(arg_tuple.5), index=1 + constant.12 = s32[] constant(1) + add.5 = s32[] add(get-tuple-element.81, constant.12) + get-tuple-element.82 = s32[] get-tuple-element(while.2), index=3 + ROOT tuple.40 = (s32[], s32[], s32[], s32[], s32[]) tuple(get-tuple-element.80, add.5, get-tuple-element.77, get-tuple-element.78, get-tuple-element.82) +} + +cond_wrapper.v3.2 { + inputs.2 = (s32[], s32[], s32[], s32[], s32[]) parameter(0) + get-tuple-element.83 = s32[] get-tuple-element(inputs.2), index=1 + constant.13 = s32[] constant(5) + ROOT less-than.3 = pred[] less-than(get-tuple-element.83, constant.13) +} + +ENTRY TestComputation { + arg_tuple.6 = (s32[], s32[], s32[], s32[], s32[]) parameter(0) + ROOT while.3 = (s32[], s32[], s32[], s32[], s32[]) while(arg_tuple.6), condition=cond_wrapper.v3.2, body=_functionalize_body_2__.v25 +} +)"; + auto module_or_status = + HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); + auto module = module_or_status.ConsumeValueOrDie(); + InsertCopies(module.get()); +} + +TEST_F(CopyInsertionTest, ControlFlowTest) { + const string& hlo_string = R"( +HloModule TestModule + +if-body.v5 { + constant.3 = s32[] constant(-1) + p.1 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0) + get-tuple-element.18 = (s32[], s32[], s32[]) get-tuple-element(p.1), index=1 + get-tuple-element.65 = s32[] get-tuple-element(get-tuple-element.18), index=0 + get-tuple-element.66 = s32[] get-tuple-element(get-tuple-element.18), index=1 + add.3 = s32[] add(get-tuple-element.65, get-tuple-element.66) + tuple.33 = (s32[]) tuple(add.3) + ROOT tuple.34 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(constant.3, get-tuple-element.18, tuple.33) +} + +if-condition.v4 { + p.2 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0) + get-tuple-element.67 = s32[] get-tuple-element(p.2), index=0 + constant.4 = s32[] constant(0) + ROOT equal-to = pred[] equal-to(get-tuple-element.67, constant.4) +} + +if-body.v5.1 { + constant.5 = s32[] constant(-1) + p.3 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0) + get-tuple-element.68 = (s32[], s32[], s32[]) get-tuple-element(p.3), index=1 + get-tuple-element.70 = s32[] get-tuple-element(get-tuple-element.68), index=2 + multiply.1 = s32[] multiply(get-tuple-element.70, get-tuple-element.70) + tuple.35 = (s32[]) tuple(multiply.1) + ROOT tuple.36 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(constant.5, get-tuple-element.68, tuple.35) +} + +if-condition.v4.1 { + p.4 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0) + get-tuple-element.71 = s32[] get-tuple-element(p.4), index=0 + constant.6 = s32[] constant(1) + ROOT equal-to.1 = pred[] equal-to(get-tuple-element.71, constant.6) +} + +_functionalize_body_1__.v28 { + arg_tuple.4 = (s32[], s32[], s32[], s32[]) parameter(0) + get-tuple-element.72 = s32[] get-tuple-element(arg_tuple.4), index=0 + constant.7 = s32[] constant(1) + add.4 = s32[] add(get-tuple-element.72, constant.7) + get-tuple-element.73 = s32[] get-tuple-element(arg_tuple.4), index=1 + get-tuple-element.74 = s32[] get-tuple-element(arg_tuple.4), index=2 + less-than-or-equal-to = pred[] less-than-or-equal-to(get-tuple-element.73, get-tuple-element.74) + constant.8 = s32[] constant(0) + select = s32[] select(less-than-or-equal-to, constant.8, constant.7) + get-tuple-element.75 = s32[] get-tuple-element(arg_tuple.4), index=3 + tuple.37 = (s32[], s32[], s32[]) tuple(get-tuple-element.73, get-tuple-element.75, get-tuple-element.74) + tuple.38 = (s32[]) tuple(constant.8) + tuple.39 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(select, tuple.37, tuple.38) + while = (s32[], (s32[], s32[], s32[]), (s32[])) while(tuple.39), condition=if-condition.v4, body=if-body.v5 + while.1 = (s32[], (s32[], s32[], s32[]), (s32[])) while(while), condition=if-condition.v4.1, body=if-body.v5.1 + get-tuple-element.76 = (s32[]) get-tuple-element(while.1), index=2 + get-tuple-element.77 = s32[] get-tuple-element(get-tuple-element.76), index=0 + ROOT tuple.40 = (s32[], s32[], s32[], s32[]) tuple(add.4, get-tuple-element.73, get-tuple-element.74, get-tuple-element.77) +} + +cond_wrapper.v3.1 { + inputs.1 = (s32[], s32[], s32[], s32[]) parameter(0) + get-tuple-element.78 = s32[] get-tuple-element(inputs.1), index=0 + constant.11 = s32[] constant(7) + ROOT less-than.2 = pred[] less-than(get-tuple-element.78, constant.11) +} + +_functionalize_body_2__.v25 { + arg_tuple.5 = (s32[], s32[], s32[], s32[], s32[]) parameter(0) + get-tuple-element.79 = s32[] get-tuple-element(arg_tuple.5), index=0 + get-tuple-element.80 = s32[] get-tuple-element(arg_tuple.5), index=2 + get-tuple-element.81 = s32[] get-tuple-element(arg_tuple.5), index=3 + get-tuple-element.82 = s32[] get-tuple-element(arg_tuple.5), index=4 + tuple.41 = (s32[], s32[], s32[], s32[]) tuple(get-tuple-element.79, get-tuple-element.80, get-tuple-element.81, get-tuple-element.82) + while.2 = (s32[], s32[], s32[], s32[]) while(tuple.41), condition=cond_wrapper.v3.1, body=_functionalize_body_1__.v28 + get-tuple-element.83 = s32[] get-tuple-element(while.2), index=0 + get-tuple-element.84 = s32[] get-tuple-element(arg_tuple.5), index=1 + constant.12 = s32[] constant(1) + add.5 = s32[] add(get-tuple-element.84, constant.12) + get-tuple-element.85 = s32[] get-tuple-element(while.2), index=3 + ROOT tuple.42 = (s32[], s32[], s32[], s32[], s32[]) tuple(get-tuple-element.83, add.5, get-tuple-element.80, get-tuple-element.81, get-tuple-element.85) +} + +cond_wrapper.v3.2 { + inputs.2 = (s32[], s32[], s32[], s32[], s32[]) parameter(0) + get-tuple-element.86 = s32[] get-tuple-element(inputs.2), index=1 + constant.13 = s32[] constant(5) + ROOT less-than.3 = pred[] less-than(get-tuple-element.86, constant.13) +} + +ENTRY TestComputation { + arg_tuple.6 = (s32[], s32[], s32[], s32[], s32[]) parameter(0) + ROOT while.3 = (s32[], s32[], s32[], s32[], s32[]) while(arg_tuple.6), condition=cond_wrapper.v3.2, body=_functionalize_body_2__.v25 +} +)"; + auto module_or_status = + HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); + auto module = module_or_status.ConsumeValueOrDie(); + InsertCopies(module.get()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index 80d89d851e..d25fc5d741 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -154,7 +154,11 @@ bool HloDataflowAnalysis::Phi( tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs) { CHECK(ssa_form_); VLOG(4) << "Phi(" << instruction->name() << ")"; - + VLOG(5) << "instruction value set = " + << GetInstructionValueSet(instruction).ToString(); + for (const InstructionValueSet* input : inputs) { + VLOG(5) << "input value set = " << input->ToString(); + } for (const InstructionValueSet* input : inputs) { DCHECK(ShapeUtil::Compatible(instruction->shape(), input->shape())); } @@ -171,9 +175,14 @@ bool HloDataflowAnalysis::Phi( value_set.values().size() == 1 ? value_set.values()[0] : nullptr; // Construct a vector of unique value IDs of the inputs. + // Don't add value ids where the input is equal to the definition. std::vector<HloValue::Id> input_value_ids; for (const InstructionValueSet* input : inputs) { for (const HloValue* value : input->element(index).values()) { + if (value->defining_instruction() == instruction && + value->defining_index() == index) { + continue; + } input_value_ids.push_back(value->id()); } } @@ -190,6 +199,7 @@ bool HloDataflowAnalysis::Phi( current_value->defining_instruction() == instruction && current_value->defining_index() == index); if (current_value_defined_here) { + VLOG(5) << "current_value_defined_here: " << current_value->ToString(); CHECK(current_value->is_phi()); auto it = std::find(input_value_ids.begin(), input_value_ids.end(), current_value->id()); @@ -197,7 +207,7 @@ bool HloDataflowAnalysis::Phi( input_value_ids.erase(it); } } - + VLOG(5) << "after input_value_ids.size = " << input_value_ids.size(); if (input_value_ids.empty()) { // A value set which has at least one element should never have its value // set reduced to zero elements. During dataflow value sets only can go |