diff options
author | 2017-06-14 14:32:48 -0700 | |
---|---|---|
committer | 2017-06-14 14:36:34 -0700 | |
commit | a7c36173cabcc1289a836e8143accb5f0914b19a (patch) | |
tree | e9d2cc747d3bceb067afe41cdcb4ae0f788e5faa | |
parent | f5fcd1fdcf896f46aed03c7e61525b48b75d1acc (diff) |
Use a non-recursive DFS in HloInstruction::Accept to avoid stack
overflow on deep graphs
Even with this fix, we don't finish compiling the exact test case from
b/38494745 in a reasonable amount of time (we spend a lot of time
inside HloInstruction::FusionReusesParamElements::ComputeInternal, for
instance), so I've used a smaller graph depth for now to avoid timing
out the test.
PiperOrigin-RevId: 159026595
-rw-r--r-- | tensorflow/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.cc | 128 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.h | 6 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/BUILD | 10 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/deep_graph_test.cc | 58 |
5 files changed, 140 insertions, 64 deletions
diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 60d58dd594..bb69bf172f 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -39,7 +39,7 @@ config_setting( config_setting( name = "android_armeabi", values = { - "cc_target_os": "android", + "crosstool_top": "//external:android/crosstool", "cpu": "armeabi", }, visibility = ["//visibility:public"], diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index f926cb4bc7..6bb9e9a9e6 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -1893,72 +1893,86 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) { HloOpcodeString(opcode_).c_str()); } -Status HloInstruction::AcceptInternal(DfsHloVisitor* visitor, - const CompareFunction* operand_order, - bool ignore_control_predecessors) { - // Do not visit this HLO node again if it is already visited. - if (visitor->DidVisit(*this)) { - VLOG(3) << "Not visiting HLO " << name() << " as it was already visited."; - return Status::OK(); - } - - // If the instruction is in the visiting state, it means a cycle. - if (visitor->IsVisiting(*this)) { +static Status PushDFSChild(DfsHloVisitor* visitor, + std::vector<HloInstruction*>* dfs_stack, + HloInstruction* parent, HloInstruction* child) { + if (visitor->IsVisiting(*child)) { return FailedPrecondition( "A cycle is detected while visiting instruction %s", - ToString().c_str()); - } - visitor->SetVisiting(*this); - - // Sort operands, if an ordering was provided. 'temp_sorted_operands' must - // live at this scope, since 'operands' will point to it if the operands are - // sorted. The purpose of the 'operands' pointer is to avoid copying the - // operands in the common case where the operands are not sorted. - std::vector<HloInstruction*>* operands = &operands_; - std::vector<HloInstruction*> temp_sorted_operands; - if (operand_order != nullptr) { - temp_sorted_operands = operands_; - std::sort(temp_sorted_operands.begin(), temp_sorted_operands.end(), - *operand_order); - operands = &temp_sorted_operands; - } - for (HloInstruction* operand : *operands) { - VLOG(3) << "Going to visit HLO " << operand->name() << " as operand of HLO " - << name(); - TF_RETURN_IF_ERROR(operand->AcceptInternal(visitor, operand_order, - ignore_control_predecessors)); - } - - if (!ignore_control_predecessors) { - // This uses the same pointer/vector sorting to avoid extra copies as above. - std::vector<HloInstruction*>* predecessors = &control_predecessors_; - std::vector<HloInstruction*> temp_sorted_predecessors; - if (operand_order != nullptr) { - temp_sorted_predecessors = control_predecessors_; - std::sort(temp_sorted_predecessors.begin(), - temp_sorted_predecessors.end(), *operand_order); - predecessors = &temp_sorted_predecessors; + parent->ToString().c_str()); + } + + if (!visitor->DidVisit(*child)) { + dfs_stack->push_back(child); + } else { + VLOG(3) << "Not visiting HLO " << child->name() + << " as it was already visited."; + } + return Status::OK(); +} + +static Status PostOrderDFS(HloInstruction* root, DfsHloVisitor* visitor, + const HloInstruction::CompareFunction* operand_order, + bool ignore_control_predecessors) { + std::vector<HloInstruction*> dfs_stack; + dfs_stack.push_back(root); + + do { + DCHECK(!dfs_stack.empty()); + + HloInstruction* current_node = dfs_stack.back(); + if (visitor->DidVisit(*current_node)) { + dfs_stack.pop_back(); + VLOG(3) << "Not visiting HLO " << current_node->name() + << " as it was already visited."; + continue; } - for (HloInstruction* control_predecessor : *predecessors) { - VLOG(3) << "Going to visit HLO " << control_predecessor->name() - << " as a control predecessor of HLO " << name(); - TF_RETURN_IF_ERROR(control_predecessor->AcceptInternal( - visitor, operand_order, ignore_control_predecessors)); + + if (visitor->IsVisiting(*current_node)) { + dfs_stack.pop_back(); + + TF_RETURN_IF_ERROR(visitor->Preprocess(current_node)); + VLOG(2) << "Visiting HLO " << current_node->name(); + TF_RETURN_IF_ERROR(current_node->Visit(visitor)); + visitor->SetVisited(*current_node); + TF_RETURN_IF_ERROR(visitor->Postprocess(current_node)); + continue; } - } - TF_RETURN_IF_ERROR(visitor->Preprocess(this)); - VLOG(2) << "Visiting HLO " << name(); - TF_RETURN_IF_ERROR(Visit(visitor)); - visitor->SetVisited(*this); - return visitor->Postprocess(this); + visitor->SetVisiting(*current_node); + + const size_t old_dfs_stack_size = dfs_stack.size(); + + for (HloInstruction* child : current_node->operands()) { + TF_RETURN_IF_ERROR( + PushDFSChild(visitor, &dfs_stack, current_node, child)); + } + + if (!ignore_control_predecessors) { + for (HloInstruction* child : current_node->control_predecessors()) { + TF_RETURN_IF_ERROR( + PushDFSChild(visitor, &dfs_stack, current_node, child)); + } + } + + if (operand_order != nullptr) { + std::sort(dfs_stack.begin() + old_dfs_stack_size, dfs_stack.end(), + *operand_order); + } + + // This makes the traversal order the same as what you'd expect + // out of a recursive algorithm. + std::reverse(dfs_stack.begin() + old_dfs_stack_size, dfs_stack.end()); + } while (!dfs_stack.empty()); + + return Status::OK(); } Status HloInstruction::Accept(DfsHloVisitor* visitor, bool call_finish_visit, bool ignore_control_predecessors) { VLOG(2) << "HloInstruction::Accept(" << name() << ")"; TF_RETURN_IF_ERROR( - AcceptInternal(visitor, nullptr, ignore_control_predecessors)); + PostOrderDFS(this, visitor, nullptr, ignore_control_predecessors)); if (call_finish_visit) { TF_RETURN_IF_ERROR(visitor->FinishVisit(this)); } @@ -1969,8 +1983,8 @@ Status HloInstruction::AcceptWithOperandOrder( DfsHloVisitor* visitor, const CompareFunction& operand_order, bool call_finish_visit) { VLOG(2) << "HloInstruction::AcceptWithOperandOrder(" << name() << ")"; - TF_RETURN_IF_ERROR(AcceptInternal(visitor, &operand_order, - /*ignore_control_predecessors=*/false)); + TF_RETURN_IF_ERROR(PostOrderDFS(this, visitor, &operand_order, + /*ignore_control_predecessors=*/false)); if (call_finish_visit) { TF_RETURN_IF_ERROR(visitor->FinishVisit(this)); } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index f98bafe81e..cb19c84814 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -825,12 +825,6 @@ class HloInstruction { const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands); - // Inner DFS traversal function -- this function being called (rather than - // Accept above) allows us to distinguish the root of the traversal. - Status AcceptInternal(DfsHloVisitor* visitor, - const CompareFunction* operand_order, - bool ignore_control_predecessors); - // CHECKs various invariants of a fusion instruction. void CheckFusionInstruction() const; diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 57523b34bb..5fa515b26f 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -1357,6 +1357,16 @@ xla_test( ], ) +xla_test( + name = "deep_graph_test", + srcs = ["deep_graph_test.cc"], + deps = [ + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", + "//tensorflow/compiler/xla/legacy_flags:user_computation_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + ], +) + cc_test( name = "literal_test_util_test", srcs = ["literal_test_util_test.cc"], diff --git a/tensorflow/compiler/xla/tests/deep_graph_test.cc b/tensorflow/compiler/xla/tests/deep_graph_test.cc new file mode 100644 index 0000000000..7a5601ada3 --- /dev/null +++ b/tensorflow/compiler/xla/tests/deep_graph_test.cc @@ -0,0 +1,58 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/user_computation_flags.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" + +namespace xla { +namespace { +TEST_F(ClientLibraryTestBase, DeepGraph) { + // TODO(b/62624812): To trigger the stack overflow this test is + // intended to track, we need to set kDepth to 20000. + // Unfortunately, setting it that high causes the test to time out. + const int kDepth = 200; + ComputationBuilder b(client_, TestName()); + ComputationDataHandle x; + ComputationDataHandle y; + auto x_data = CreateR0Parameter<int32>(3, 0, "x", &b, &x); + auto y_data = CreateR0Parameter<int32>(1, 1, "y", &b, &y); + ComputationDataHandle z = x; + for (int i = 0; i < kDepth; ++i) { + z = b.Add(z, y); + } + ComputeAndCompareR0<int32>(&b, /*expected=*/kDepth + 3, + {x_data.get(), y_data.get()}); +} +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector<tensorflow::Flag> flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); + xla::legacy_flags::AppendUserComputationFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} |