aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/data/graph_utils.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/optimizers/data/graph_utils.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils.cc173
1 files changed, 100 insertions, 73 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
index aece142f7a..6ce6533369 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
@@ -16,11 +16,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/framework/device_base.h"
-#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
-#include "tensorflow/core/grappler/graph_view.h"
-#include "tensorflow/core/grappler/grappler_item.h"
-#include "tensorflow/core/grappler/grappler_item_builder.h"
-#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
+#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
@@ -30,14 +26,12 @@ namespace {
constexpr char kConstOpName[] = "Const";
-int FindNodeWithPredicate(const std::function<bool(const NodeDef&)>& predicate,
- const GraphDef& graph) {
- for (int i = 0; i < graph.node_size(); ++i) {
- if (predicate(graph.node(i))) {
- return i;
- }
- }
- return -1;
+template <typename Predicate, typename Collection>
+int GetElementIdxWithPredicate(const Predicate& predicate,
+ const Collection& collection) {
+ auto it = std::find_if(collection.begin(), collection.end(), predicate);
+ if (it == collection.end()) return -1;
+ return std::distance(collection.begin(), it);
}
std::vector<int> CreateNameIndex(const GraphDef& graph) {
@@ -66,13 +60,14 @@ std::vector<int> CreateInputIndex(const NodeDef& node) {
return index;
}
-Status AddScalarConstNodeHelper(
+NodeDef* AddScalarConstNodeHelper(
DataType dtype, const std::function<void(TensorProto*)>& add_value,
- GraphDef* graph, NodeDef** result) {
- NodeDef* node = graph->add_node();
- node->set_op(kConstOpName);
- SetUniqueName(kConstOpName, graph, node);
- (*node->mutable_attr())["dtype"].set_type(dtype);
+ MutableGraphView* graph) {
+ NodeDef node;
+ node.set_op(kConstOpName);
+ SetUniqueGraphNodeName(kConstOpName, graph->GetGraph(), &node);
+
+ (*node.mutable_attr())["dtype"].set_type(dtype);
std::unique_ptr<tensorflow::TensorProto> tensor =
tensorflow::MakeUnique<tensorflow::TensorProto>();
std::unique_ptr<tensorflow::TensorShapeProto> tensor_shape =
@@ -80,75 +75,69 @@ Status AddScalarConstNodeHelper(
tensor->set_allocated_tensor_shape(tensor_shape.release());
tensor->set_dtype(dtype);
add_value(tensor.get());
- (*node->mutable_attr())["value"].set_allocated_tensor(tensor.release());
- *result = node;
- return Status::OK();
+ (*node.mutable_attr())["value"].set_allocated_tensor(tensor.release());
+
+ return graph->AddNode(std::move(node));
}
} // namespace
-Status AddNode(const string& name, const string& op,
- const std::vector<string>& inputs,
- const std::vector<std::pair<string, AttrValue>>& attributes,
- GraphDef* graph, NodeDef** result) {
- NodeDef* node = graph->add_node();
+NodeDef* AddNode(const string& name, const string& op,
+ const std::vector<string>& inputs,
+ const std::vector<std::pair<string, AttrValue>>& attributes,
+ MutableGraphView* graph) {
+ NodeDef node;
if (!name.empty()) {
- node->set_name(name);
+ node.set_name(name);
} else {
- SetUniqueName(op, graph, node);
+ SetUniqueGraphNodeName(op, graph->GetGraph(), &node);
}
- node->set_op(op);
+ node.set_op(op);
for (const string& input : inputs) {
- node->add_input(input);
+ node.add_input(input);
}
for (auto attr : attributes) {
- (*node->mutable_attr())[attr.first] = attr.second;
+ (*node.mutable_attr())[attr.first] = attr.second;
}
- *result = node;
- return Status::OK();
+ return graph->AddNode(std::move(node));
}
template <>
-Status AddScalarConstNode(bool v, GraphDef* graph, NodeDef** result) {
+NodeDef* AddScalarConstNode(bool v, MutableGraphView* graph) {
return AddScalarConstNodeHelper(
- DT_BOOL, [v](TensorProto* proto) { proto->add_bool_val(v); }, graph,
- result);
+ DT_BOOL, [v](TensorProto* proto) { proto->add_bool_val(v); }, graph);
}
template <>
-Status AddScalarConstNode(double v, GraphDef* graph, NodeDef** result) {
+NodeDef* AddScalarConstNode(double v, MutableGraphView* graph) {
return AddScalarConstNodeHelper(
- DT_DOUBLE, [v](TensorProto* proto) { proto->add_double_val(v); }, graph,
- result);
+ DT_DOUBLE, [v](TensorProto* proto) { proto->add_double_val(v); }, graph);
}
template <>
-Status AddScalarConstNode(float v, GraphDef* graph, NodeDef** result) {
+NodeDef* AddScalarConstNode(float v, MutableGraphView* graph) {
return AddScalarConstNodeHelper(
- DT_FLOAT, [v](TensorProto* proto) { proto->add_float_val(v); }, graph,
- result);
+ DT_FLOAT, [v](TensorProto* proto) { proto->add_float_val(v); }, graph);
}
template <>
-Status AddScalarConstNode(int v, GraphDef* graph, NodeDef** result) {
+NodeDef* AddScalarConstNode(int v, MutableGraphView* graph) {
return AddScalarConstNodeHelper(
- DT_INT32, [v](TensorProto* proto) { proto->add_int_val(v); }, graph,
- result);
+ DT_INT32, [v](TensorProto* proto) { proto->add_int_val(v); }, graph);
}
template <>
-Status AddScalarConstNode(int64 v, GraphDef* graph, NodeDef** result) {
+NodeDef* AddScalarConstNode(int64 v, MutableGraphView* graph) {
return AddScalarConstNodeHelper(
- DT_INT64, [v](TensorProto* proto) { proto->add_int64_val(v); }, graph,
- result);
+ DT_INT64, [v](TensorProto* proto) { proto->add_int64_val(v); }, graph);
}
template <>
-Status AddScalarConstNode(StringPiece v, GraphDef* graph, NodeDef** result) {
+NodeDef* AddScalarConstNode(StringPiece v, MutableGraphView* graph) {
return AddScalarConstNodeHelper(
DT_STRING,
[v](TensorProto* proto) { proto->add_string_val(v.data(), v.size()); },
- graph, result);
+ graph);
}
bool Compare(const GraphDef& g1, const GraphDef& g2) {
@@ -181,44 +170,82 @@ bool Compare(const GraphDef& g1, const GraphDef& g2) {
return true;
}
-bool ContainsNodeWithName(const string& name, const GraphDef& graph) {
- return FindNodeWithName(name, graph) != -1;
+bool ContainsGraphNodeWithName(const string& name, const GraphDef& graph) {
+ return FindGraphNodeWithName(name, graph) != -1;
}
bool ContainsNodeWithOp(const string& op, const GraphDef& graph) {
return FindNodeWithOp(op, graph) != -1;
}
-Status DeleteNodes(const std::set<string>& nodes_to_delete, GraphDef* graph) {
- int last = graph->node_size() - 1;
- for (int i = graph->node_size() - 1; i >= 0; --i) {
- const NodeDef& node = graph->node(i);
- if (nodes_to_delete.find(node.name()) != nodes_to_delete.end()) {
- graph->mutable_node()->SwapElements(i, last);
- last--;
- }
- }
- graph->mutable_node()->DeleteSubrange(last + 1,
- graph->node_size() - last - 1);
- return Status::OK();
+bool ContainsGraphFunctionWithName(const string& name,
+ const FunctionDefLibrary& library) {
+ return FindGraphFunctionWithName(name, library) != -1;
+}
+
+bool ContainsFunctionNodeWithName(const string& name,
+ const FunctionDef& function) {
+ return FindFunctionNodeWithName(name, function) != -1;
}
-int FindNodeWithName(const string& name, const GraphDef& graph) {
- return FindNodeWithPredicate(
- [name](const NodeDef& node) { return node.name() == name; }, graph);
+int FindGraphNodeWithName(const string& name, const GraphDef& graph) {
+ return GetElementIdxWithPredicate(
+ [&name](const NodeDef& node) { return node.name() == name; },
+ graph.node());
}
int FindNodeWithOp(const string& op, const GraphDef& graph) {
- return FindNodeWithPredicate(
- [op](const NodeDef& node) { return node.op() == op; }, graph);
+ return GetElementIdxWithPredicate(
+ [&op](const NodeDef& node) { return node.op() == op; }, graph.node());
+}
+
+int FindGraphFunctionWithName(const string& name,
+ const FunctionDefLibrary& library) {
+ return GetElementIdxWithPredicate(
+ [&name](const FunctionDef& function) {
+ return function.signature().name() == name;
+ },
+ library.function());
}
-void SetUniqueName(const string& op, GraphDef* graph, NodeDef* node) {
+int FindFunctionNodeWithName(const string& name, const FunctionDef& function) {
+ return GetElementIdxWithPredicate(
+ [&name](const NodeDef& node) { return node.name() == name; },
+ function.node_def());
+}
+
+void SetUniqueGraphNodeName(const string& prefix, GraphDef* graph,
+ NodeDef* node) {
+ string name = prefix;
int id = graph->node_size();
- while (ContainsNodeWithName(strings::StrCat(op, "/_", id), *graph)) {
+ while (ContainsGraphNodeWithName(name, *graph)) {
+ name = strings::StrCat(prefix, "/_", id);
+ ++id;
+ }
+ node->set_name(std::move(name));
+}
+
+void SetUniqueFunctionNodeName(const string& prefix, FunctionDef* function,
+ NodeDef* node) {
+ string name = prefix;
+ int id = function->node_def_size();
+ while (ContainsFunctionNodeWithName(name, *function)) {
+ name = strings::StrCat(prefix, "/_", id);
+ ++id;
+ }
+ node->set_name(std::move(name));
+}
+
+void SetUniqueGraphFunctionName(const string& prefix,
+ FunctionDefLibrary* library,
+ FunctionDef* function) {
+ string name = prefix;
+ int id = library->function_size();
+ while (ContainsGraphFunctionWithName(name, *library)) {
+ name = strings::StrCat(prefix, "/_", id);
++id;
}
- node->set_name(strings::StrCat(op, "/_", id));
+ function->mutable_signature()->set_name(name);
}
} // end namespace graph_utils