diff options
Diffstat (limited to 'tensorflow/core/grappler/optimizers/data/graph_utils.h')
-rw-r--r-- | tensorflow/core/grappler/optimizers/data/graph_utils.h | 72 |
1 files changed, 52 insertions, 20 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.h b/tensorflow/core/grappler/optimizers/data/graph_utils.h index 3d2467031f..0847748802 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_utils.h +++ b/tensorflow/core/grappler/optimizers/data/graph_utils.h @@ -17,11 +17,13 @@ limitations under the License. #define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_UTILS_H_ #include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/grappler/mutable_graph_view.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/lib/core/errors.h" @@ -30,53 +32,83 @@ namespace grappler { namespace graph_utils { // Adds a node to the graph. -Status AddNode(const string& name, const string& op, - const std::vector<string>& inputs, - const std::vector<std::pair<string, AttrValue>>& attributes, - GraphDef* graph, NodeDef** result); +NodeDef* AddNode(const string& name, const string& op, + const std::vector<string>& inputs, + const std::vector<std::pair<string, AttrValue>>& attributes, + MutableGraphView* graph); // Adds a Const node with the given value to the graph. template <typename T> -Status AddScalarConstNode(T v, GraphDef* graph, NodeDef** result) { - return errors::Unimplemented("Type %s is not supported.", - DataTypeToEnum<T>::value); +NodeDef* AddScalarConstNode(T v, MutableGraphView* graph) { + // is_same is an idiomatic hack for making it compile if not instantiated. + // Replacing with false will result in a compile-time error. + static_assert(!std::is_same<T, T>::value, + "Invalid specialization of this method for type T."); + return {}; } + template <> -Status AddScalarConstNode(bool v, GraphDef* graph, NodeDef** result); +NodeDef* AddScalarConstNode(bool v, MutableGraphView* graph); template <> -Status AddScalarConstNode(double v, GraphDef* graph, NodeDef** result); +NodeDef* AddScalarConstNode(double v, MutableGraphView* graph); template <> -Status AddScalarConstNode(float v, GraphDef* graph, NodeDef** result); +NodeDef* AddScalarConstNode(float v, MutableGraphView* graph); template <> -Status AddScalarConstNode(int v, GraphDef* graph, NodeDef** result); +NodeDef* AddScalarConstNode(int v, MutableGraphView* graph); template <> -Status AddScalarConstNode(int64 v, GraphDef* graph, NodeDef** result); +NodeDef* AddScalarConstNode(int64 v, MutableGraphView* graph); template <> -Status AddScalarConstNode(StringPiece v, GraphDef* graph, NodeDef** result); +NodeDef* AddScalarConstNode(StringPiece v, MutableGraphView* graph); // Checks whether the two graphs are the same. bool Compare(const GraphDef& g1, const GraphDef& g2); // Checks whether the graph contains a node with the given name. -bool ContainsNodeWithName(const string& name, const GraphDef& graph); +bool ContainsGraphNodeWithName(const string& name, const GraphDef& graph); + +// Checks whether the library contains a function with the given name. +bool ContainsGraphFunctionWithName(const string& name, + const FunctionDefLibrary& library); + +// Checks whether the function contains a node with the given name. +bool ContainsFunctionNodeWithName(const string& name, + const FunctionDef& function); // Checks whether the graph contains a node with the given op. bool ContainsNodeWithOp(const string& op, const GraphDef& graph); -// Deletes nodes from the graph. -Status DeleteNodes(const std::set<string>& nodes_to_delete, GraphDef* graph); - // Returns the index of the node with the given name or -1 if the node does // not exist. -int FindNodeWithName(const string& name, const GraphDef& graph); +int FindGraphNodeWithName(const string& name, const GraphDef& graph); + +// Returns the index of the function with the given name or -1 if the function +// does not exist. +int FindGraphFunctionWithName(const string& name, + const FunctionDefLibrary& library); + +// Returns the index of the function node with the given name or -1 if the +// function node does not exist. +int FindFunctionNodeWithName(const string& name, const FunctionDef& function); // Returns the index of a node with the given op or -1 if no such node // exists. int FindNodeWithOp(const string& op, const GraphDef& graph); -// Sets the node name using the op name as a prefix while guaranteeing the name +// Sets the node name using `prefix` as a prefix while guaranteeing the name // is unique across the graph. -void SetUniqueName(const string& op, GraphDef* graph, NodeDef* node); +void SetUniqueGraphNodeName(const string& prefix, GraphDef* graph, + NodeDef* node); + +// Sets the function node name using the `prefix` as a prefix while guaranteeing +// the name is unique across the functions nodes. +void SetUniqueFunctionNodeName(const string& prefix, FunctionDef* function, + NodeDef* node); + +// Sets the node name using the `prefix` name as a prefix while guaranteeing the +// name is unique across the graph. +void SetUniqueGraphFunctionName(const string& prefix, + FunctionDefLibrary* library, + FunctionDef* function); } // end namespace graph_utils } // end namespace grappler |