aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/mutable_graph_view.h
blob: 971e5503d4ce908dbb86a4f127ac4da6bea95874 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_CORE_GRAPPLER_MUTABLE_GRAPH_VIEW_H_
#define TENSORFLOW_CORE_GRAPPLER_MUTABLE_GRAPH_VIEW_H_

#include "tensorflow/core/grappler/graph_view.h"

namespace tensorflow {
namespace grappler {

// A utility class to simplify the traversal of a GraphDef that, unlike
// GraphView, supports updating the graph.  Note that you should not modify the
// graph separately, because the view will get out of sync.
class MutableGraphView : public GraphView {
 public:
  using GraphView::GraphView;

  GraphDef* GetGraph() { return MutableGraph(); }

  // Adds a new node to graph and updates the view.
  NodeDef* AddNode(NodeDef&& node);

  // Inserts a new node to the graph after `input` node and updates the view.
  // This adds `node` to the graph and replaces the input for the output
  // nodes of `input` with a port `output_port_id` with the new node.
  NodeDef* InsertNode(const NodeDef& input, NodeDef&& node,
                      int output_port_id = 0);

  // Replaces the input for the output nodes of 'old_input' with a port
  // `output_port_id` with 'new_input'.
  //
  // E.g: We have 2 nodes that use 'bar' node outputs as inputs:
  // foo(bar:0, bar:1),  foo2(other:0, bar:0)
  // Calling ReplaceInput(bar, new, 0) changes every occurrence of bar:0 for
  // new:0.  Result:
  // foo(new:0, bar:1),  foo2(other:0, new:0)
  void ReplaceInput(const NodeDef& old_input, const NodeDef& new_input,
                    int output_port_id = 0);

  // Deletes nodes from the graph.
  void DeleteNodes(const std::set<string>& nodes_to_delete);

 private:
  void RemoveFanouts(NodeDef* node);
};

}  // end namespace grappler
}  // end namespace tensorflow

#endif  // TENSORFLOW_CORE_GRAPPLER_MUTABLE_GRAPH_VIEW_H_