From 11749434e3eb04eee058a43a931a27bdee4916df Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 1 Dec 2017 10:58:17 -0800 Subject: Make TopologicalSort return an error status if the sorting fails. PiperOrigin-RevId: 177612830 --- tensorflow/core/grappler/utils/BUILD | 1 + tensorflow/core/grappler/utils/topological_sort.cc | 25 ++++++++++++---------- tensorflow/core/grappler/utils/topological_sort.h | 3 ++- .../core/grappler/utils/topological_sort_test.cc | 9 ++++---- 4 files changed, 22 insertions(+), 16 deletions(-) (limited to 'tensorflow/core/grappler/utils') diff --git a/tensorflow/core/grappler/utils/BUILD b/tensorflow/core/grappler/utils/BUILD index 21243833ac..534f7a063f 100644 --- a/tensorflow/core/grappler/utils/BUILD +++ b/tensorflow/core/grappler/utils/BUILD @@ -53,6 +53,7 @@ cc_library( hdrs = ["topological_sort.h"], visibility = ["//visibility:public"], deps = [ + "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:op_types", diff --git a/tensorflow/core/grappler/utils/topological_sort.cc b/tensorflow/core/grappler/utils/topological_sort.cc index 77d4702d21..d87f43a498 100644 --- a/tensorflow/core/grappler/utils/topological_sort.cc +++ b/tensorflow/core/grappler/utils/topological_sort.cc @@ -19,13 +19,14 @@ limitations under the License. #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/lib/core/status.h" namespace tensorflow { namespace grappler { // Kahn's algorithm is implemented. // For details, see https://en.wikipedia.org/wiki/Topological_sorting -void TopologicalSort(GraphDef* graph) { +Status TopologicalSort(GraphDef* graph) { OutputMap output_map(graph); std::vector ready_nodes; ready_nodes.reserve(graph->node_size()); @@ -63,17 +64,19 @@ void TopologicalSort(GraphDef* graph) { front++; } - if (back == graph->node_size()) { - GraphDef new_graph; - new_graph.mutable_node()->Reserve(graph->node_size()); - for (int i = 0; i < graph->node_size(); i++) { - auto new_node = new_graph.add_node(); - new_node->Swap(ready_nodes[i]); - } - graph->mutable_node()->Swap(new_graph.mutable_node()); - } else { - LOG(ERROR) << "The graph couldn't be sorted in topological order."; + if (back != graph->node_size()) { + return errors::InvalidArgument( + "The graph couldn't be sorted in topological order."); + } + + GraphDef new_graph; + new_graph.mutable_node()->Reserve(graph->node_size()); + for (int i = 0; i < graph->node_size(); i++) { + auto new_node = new_graph.add_node(); + new_node->Swap(ready_nodes[i]); } + graph->mutable_node()->Swap(new_graph.mutable_node()); + return Status::OK(); } } // namespace grappler diff --git a/tensorflow/core/grappler/utils/topological_sort.h b/tensorflow/core/grappler/utils/topological_sort.h index d4d8034ef5..f2c9bbfa4e 100644 --- a/tensorflow/core/grappler/utils/topological_sort.h +++ b/tensorflow/core/grappler/utils/topological_sort.h @@ -17,12 +17,13 @@ limitations under the License. #define THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_UTILS_TOPOLOGICAL_SORT_H_ #include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/lib/core/status.h" namespace tensorflow { namespace grappler { // Sort a graph in topological order. -void TopologicalSort(GraphDef* graph); +Status TopologicalSort(GraphDef* graph); } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/utils/topological_sort_test.cc b/tensorflow/core/grappler/utils/topological_sort_test.cc index dc99cb1052..ba0fe0155a 100644 --- a/tensorflow/core/grappler/utils/topological_sort_test.cc +++ b/tensorflow/core/grappler/utils/topological_sort_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/grappler/utils/topological_sort.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" @@ -51,7 +52,7 @@ TEST_F(TopologicalSortTest, NoLoop) { *graph.add_node() = CreateNode("5", {}); *graph.add_node() = CreateNode("4", {}); - TopologicalSort(&graph); + TF_EXPECT_OK(TopologicalSort(&graph)); std::vector order = {"5", "4", "2", "0", "3", "1"}; for (int i = 0; i < order.size(); i++) { EXPECT_EQ(graph.node(i).name(), order[i]); @@ -67,7 +68,7 @@ TEST_F(TopologicalSortTest, WithLoop) { *graph.add_node() = CreateNode("5", "NextIteration", {"4"}); *graph.add_node() = CreateNode("1", {}); - TopologicalSort(&graph); + TF_EXPECT_OK(TopologicalSort(&graph)); std::vector order = {"1", "2", "3", "4", "5"}; for (int i = 0; i < order.size(); i++) { EXPECT_EQ(graph.node(i).name(), order[i]); @@ -82,7 +83,7 @@ TEST_F(TopologicalSortTest, WithIllegalLoop) { *graph.add_node() = CreateNode("3", {"2"}); *graph.add_node() = CreateNode("1", {}); - TopologicalSort(&graph); + EXPECT_FALSE(TopologicalSort(&graph).ok()); std::vector order = {"2", "3", "1"}; for (int i = 0; i < order.size(); i++) { EXPECT_EQ(graph.node(i).name(), order[i]); @@ -94,7 +95,7 @@ TEST_F(TopologicalSortTest, DuplicatedInputs) { *graph.add_node() = CreateNode("2", {"1", "1"}); *graph.add_node() = CreateNode("1", {}); - TopologicalSort(&graph); + TF_EXPECT_OK(TopologicalSort(&graph)); std::vector order = {"1", "2"}; for (int i = 0; i < order.size(); i++) { EXPECT_EQ(graph.node(i).name(), order[i]); -- cgit v1.2.3