diff options
Diffstat (limited to 'tensorflow/core/grappler/graph_view.cc')
-rw-r--r-- | tensorflow/core/grappler/graph_view.cc | 49 |
1 files changed, 49 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/graph_view.cc b/tensorflow/core/grappler/graph_view.cc index 0d3f94854b..3e448216f9 100644 --- a/tensorflow/core/grappler/graph_view.cc +++ b/tensorflow/core/grappler/graph_view.cc @@ -173,5 +173,54 @@ int GraphView::NumFanins(const NodeDef& node, return count; } +std::unordered_set<GraphView::Edge, GraphView::HashEdge> +GraphView::GetFanoutEdges(const NodeDef& node, + bool include_controlled_edges) const { + std::unordered_set<Edge, HashEdge> result; + OutputPort port; + port.node = const_cast<NodeDef*>(&node); + const int first_port_id = include_controlled_edges ? -1 : 0; + auto it = num_regular_outputs_.find(&node); + const int last_port_id = (it != num_regular_outputs_.end()) ? it->second : -1; + + for (int i = first_port_id; i <= last_port_id; ++i) { + port.port_id = i; + auto it = fanouts_.find(port); + if (it != fanouts_.end()) { + Edge fanout; + fanout.src.node = const_cast<NodeDef*>(&node); + fanout.src.port_id = i; + for (auto itr = it->second.begin(); itr != it->second.end(); ++itr) { + fanout.tgt = *itr; + result.insert(fanout); + } + } + } + return result; +} + +std::unordered_set<GraphView::Edge, GraphView::HashEdge> +GraphView::GetFaninEdges(const NodeDef& node, + bool include_controlling_edges) const { + std::unordered_set<Edge, HashEdge> result; + for (int i = 0; i < node.input_size(); ++i) { + Edge fanin; + fanin.tgt.node = const_cast<NodeDef*>(&node); + fanin.tgt.port_id = i; + string fanin_name = ParseNodeName(node.input(i), &fanin.src.port_id); + if (fanin.src.port_id < 0) { + if (!include_controlling_edges) { + break; + } + } + auto it = nodes_.find(fanin_name); + if (it != nodes_.end()) { + fanin.src.node = it->second; + result.insert(fanin); + } + } + return result; +} + } // end namespace grappler } // end namespace tensorflow |