aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/graph_view.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/graph_view.h')
-rw-r--r--tensorflow/core/grappler/graph_view.h20
1 files changed, 16 insertions, 4 deletions
diff --git a/tensorflow/core/grappler/graph_view.h b/tensorflow/core/grappler/graph_view.h
index 584cb9048b..050789d2e2 100644
--- a/tensorflow/core/grappler/graph_view.h
+++ b/tensorflow/core/grappler/graph_view.h
@@ -29,8 +29,11 @@ namespace grappler {
class GraphView {
public:
struct Port {
- Port() : node(nullptr), port_id(-1) {}
+ Port() = default;
Port(NodeDef* n, int port) : node(n), port_id(port) {}
+
+ // TODO(prazek): ports should keep the constness of GraphView. The only way
+ // to modify graph through the view should be using MutableGraphView.
NodeDef* node = nullptr;
int port_id = -1;
@@ -111,13 +114,22 @@ class GraphView {
std::unordered_set<Edge, HashEdge> GetFaninEdges(
const NodeDef& node, bool include_controlling_edges) const;
+ protected:
+ // Add fanout to every `node` input.
+ void AddFanouts(NodeDef* node);
+ std::unordered_map<string, NodeDef*>* MutableNodes() { return &nodes_; }
+ GraphDef* MutableGraph() { return graph_; }
+
+ using FanoutsMapType =
+ std::unordered_map<OutputPort, std::unordered_set<InputPort, HashPort>,
+ HashPort>;
+ FanoutsMapType* MutableFanouts() { return &fanouts_; }
+
private:
GraphDef* graph_;
std::unordered_map<string, NodeDef*> nodes_;
std::unordered_set<InputPort, HashPort> empty_set_;
- std::unordered_map<OutputPort, std::unordered_set<InputPort, HashPort>,
- HashPort>
- fanouts_;
+ FanoutsMapType fanouts_;
std::unordered_map<const NodeDef*, int> num_regular_outputs_;
};