aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/graph_view.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/graph_view.cc')
-rw-r--r--tensorflow/core/grappler/graph_view.cc49
1 files changed, 49 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/graph_view.cc b/tensorflow/core/grappler/graph_view.cc
index 0d3f94854b..3e448216f9 100644
--- a/tensorflow/core/grappler/graph_view.cc
+++ b/tensorflow/core/grappler/graph_view.cc
@@ -173,5 +173,54 @@ int GraphView::NumFanins(const NodeDef& node,
return count;
}
+std::unordered_set<GraphView::Edge, GraphView::HashEdge>
+GraphView::GetFanoutEdges(const NodeDef& node,
+ bool include_controlled_edges) const {
+ std::unordered_set<Edge, HashEdge> result;
+ OutputPort port;
+ port.node = const_cast<NodeDef*>(&node);
+ const int first_port_id = include_controlled_edges ? -1 : 0;
+ auto it = num_regular_outputs_.find(&node);
+ const int last_port_id = (it != num_regular_outputs_.end()) ? it->second : -1;
+
+ for (int i = first_port_id; i <= last_port_id; ++i) {
+ port.port_id = i;
+ auto it = fanouts_.find(port);
+ if (it != fanouts_.end()) {
+ Edge fanout;
+ fanout.src.node = const_cast<NodeDef*>(&node);
+ fanout.src.port_id = i;
+ for (auto itr = it->second.begin(); itr != it->second.end(); ++itr) {
+ fanout.tgt = *itr;
+ result.insert(fanout);
+ }
+ }
+ }
+ return result;
+}
+
+std::unordered_set<GraphView::Edge, GraphView::HashEdge>
+GraphView::GetFaninEdges(const NodeDef& node,
+ bool include_controlling_edges) const {
+ std::unordered_set<Edge, HashEdge> result;
+ for (int i = 0; i < node.input_size(); ++i) {
+ Edge fanin;
+ fanin.tgt.node = const_cast<NodeDef*>(&node);
+ fanin.tgt.port_id = i;
+ string fanin_name = ParseNodeName(node.input(i), &fanin.src.port_id);
+ if (fanin.src.port_id < 0) {
+ if (!include_controlling_edges) {
+ break;
+ }
+ }
+ auto it = nodes_.find(fanin_name);
+ if (it != nodes_.end()) {
+ fanin.src.node = it->second;
+ result.insert(fanin);
+ }
+ }
+ return result;
+}
+
} // end namespace grappler
} // end namespace tensorflow