aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/data/graph_utils.h
blob: d130fee2047e5be49857dea6ac6489f93088aa50 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_UTILS_H_
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_UTILS_H_

#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/grappler/mutable_graph_view.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/core/errors.h"

namespace tensorflow {
namespace grappler {
namespace graph_utils {

// Returns the index of the first element in collection that fulfills predicate.
// If no such element exists, returns -1.
template <typename Predicate, typename Collection>
int GetFirstElementIndexWithPredicate(const Predicate& predicate,
                                      const Collection& collection) {
  unsigned idx = 0;
  for (auto&& element : collection) {
    if (predicate(element)) {
      return idx;
    }
    idx++;
  }
  return -1;
}

// Adds a node to the graph.
NodeDef* AddNode(StringPiece name, StringPiece op,
                 const std::vector<string>& inputs,
                 const std::vector<std::pair<string, AttrValue>>& attributes,
                 MutableGraphView* graph);

// Adds Placeholder node for given type.
NodeDef* AddScalarPlaceholder(DataType dtype, MutableGraphView* graph);

// Adds a Const node with the given value to the graph.
template <typename T>
NodeDef* AddScalarConstNode(T v, MutableGraphView* graph) {
  // is_same is an idiomatic hack for making it compile if not instantiated.
  // Replacing with false will result in a compile-time error.
  static_assert(!std::is_same<T, T>::value,
                "Invalid specialization of this method for type T.");
  return {};
}

template <>
NodeDef* AddScalarConstNode(bool v, MutableGraphView* graph);
template <>
NodeDef* AddScalarConstNode(double v, MutableGraphView* graph);
template <>
NodeDef* AddScalarConstNode(float v, MutableGraphView* graph);
template <>
NodeDef* AddScalarConstNode(int v, MutableGraphView* graph);
template <>
NodeDef* AddScalarConstNode(int64 v, MutableGraphView* graph);
template <>
NodeDef* AddScalarConstNode(StringPiece v, MutableGraphView* graph);

// Checks whether the two graphs are the same.
bool Compare(const GraphDef& g1, const GraphDef& g2);

// Checks whether the graph contains a node with the given name.
bool ContainsGraphNodeWithName(StringPiece name, const GraphDef& graph);

// Checks whether the library contains a function with the given name.
bool ContainsGraphFunctionWithName(StringPiece name,
                                   const FunctionDefLibrary& library);

// Checks whether the graph contains a node with the given op.
bool ContainsNodeWithOp(StringPiece op, const GraphDef& graph);

// Returns the index of the node with the given name or -1 if the node does
// not exist.
int FindGraphNodeWithName(StringPiece name, const GraphDef& graph);

// Returns the index of the function with the given name or -1 if the function
// does not exist.
int FindGraphFunctionWithName(StringPiece name,
                              const FunctionDefLibrary& library);

// Returns the index of the first node with the given op or -1 if no such  node
// exists.
int FindGraphNodeWithOp(StringPiece op, const GraphDef& graph);

// Gets the 0th input to a node in the graph.
NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph);

// Returns the list of indices of all nodes with the given op or empty list if
// no such node exists.
std::vector<int> FindAllGraphNodesWithOp(const string& op,
                                         const GraphDef& graph);

// Sets the node name using `prefix` as a prefix while guaranteeing the name
// is unique across the graph.
void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph, NodeDef* node);

// Sets the function name using the `prefix` name as a prefix while guaranteeing
// the name is unique across the function library.
void SetUniqueGraphFunctionName(StringPiece prefix, FunctionDefLibrary* library,
                                FunctionDef* function);

// Copies attribute having name `attribute_name` from node `from` to node
// `to_node`.
void CopyAttribute(const string& attribute_name, const NodeDef& from,
                   NodeDef* to_node);

// Concatenates list attribute having name `attribute_name` from `first` and
// `second` node, setting it to `to_node`.
void ConcatAttributeList(const string& attribute_name, const NodeDef& first,
                         const NodeDef& second, NodeDef* to_node);

// Checks that all nodes in the graphs have unique names, and sets their names
// to be unique if they are not already.  This is necessary as Graph does not
// have the provisions to deduplicate names, and name deduplication elsewhere
// in tensorflow happens in other layers (for example, in the Scope class of the
// C++ API). Note that the nodes in the graph are identified by their id,
// and renaming nodes does not mutate any edges.
Status EnsureNodeNamesUnique(Graph* g);

}  // end namespace graph_utils
}  // end namespace grappler
}  // end namespace tensorflow

#endif  // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_UTILS_H_