aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-06-14 14:32:48 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-14 14:36:34 -0700
commita7c36173cabcc1289a836e8143accb5f0914b19a (patch)
treee9d2cc747d3bceb067afe41cdcb4ae0f788e5faa
parentf5fcd1fdcf896f46aed03c7e61525b48b75d1acc (diff)
Use a non-recursive DFS in HloInstruction::Accept to avoid stack
overflow on deep graphs Even with this fix, we don't finish compiling the exact test case from b/38494745 in a reasonable amount of time (we spend a lot of time inside HloInstruction::FusionReusesParamElements::ComputeInternal, for instance), so I've used a smaller graph depth for now to avoid timing out the test. PiperOrigin-RevId: 159026595
-rw-r--r--tensorflow/BUILD2
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc128
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h6
-rw-r--r--tensorflow/compiler/xla/tests/BUILD10
-rw-r--r--tensorflow/compiler/xla/tests/deep_graph_test.cc58
5 files changed, 140 insertions, 64 deletions
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 60d58dd594..bb69bf172f 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -39,7 +39,7 @@ config_setting(
config_setting(
name = "android_armeabi",
values = {
- "cc_target_os": "android",
+ "crosstool_top": "//external:android/crosstool",
"cpu": "armeabi",
},
visibility = ["//visibility:public"],
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index f926cb4bc7..6bb9e9a9e6 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -1893,72 +1893,86 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) {
HloOpcodeString(opcode_).c_str());
}
-Status HloInstruction::AcceptInternal(DfsHloVisitor* visitor,
- const CompareFunction* operand_order,
- bool ignore_control_predecessors) {
- // Do not visit this HLO node again if it is already visited.
- if (visitor->DidVisit(*this)) {
- VLOG(3) << "Not visiting HLO " << name() << " as it was already visited.";
- return Status::OK();
- }
-
- // If the instruction is in the visiting state, it means a cycle.
- if (visitor->IsVisiting(*this)) {
+static Status PushDFSChild(DfsHloVisitor* visitor,
+ std::vector<HloInstruction*>* dfs_stack,
+ HloInstruction* parent, HloInstruction* child) {
+ if (visitor->IsVisiting(*child)) {
return FailedPrecondition(
"A cycle is detected while visiting instruction %s",
- ToString().c_str());
- }
- visitor->SetVisiting(*this);
-
- // Sort operands, if an ordering was provided. 'temp_sorted_operands' must
- // live at this scope, since 'operands' will point to it if the operands are
- // sorted. The purpose of the 'operands' pointer is to avoid copying the
- // operands in the common case where the operands are not sorted.
- std::vector<HloInstruction*>* operands = &operands_;
- std::vector<HloInstruction*> temp_sorted_operands;
- if (operand_order != nullptr) {
- temp_sorted_operands = operands_;
- std::sort(temp_sorted_operands.begin(), temp_sorted_operands.end(),
- *operand_order);
- operands = &temp_sorted_operands;
- }
- for (HloInstruction* operand : *operands) {
- VLOG(3) << "Going to visit HLO " << operand->name() << " as operand of HLO "
- << name();
- TF_RETURN_IF_ERROR(operand->AcceptInternal(visitor, operand_order,
- ignore_control_predecessors));
- }
-
- if (!ignore_control_predecessors) {
- // This uses the same pointer/vector sorting to avoid extra copies as above.
- std::vector<HloInstruction*>* predecessors = &control_predecessors_;
- std::vector<HloInstruction*> temp_sorted_predecessors;
- if (operand_order != nullptr) {
- temp_sorted_predecessors = control_predecessors_;
- std::sort(temp_sorted_predecessors.begin(),
- temp_sorted_predecessors.end(), *operand_order);
- predecessors = &temp_sorted_predecessors;
+ parent->ToString().c_str());
+ }
+
+ if (!visitor->DidVisit(*child)) {
+ dfs_stack->push_back(child);
+ } else {
+ VLOG(3) << "Not visiting HLO " << child->name()
+ << " as it was already visited.";
+ }
+ return Status::OK();
+}
+
+static Status PostOrderDFS(HloInstruction* root, DfsHloVisitor* visitor,
+ const HloInstruction::CompareFunction* operand_order,
+ bool ignore_control_predecessors) {
+ std::vector<HloInstruction*> dfs_stack;
+ dfs_stack.push_back(root);
+
+ do {
+ DCHECK(!dfs_stack.empty());
+
+ HloInstruction* current_node = dfs_stack.back();
+ if (visitor->DidVisit(*current_node)) {
+ dfs_stack.pop_back();
+ VLOG(3) << "Not visiting HLO " << current_node->name()
+ << " as it was already visited.";
+ continue;
}
- for (HloInstruction* control_predecessor : *predecessors) {
- VLOG(3) << "Going to visit HLO " << control_predecessor->name()
- << " as a control predecessor of HLO " << name();
- TF_RETURN_IF_ERROR(control_predecessor->AcceptInternal(
- visitor, operand_order, ignore_control_predecessors));
+
+ if (visitor->IsVisiting(*current_node)) {
+ dfs_stack.pop_back();
+
+ TF_RETURN_IF_ERROR(visitor->Preprocess(current_node));
+ VLOG(2) << "Visiting HLO " << current_node->name();
+ TF_RETURN_IF_ERROR(current_node->Visit(visitor));
+ visitor->SetVisited(*current_node);
+ TF_RETURN_IF_ERROR(visitor->Postprocess(current_node));
+ continue;
}
- }
- TF_RETURN_IF_ERROR(visitor->Preprocess(this));
- VLOG(2) << "Visiting HLO " << name();
- TF_RETURN_IF_ERROR(Visit(visitor));
- visitor->SetVisited(*this);
- return visitor->Postprocess(this);
+ visitor->SetVisiting(*current_node);
+
+ const size_t old_dfs_stack_size = dfs_stack.size();
+
+ for (HloInstruction* child : current_node->operands()) {
+ TF_RETURN_IF_ERROR(
+ PushDFSChild(visitor, &dfs_stack, current_node, child));
+ }
+
+ if (!ignore_control_predecessors) {
+ for (HloInstruction* child : current_node->control_predecessors()) {
+ TF_RETURN_IF_ERROR(
+ PushDFSChild(visitor, &dfs_stack, current_node, child));
+ }
+ }
+
+ if (operand_order != nullptr) {
+ std::sort(dfs_stack.begin() + old_dfs_stack_size, dfs_stack.end(),
+ *operand_order);
+ }
+
+ // This makes the traversal order the same as what you'd expect
+ // out of a recursive algorithm.
+ std::reverse(dfs_stack.begin() + old_dfs_stack_size, dfs_stack.end());
+ } while (!dfs_stack.empty());
+
+ return Status::OK();
}
Status HloInstruction::Accept(DfsHloVisitor* visitor, bool call_finish_visit,
bool ignore_control_predecessors) {
VLOG(2) << "HloInstruction::Accept(" << name() << ")";
TF_RETURN_IF_ERROR(
- AcceptInternal(visitor, nullptr, ignore_control_predecessors));
+ PostOrderDFS(this, visitor, nullptr, ignore_control_predecessors));
if (call_finish_visit) {
TF_RETURN_IF_ERROR(visitor->FinishVisit(this));
}
@@ -1969,8 +1983,8 @@ Status HloInstruction::AcceptWithOperandOrder(
DfsHloVisitor* visitor, const CompareFunction& operand_order,
bool call_finish_visit) {
VLOG(2) << "HloInstruction::AcceptWithOperandOrder(" << name() << ")";
- TF_RETURN_IF_ERROR(AcceptInternal(visitor, &operand_order,
- /*ignore_control_predecessors=*/false));
+ TF_RETURN_IF_ERROR(PostOrderDFS(this, visitor, &operand_order,
+ /*ignore_control_predecessors=*/false));
if (call_finish_visit) {
TF_RETURN_IF_ERROR(visitor->FinishVisit(this));
}
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index f98bafe81e..cb19c84814 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -825,12 +825,6 @@ class HloInstruction {
const Shape& shape,
tensorflow::gtl::ArraySlice<HloInstruction*> operands);
- // Inner DFS traversal function -- this function being called (rather than
- // Accept above) allows us to distinguish the root of the traversal.
- Status AcceptInternal(DfsHloVisitor* visitor,
- const CompareFunction* operand_order,
- bool ignore_control_predecessors);
-
// CHECKs various invariants of a fusion instruction.
void CheckFusionInstruction() const;
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 57523b34bb..5fa515b26f 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -1357,6 +1357,16 @@ xla_test(
],
)
+xla_test(
+ name = "deep_graph_test",
+ srcs = ["deep_graph_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
+ "//tensorflow/compiler/xla/legacy_flags:user_computation_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ ],
+)
+
cc_test(
name = "literal_test_util_test",
srcs = ["literal_test_util_test.cc"],
diff --git a/tensorflow/compiler/xla/tests/deep_graph_test.cc b/tensorflow/compiler/xla/tests/deep_graph_test.cc
new file mode 100644
index 0000000000..7a5601ada3
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/deep_graph_test.cc
@@ -0,0 +1,58 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
+#include "tensorflow/compiler/xla/legacy_flags/user_computation_flags.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+
+namespace xla {
+namespace {
+TEST_F(ClientLibraryTestBase, DeepGraph) {
+ // TODO(b/62624812): To trigger the stack overflow this test is
+ // intended to track, we need to set kDepth to 20000.
+ // Unfortunately, setting it that high causes the test to time out.
+ const int kDepth = 200;
+ ComputationBuilder b(client_, TestName());
+ ComputationDataHandle x;
+ ComputationDataHandle y;
+ auto x_data = CreateR0Parameter<int32>(3, 0, "x", &b, &x);
+ auto y_data = CreateR0Parameter<int32>(1, 1, "y", &b, &y);
+ ComputationDataHandle z = x;
+ for (int i = 0; i < kDepth; ++i) {
+ z = b.Add(z, y);
+ }
+ ComputeAndCompareR0<int32>(&b, /*expected=*/kDepth + 3,
+ {x_data.get(), y_data.get()});
+}
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendDebugOptionsFlags(&flag_list);
+ xla::legacy_flags::AppendUserComputationFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}