aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/utils
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-12-01 10:58:17 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-01 11:02:38 -0800
commit11749434e3eb04eee058a43a931a27bdee4916df (patch)
tree8fff987690f5d476127e0b0b30cb0d15c46b6b9d /tensorflow/core/grappler/utils
parent00791693e4d32bed92fcfadf09da321c9f548bab (diff)
Make TopologicalSort return an error status if the sorting fails.
PiperOrigin-RevId: 177612830
Diffstat (limited to 'tensorflow/core/grappler/utils')
-rw-r--r--tensorflow/core/grappler/utils/BUILD1
-rw-r--r--tensorflow/core/grappler/utils/topological_sort.cc25
-rw-r--r--tensorflow/core/grappler/utils/topological_sort.h3
-rw-r--r--tensorflow/core/grappler/utils/topological_sort_test.cc9
4 files changed, 22 insertions, 16 deletions
diff --git a/tensorflow/core/grappler/utils/BUILD b/tensorflow/core/grappler/utils/BUILD
index 21243833ac..534f7a063f 100644
--- a/tensorflow/core/grappler/utils/BUILD
+++ b/tensorflow/core/grappler/utils/BUILD
@@ -53,6 +53,7 @@ cc_library(
hdrs = ["topological_sort.h"],
visibility = ["//visibility:public"],
deps = [
+ "//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:op_types",
diff --git a/tensorflow/core/grappler/utils/topological_sort.cc b/tensorflow/core/grappler/utils/topological_sort.cc
index 77d4702d21..d87f43a498 100644
--- a/tensorflow/core/grappler/utils/topological_sort.cc
+++ b/tensorflow/core/grappler/utils/topological_sort.cc
@@ -19,13 +19,14 @@ limitations under the License.
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
namespace grappler {
// Kahn's algorithm is implemented.
// For details, see https://en.wikipedia.org/wiki/Topological_sorting
-void TopologicalSort(GraphDef* graph) {
+Status TopologicalSort(GraphDef* graph) {
OutputMap output_map(graph);
std::vector<NodeDef*> ready_nodes;
ready_nodes.reserve(graph->node_size());
@@ -63,17 +64,19 @@ void TopologicalSort(GraphDef* graph) {
front++;
}
- if (back == graph->node_size()) {
- GraphDef new_graph;
- new_graph.mutable_node()->Reserve(graph->node_size());
- for (int i = 0; i < graph->node_size(); i++) {
- auto new_node = new_graph.add_node();
- new_node->Swap(ready_nodes[i]);
- }
- graph->mutable_node()->Swap(new_graph.mutable_node());
- } else {
- LOG(ERROR) << "The graph couldn't be sorted in topological order.";
+ if (back != graph->node_size()) {
+ return errors::InvalidArgument(
+ "The graph couldn't be sorted in topological order.");
+ }
+
+ GraphDef new_graph;
+ new_graph.mutable_node()->Reserve(graph->node_size());
+ for (int i = 0; i < graph->node_size(); i++) {
+ auto new_node = new_graph.add_node();
+ new_node->Swap(ready_nodes[i]);
}
+ graph->mutable_node()->Swap(new_graph.mutable_node());
+ return Status::OK();
}
} // namespace grappler
diff --git a/tensorflow/core/grappler/utils/topological_sort.h b/tensorflow/core/grappler/utils/topological_sort.h
index d4d8034ef5..f2c9bbfa4e 100644
--- a/tensorflow/core/grappler/utils/topological_sort.h
+++ b/tensorflow/core/grappler/utils/topological_sort.h
@@ -17,12 +17,13 @@ limitations under the License.
#define THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_UTILS_TOPOLOGICAL_SORT_H_
#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
namespace grappler {
// Sort a graph in topological order.
-void TopologicalSort(GraphDef* graph);
+Status TopologicalSort(GraphDef* graph);
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/utils/topological_sort_test.cc b/tensorflow/core/grappler/utils/topological_sort_test.cc
index dc99cb1052..ba0fe0155a 100644
--- a/tensorflow/core/grappler/utils/topological_sort_test.cc
+++ b/tensorflow/core/grappler/utils/topological_sort_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/grappler/utils/topological_sort.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
@@ -51,7 +52,7 @@ TEST_F(TopologicalSortTest, NoLoop) {
*graph.add_node() = CreateNode("5", {});
*graph.add_node() = CreateNode("4", {});
- TopologicalSort(&graph);
+ TF_EXPECT_OK(TopologicalSort(&graph));
std::vector<string> order = {"5", "4", "2", "0", "3", "1"};
for (int i = 0; i < order.size(); i++) {
EXPECT_EQ(graph.node(i).name(), order[i]);
@@ -67,7 +68,7 @@ TEST_F(TopologicalSortTest, WithLoop) {
*graph.add_node() = CreateNode("5", "NextIteration", {"4"});
*graph.add_node() = CreateNode("1", {});
- TopologicalSort(&graph);
+ TF_EXPECT_OK(TopologicalSort(&graph));
std::vector<string> order = {"1", "2", "3", "4", "5"};
for (int i = 0; i < order.size(); i++) {
EXPECT_EQ(graph.node(i).name(), order[i]);
@@ -82,7 +83,7 @@ TEST_F(TopologicalSortTest, WithIllegalLoop) {
*graph.add_node() = CreateNode("3", {"2"});
*graph.add_node() = CreateNode("1", {});
- TopologicalSort(&graph);
+ EXPECT_FALSE(TopologicalSort(&graph).ok());
std::vector<string> order = {"2", "3", "1"};
for (int i = 0; i < order.size(); i++) {
EXPECT_EQ(graph.node(i).name(), order[i]);
@@ -94,7 +95,7 @@ TEST_F(TopologicalSortTest, DuplicatedInputs) {
*graph.add_node() = CreateNode("2", {"1", "1"});
*graph.add_node() = CreateNode("1", {});
- TopologicalSort(&graph);
+ TF_EXPECT_OK(TopologicalSort(&graph));
std::vector<string> order = {"1", "2"};
for (int i = 0; i < order.size(); i++) {
EXPECT_EQ(graph.node(i).name(), order[i]);