aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-03 16:12:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-03 16:20:00 -0700
commit1b414c84c138ca67da96e08ded600234fbd2427b (patch)
treeaaac9ea3fee948dc0f3ca4ae7110bdf6cdaa4081
parent8cd7135ea19fc184c873a1ba1463ff4b77f3c7ad (diff)
Adds a function to sort a graph in reverse topological order.
PiperOrigin-RevId: 207340526
-rw-r--r--tensorflow/core/grappler/utils/topological_sort.cc9
-rw-r--r--tensorflow/core/grappler/utils/topological_sort.h3
2 files changed, 12 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/utils/topological_sort.cc b/tensorflow/core/grappler/utils/topological_sort.cc
index ff89035902..63ca92c69e 100644
--- a/tensorflow/core/grappler/utils/topological_sort.cc
+++ b/tensorflow/core/grappler/utils/topological_sort.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/utils/topological_sort.h"
+#include <algorithm>
#include <deque>
#include <unordered_map>
#include "tensorflow/core/framework/node_def.pb.h"
@@ -85,6 +86,14 @@ Status ComputeTopologicalOrder(
return Status::OK();
}
+Status ReversedTopologicalSort(GraphDef* graph) {
+ std::vector<int> ready_nodes;
+ TF_RETURN_IF_ERROR(ComputeTopologicalOrder(*graph, &ready_nodes, nullptr));
+ std::reverse(ready_nodes.begin(), ready_nodes.end());
+ PermuteNodesInPlace(graph, &ready_nodes, /*invert_permutation=*/true);
+ return Status::OK();
+}
+
Status TopologicalSort(GraphDef* graph) {
std::vector<int> ready_nodes;
TF_RETURN_IF_ERROR(ComputeTopologicalOrder(*graph, &ready_nodes, nullptr));
diff --git a/tensorflow/core/grappler/utils/topological_sort.h b/tensorflow/core/grappler/utils/topological_sort.h
index bc0299a7b8..b8cf897a32 100644
--- a/tensorflow/core/grappler/utils/topological_sort.h
+++ b/tensorflow/core/grappler/utils/topological_sort.h
@@ -31,6 +31,9 @@ Status ComputeTopologicalOrder(
// Sort a graph in topological order.
Status TopologicalSort(GraphDef* graph);
+// Sort a graph in topological order and reverse it.
+Status ReversedTopologicalSort(GraphDef* graph);
+
} // namespace grappler
} // namespace tensorflow