diff options
Diffstat (limited to 'tensorflow/core/grappler/graph_view.h')
-rw-r--r-- | tensorflow/core/grappler/graph_view.h | 20 |
1 files changed, 16 insertions, 4 deletions
diff --git a/tensorflow/core/grappler/graph_view.h b/tensorflow/core/grappler/graph_view.h index 584cb9048b..050789d2e2 100644 --- a/tensorflow/core/grappler/graph_view.h +++ b/tensorflow/core/grappler/graph_view.h @@ -29,8 +29,11 @@ namespace grappler { class GraphView { public: struct Port { - Port() : node(nullptr), port_id(-1) {} + Port() = default; Port(NodeDef* n, int port) : node(n), port_id(port) {} + + // TODO(prazek): ports should keep the constness of GraphView. The only way + // to modify graph through the view should be using MutableGraphView. NodeDef* node = nullptr; int port_id = -1; @@ -111,13 +114,22 @@ class GraphView { std::unordered_set<Edge, HashEdge> GetFaninEdges( const NodeDef& node, bool include_controlling_edges) const; + protected: + // Add fanout to every `node` input. + void AddFanouts(NodeDef* node); + std::unordered_map<string, NodeDef*>* MutableNodes() { return &nodes_; } + GraphDef* MutableGraph() { return graph_; } + + using FanoutsMapType = + std::unordered_map<OutputPort, std::unordered_set<InputPort, HashPort>, + HashPort>; + FanoutsMapType* MutableFanouts() { return &fanouts_; } + private: GraphDef* graph_; std::unordered_map<string, NodeDef*> nodes_; std::unordered_set<InputPort, HashPort> empty_set_; - std::unordered_map<OutputPort, std::unordered_set<InputPort, HashPort>, - HashPort> - fanouts_; + FanoutsMapType fanouts_; std::unordered_map<const NodeDef*, int> num_regular_outputs_; }; |