aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/data/graph_utils.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/optimizers/data/graph_utils.h')
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils.h72
1 files changed, 52 insertions, 20 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.h b/tensorflow/core/grappler/optimizers/data/graph_utils.h
index 3d2467031f..0847748802 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils.h
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils.h
@@ -17,11 +17,13 @@ limitations under the License.
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_UTILS_H_
#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/grappler/mutable_graph_view.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -30,53 +32,83 @@ namespace grappler {
namespace graph_utils {
// Adds a node to the graph.
-Status AddNode(const string& name, const string& op,
- const std::vector<string>& inputs,
- const std::vector<std::pair<string, AttrValue>>& attributes,
- GraphDef* graph, NodeDef** result);
+NodeDef* AddNode(const string& name, const string& op,
+ const std::vector<string>& inputs,
+ const std::vector<std::pair<string, AttrValue>>& attributes,
+ MutableGraphView* graph);
// Adds a Const node with the given value to the graph.
template <typename T>
-Status AddScalarConstNode(T v, GraphDef* graph, NodeDef** result) {
- return errors::Unimplemented("Type %s is not supported.",
- DataTypeToEnum<T>::value);
+NodeDef* AddScalarConstNode(T v, MutableGraphView* graph) {
+ // is_same is an idiomatic hack for making it compile if not instantiated.
+ // Replacing with false will result in a compile-time error.
+ static_assert(!std::is_same<T, T>::value,
+ "Invalid specialization of this method for type T.");
+ return {};
}
+
template <>
-Status AddScalarConstNode(bool v, GraphDef* graph, NodeDef** result);
+NodeDef* AddScalarConstNode(bool v, MutableGraphView* graph);
template <>
-Status AddScalarConstNode(double v, GraphDef* graph, NodeDef** result);
+NodeDef* AddScalarConstNode(double v, MutableGraphView* graph);
template <>
-Status AddScalarConstNode(float v, GraphDef* graph, NodeDef** result);
+NodeDef* AddScalarConstNode(float v, MutableGraphView* graph);
template <>
-Status AddScalarConstNode(int v, GraphDef* graph, NodeDef** result);
+NodeDef* AddScalarConstNode(int v, MutableGraphView* graph);
template <>
-Status AddScalarConstNode(int64 v, GraphDef* graph, NodeDef** result);
+NodeDef* AddScalarConstNode(int64 v, MutableGraphView* graph);
template <>
-Status AddScalarConstNode(StringPiece v, GraphDef* graph, NodeDef** result);
+NodeDef* AddScalarConstNode(StringPiece v, MutableGraphView* graph);
// Checks whether the two graphs are the same.
bool Compare(const GraphDef& g1, const GraphDef& g2);
// Checks whether the graph contains a node with the given name.
-bool ContainsNodeWithName(const string& name, const GraphDef& graph);
+bool ContainsGraphNodeWithName(const string& name, const GraphDef& graph);
+
+// Checks whether the library contains a function with the given name.
+bool ContainsGraphFunctionWithName(const string& name,
+ const FunctionDefLibrary& library);
+
+// Checks whether the function contains a node with the given name.
+bool ContainsFunctionNodeWithName(const string& name,
+ const FunctionDef& function);
// Checks whether the graph contains a node with the given op.
bool ContainsNodeWithOp(const string& op, const GraphDef& graph);
-// Deletes nodes from the graph.
-Status DeleteNodes(const std::set<string>& nodes_to_delete, GraphDef* graph);
-
// Returns the index of the node with the given name or -1 if the node does
// not exist.
-int FindNodeWithName(const string& name, const GraphDef& graph);
+int FindGraphNodeWithName(const string& name, const GraphDef& graph);
+
+// Returns the index of the function with the given name or -1 if the function
+// does not exist.
+int FindGraphFunctionWithName(const string& name,
+ const FunctionDefLibrary& library);
+
+// Returns the index of the function node with the given name or -1 if the
+// function node does not exist.
+int FindFunctionNodeWithName(const string& name, const FunctionDef& function);
// Returns the index of a node with the given op or -1 if no such node
// exists.
int FindNodeWithOp(const string& op, const GraphDef& graph);
-// Sets the node name using the op name as a prefix while guaranteeing the name
+// Sets the node name using `prefix` as a prefix while guaranteeing the name
// is unique across the graph.
-void SetUniqueName(const string& op, GraphDef* graph, NodeDef* node);
+void SetUniqueGraphNodeName(const string& prefix, GraphDef* graph,
+ NodeDef* node);
+
+// Sets the function node name using the `prefix` as a prefix while guaranteeing
+// the name is unique across the functions nodes.
+void SetUniqueFunctionNodeName(const string& prefix, FunctionDef* function,
+ NodeDef* node);
+
+// Sets the node name using the `prefix` name as a prefix while guaranteeing the
+// name is unique across the graph.
+void SetUniqueGraphFunctionName(const string& prefix,
+ FunctionDefLibrary* library,
+ FunctionDef* function);
} // end namespace graph_utils
} // end namespace grappler