diff options
Diffstat (limited to 'tensorflow/core/grappler/optimizers/data/graph_utils.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/data/graph_utils.cc | 173 |
1 files changed, 100 insertions, 73 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_utils.cc index aece142f7a..6ce6533369 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/graph_utils.cc @@ -16,11 +16,7 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/framework/device_base.h" -#include "tensorflow/core/grappler/clusters/virtual_cluster.h" -#include "tensorflow/core/grappler/graph_view.h" -#include "tensorflow/core/grappler/grappler_item.h" -#include "tensorflow/core/grappler/grappler_item_builder.h" -#include "tensorflow/core/grappler/optimizers/meta_optimizer.h" +#include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/util/ptr_util.h" namespace tensorflow { @@ -30,14 +26,12 @@ namespace { constexpr char kConstOpName[] = "Const"; -int FindNodeWithPredicate(const std::function<bool(const NodeDef&)>& predicate, - const GraphDef& graph) { - for (int i = 0; i < graph.node_size(); ++i) { - if (predicate(graph.node(i))) { - return i; - } - } - return -1; +template <typename Predicate, typename Collection> +int GetElementIdxWithPredicate(const Predicate& predicate, + const Collection& collection) { + auto it = std::find_if(collection.begin(), collection.end(), predicate); + if (it == collection.end()) return -1; + return std::distance(collection.begin(), it); } std::vector<int> CreateNameIndex(const GraphDef& graph) { @@ -66,13 +60,14 @@ std::vector<int> CreateInputIndex(const NodeDef& node) { return index; } -Status AddScalarConstNodeHelper( +NodeDef* AddScalarConstNodeHelper( DataType dtype, const std::function<void(TensorProto*)>& add_value, - GraphDef* graph, NodeDef** result) { - NodeDef* node = graph->add_node(); - node->set_op(kConstOpName); - SetUniqueName(kConstOpName, graph, node); - (*node->mutable_attr())["dtype"].set_type(dtype); + MutableGraphView* graph) { + NodeDef node; + node.set_op(kConstOpName); + SetUniqueGraphNodeName(kConstOpName, graph->GetGraph(), &node); + + (*node.mutable_attr())["dtype"].set_type(dtype); std::unique_ptr<tensorflow::TensorProto> tensor = tensorflow::MakeUnique<tensorflow::TensorProto>(); std::unique_ptr<tensorflow::TensorShapeProto> tensor_shape = @@ -80,75 +75,69 @@ Status AddScalarConstNodeHelper( tensor->set_allocated_tensor_shape(tensor_shape.release()); tensor->set_dtype(dtype); add_value(tensor.get()); - (*node->mutable_attr())["value"].set_allocated_tensor(tensor.release()); - *result = node; - return Status::OK(); + (*node.mutable_attr())["value"].set_allocated_tensor(tensor.release()); + + return graph->AddNode(std::move(node)); } } // namespace -Status AddNode(const string& name, const string& op, - const std::vector<string>& inputs, - const std::vector<std::pair<string, AttrValue>>& attributes, - GraphDef* graph, NodeDef** result) { - NodeDef* node = graph->add_node(); +NodeDef* AddNode(const string& name, const string& op, + const std::vector<string>& inputs, + const std::vector<std::pair<string, AttrValue>>& attributes, + MutableGraphView* graph) { + NodeDef node; if (!name.empty()) { - node->set_name(name); + node.set_name(name); } else { - SetUniqueName(op, graph, node); + SetUniqueGraphNodeName(op, graph->GetGraph(), &node); } - node->set_op(op); + node.set_op(op); for (const string& input : inputs) { - node->add_input(input); + node.add_input(input); } for (auto attr : attributes) { - (*node->mutable_attr())[attr.first] = attr.second; + (*node.mutable_attr())[attr.first] = attr.second; } - *result = node; - return Status::OK(); + return graph->AddNode(std::move(node)); } template <> -Status AddScalarConstNode(bool v, GraphDef* graph, NodeDef** result) { +NodeDef* AddScalarConstNode(bool v, MutableGraphView* graph) { return AddScalarConstNodeHelper( - DT_BOOL, [v](TensorProto* proto) { proto->add_bool_val(v); }, graph, - result); + DT_BOOL, [v](TensorProto* proto) { proto->add_bool_val(v); }, graph); } template <> -Status AddScalarConstNode(double v, GraphDef* graph, NodeDef** result) { +NodeDef* AddScalarConstNode(double v, MutableGraphView* graph) { return AddScalarConstNodeHelper( - DT_DOUBLE, [v](TensorProto* proto) { proto->add_double_val(v); }, graph, - result); + DT_DOUBLE, [v](TensorProto* proto) { proto->add_double_val(v); }, graph); } template <> -Status AddScalarConstNode(float v, GraphDef* graph, NodeDef** result) { +NodeDef* AddScalarConstNode(float v, MutableGraphView* graph) { return AddScalarConstNodeHelper( - DT_FLOAT, [v](TensorProto* proto) { proto->add_float_val(v); }, graph, - result); + DT_FLOAT, [v](TensorProto* proto) { proto->add_float_val(v); }, graph); } template <> -Status AddScalarConstNode(int v, GraphDef* graph, NodeDef** result) { +NodeDef* AddScalarConstNode(int v, MutableGraphView* graph) { return AddScalarConstNodeHelper( - DT_INT32, [v](TensorProto* proto) { proto->add_int_val(v); }, graph, - result); + DT_INT32, [v](TensorProto* proto) { proto->add_int_val(v); }, graph); } template <> -Status AddScalarConstNode(int64 v, GraphDef* graph, NodeDef** result) { +NodeDef* AddScalarConstNode(int64 v, MutableGraphView* graph) { return AddScalarConstNodeHelper( - DT_INT64, [v](TensorProto* proto) { proto->add_int64_val(v); }, graph, - result); + DT_INT64, [v](TensorProto* proto) { proto->add_int64_val(v); }, graph); } template <> -Status AddScalarConstNode(StringPiece v, GraphDef* graph, NodeDef** result) { +NodeDef* AddScalarConstNode(StringPiece v, MutableGraphView* graph) { return AddScalarConstNodeHelper( DT_STRING, [v](TensorProto* proto) { proto->add_string_val(v.data(), v.size()); }, - graph, result); + graph); } bool Compare(const GraphDef& g1, const GraphDef& g2) { @@ -181,44 +170,82 @@ bool Compare(const GraphDef& g1, const GraphDef& g2) { return true; } -bool ContainsNodeWithName(const string& name, const GraphDef& graph) { - return FindNodeWithName(name, graph) != -1; +bool ContainsGraphNodeWithName(const string& name, const GraphDef& graph) { + return FindGraphNodeWithName(name, graph) != -1; } bool ContainsNodeWithOp(const string& op, const GraphDef& graph) { return FindNodeWithOp(op, graph) != -1; } -Status DeleteNodes(const std::set<string>& nodes_to_delete, GraphDef* graph) { - int last = graph->node_size() - 1; - for (int i = graph->node_size() - 1; i >= 0; --i) { - const NodeDef& node = graph->node(i); - if (nodes_to_delete.find(node.name()) != nodes_to_delete.end()) { - graph->mutable_node()->SwapElements(i, last); - last--; - } - } - graph->mutable_node()->DeleteSubrange(last + 1, - graph->node_size() - last - 1); - return Status::OK(); +bool ContainsGraphFunctionWithName(const string& name, + const FunctionDefLibrary& library) { + return FindGraphFunctionWithName(name, library) != -1; +} + +bool ContainsFunctionNodeWithName(const string& name, + const FunctionDef& function) { + return FindFunctionNodeWithName(name, function) != -1; } -int FindNodeWithName(const string& name, const GraphDef& graph) { - return FindNodeWithPredicate( - [name](const NodeDef& node) { return node.name() == name; }, graph); +int FindGraphNodeWithName(const string& name, const GraphDef& graph) { + return GetElementIdxWithPredicate( + [&name](const NodeDef& node) { return node.name() == name; }, + graph.node()); } int FindNodeWithOp(const string& op, const GraphDef& graph) { - return FindNodeWithPredicate( - [op](const NodeDef& node) { return node.op() == op; }, graph); + return GetElementIdxWithPredicate( + [&op](const NodeDef& node) { return node.op() == op; }, graph.node()); +} + +int FindGraphFunctionWithName(const string& name, + const FunctionDefLibrary& library) { + return GetElementIdxWithPredicate( + [&name](const FunctionDef& function) { + return function.signature().name() == name; + }, + library.function()); } -void SetUniqueName(const string& op, GraphDef* graph, NodeDef* node) { +int FindFunctionNodeWithName(const string& name, const FunctionDef& function) { + return GetElementIdxWithPredicate( + [&name](const NodeDef& node) { return node.name() == name; }, + function.node_def()); +} + +void SetUniqueGraphNodeName(const string& prefix, GraphDef* graph, + NodeDef* node) { + string name = prefix; int id = graph->node_size(); - while (ContainsNodeWithName(strings::StrCat(op, "/_", id), *graph)) { + while (ContainsGraphNodeWithName(name, *graph)) { + name = strings::StrCat(prefix, "/_", id); + ++id; + } + node->set_name(std::move(name)); +} + +void SetUniqueFunctionNodeName(const string& prefix, FunctionDef* function, + NodeDef* node) { + string name = prefix; + int id = function->node_def_size(); + while (ContainsFunctionNodeWithName(name, *function)) { + name = strings::StrCat(prefix, "/_", id); + ++id; + } + node->set_name(std::move(name)); +} + +void SetUniqueGraphFunctionName(const string& prefix, + FunctionDefLibrary* library, + FunctionDef* function) { + string name = prefix; + int id = library->function_size(); + while (ContainsGraphFunctionWithName(name, *library)) { + name = strings::StrCat(prefix, "/_", id); ++id; } - node->set_name(strings::StrCat(op, "/_", id)); + function->mutable_signature()->set_name(name); } } // end namespace graph_utils |