diff options
author | Rachel Lim <rachelim@google.com> | 2018-10-03 11:24:41 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-03 11:29:40 -0700 |
commit | 47eafbaf43c763dc65a2cd3cfd9ecbd8fbbdf668 (patch) | |
tree | b9426571254706e767b5f09159c81b26bdf64d96 | |
parent | 880dcb7a91e5ee497045614d9c5f4ab93c9ffacf (diff) |
[tf.data] Add utility to deduplicate graph node names (after vectorization)
PiperOrigin-RevId: 215595078
-rw-r--r-- | tensorflow/core/graph/graph.cc | 5 | ||||
-rw-r--r-- | tensorflow/core/graph/graph.h | 1 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/data/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/data/graph_utils.cc | 21 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/data/graph_utils.h | 9 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/data/graph_utils_test.cc | 28 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/data/vectorization_utils.cc | 2 |
7 files changed, 68 insertions, 0 deletions
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index 4c0cd14ff1..7a4a0096fa 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -192,6 +192,11 @@ void Node::ClearAttr(const string& name) { (*props_->node_def.mutable_attr()).erase(name); } +void Node::set_name(string name) { + MaybeCopyOnWrite(); + props_->node_def.set_name(std::move(name)); +} + void Node::set_requested_device(const string& device) { MaybeCopyOnWrite(); props_->node_def.set_device(device); diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h index 72cef07072..2944951f82 100644 --- a/tensorflow/core/graph/graph.h +++ b/tensorflow/core/graph/graph.h @@ -72,6 +72,7 @@ class Node { int id() const { return id_; } int cost_id() const { return cost_id_; } const string& name() const; + void set_name(string name); const string& type_string() const; // def() provides the NodeDef the user supplied, but the specifics diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD index 5a3abbb545..755af3361e 100644 --- a/tensorflow/core/grappler/optimizers/data/BUILD +++ b/tensorflow/core/grappler/optimizers/data/BUILD @@ -129,6 +129,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core/grappler:mutable_graph_view", "//tensorflow/core/grappler:utils", + "//tensorflow/core:lib_internal", ] + tf_protos_all(), ) @@ -138,6 +139,7 @@ tf_cc_test( visibility = ["//visibility:public"], deps = [ ":graph_utils", + "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:test", "//tensorflow/core:test_main", diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_utils.cc index 3eaaf8fbef..b863a25dc5 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/graph_utils.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/util/ptr_util.h" namespace tensorflow { @@ -272,6 +273,26 @@ void ConcatAttributeList(const string& attribute_name, const NodeDef& first, ->MergeFrom(second.attr().at(attribute_name).list()); } +Status EnsureNodeNamesUnique(Graph* g) { + // Modeled after Scope::Impl::GetUniqueName + std::unordered_map<string, int> name_map; + + for (auto node : g->op_nodes()) { + const string& prefix = node->name(); + if (auto entry = gtl::FindOrNull(name_map, prefix)) { + string unique_name; + do { + unique_name = strings::StrCat(prefix, "_", ++(*entry)); + } while (name_map.find(unique_name) != name_map.end()); + name_map.insert({unique_name, 0}); + node->set_name(std::move(unique_name)); + } else { + name_map.insert({node->name(), 0}); + } + } + + return Status::OK(); +} } // end namespace graph_utils } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.h b/tensorflow/core/grappler/optimizers/data/graph_utils.h index 3af34f6904..d130fee204 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_utils.h +++ b/tensorflow/core/grappler/optimizers/data/graph_utils.h @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/grappler/mutable_graph_view.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/lib/core/errors.h" @@ -131,6 +132,14 @@ void CopyAttribute(const string& attribute_name, const NodeDef& from, void ConcatAttributeList(const string& attribute_name, const NodeDef& first, const NodeDef& second, NodeDef* to_node); +// Checks that all nodes in the graphs have unique names, and sets their names +// to be unique if they are not already. This is necessary as Graph does not +// have the provisions to deduplicate names, and name deduplication elsewhere +// in tensorflow happens in other layers (for example, in the Scope class of the +// C++ API). Note that the nodes in the graph are identified by their id, +// and renaming nodes does not mutate any edges. +Status EnsureNodeNamesUnique(Graph* g); + } // end namespace graph_utils } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc index db986542b2..4ab6d71532 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc +++ b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -229,6 +230,33 @@ TEST(GraphUtilsTest, GetInputNode) { EXPECT_EQ(GetInputNode(*node1, graph), nullptr); } +TEST(GraphUtilsTest, EnsureNodeNamesUnique) { + Graph g(OpRegistry::Global()); + + Node *const_0, *const_1, *const_2; + + // Arbitrary const + Tensor tensor(DT_INT32, {}); + tensor.scalar<int32>()() = 5; + + for (auto node : {&const_0, &const_1}) { + TF_EXPECT_OK(NodeBuilder("Const", "Const") + .Attr("value", tensor) + .Attr("dtype", DT_INT32) + .Finalize(&g, node)); + } + // Make sure generated name doesn't clash with existing name either + TF_EXPECT_OK(NodeBuilder("Const_1", "Const") + .Attr("value", tensor) + .Attr("dtype", DT_INT32) + .Finalize(&g, &const_2)); + + TF_EXPECT_OK(EnsureNodeNamesUnique(&g)); + EXPECT_NE(const_0->name(), const_1->name()); + EXPECT_NE(const_1->name(), const_2->name()); + EXPECT_NE(const_0->name(), const_2->name()); +} + } // namespace } // namespace graph_utils } // namespace grappler diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc index cea667f668..2d6cf562b1 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc @@ -352,6 +352,8 @@ Status Vectorization::Initialize(const FunctionDef& outer_scope, Status Vectorization::GetResult(FunctionDef** vectorized_function) { TF_RETURN_IF_ERROR(status_); + TF_RETURN_IF_ERROR(graph_utils::EnsureNodeNamesUnique(outer_scope_.get())); + TF_RETURN_IF_ERROR(graph_utils::EnsureNodeNamesUnique(map_defun_fn_->graph)); if (!map_defun_fn_->ret_nodes.empty()) { FunctionDef* map_defun_fn = lib_->add_function(); |