aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/mutable_graph_view.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/mutable_graph_view.h')
-rw-r--r--tensorflow/core/grappler/mutable_graph_view.h56
1 files changed, 56 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/mutable_graph_view.h b/tensorflow/core/grappler/mutable_graph_view.h
new file mode 100644
index 0000000000..105eb972e8
--- /dev/null
+++ b/tensorflow/core/grappler/mutable_graph_view.h
@@ -0,0 +1,56 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_MUTABLE_GRAPH_VIEW_H_
+#define TENSORFLOW_CORE_GRAPPLER_MUTABLE_GRAPH_VIEW_H_
+
+#include "tensorflow/core/grappler/graph_view.h"
+
+namespace tensorflow {
+namespace grappler {
+
+// A utility class to simplify the traversal of a GraphDef that, unlike
+// GraphView, supports updating the graph. Note that you should not modify the
+// graph separately, because the view will get out of sync.
+class MutableGraphView : public GraphView {
+ public:
+ using GraphView::GraphView;
+
+ GraphDef* GetGraph() { return MutableGraph(); }
+ // Adds a new node to graph and updates the view.
+ NodeDef* AddNode(NodeDef&& node);
+
+ // Replaces the input for the output nodes of 'old_input' with a port
+ // `output_port_id` with 'new_input'.
+ //
+ // E.g: We have 2 nodes that use 'bar' node outputs as inputs:
+ // foo(bar:0, bar:1), foo2(other:0, bar:0)
+ // Calling ReplaceInput(bar, new, 0) changes every occurrence of bar:0 for
+ // new:0. Result:
+ // foo(new:0, bar:1), foo2(other:0, new:0)
+ void ReplaceInput(const NodeDef& old_input, const NodeDef& new_input,
+ int output_port_id = 0);
+
+ // Deletes nodes from the graph.
+ void DeleteNodes(const std::set<string>& nodes_to_delete);
+
+ private:
+ void RemoveFanouts(NodeDef* node);
+};
+
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_MUTABLE_GRAPH_VIEW_H_