aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Rachel Lim <rachelim@google.com>2018-10-03 11:24:41 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-03 11:29:40 -0700
commit47eafbaf43c763dc65a2cd3cfd9ecbd8fbbdf668 (patch)
treeb9426571254706e767b5f09159c81b26bdf64d96
parent880dcb7a91e5ee497045614d9c5f4ab93c9ffacf (diff)
[tf.data] Add utility to deduplicate graph node names (after vectorization)
PiperOrigin-RevId: 215595078
-rw-r--r--tensorflow/core/graph/graph.cc5
-rw-r--r--tensorflow/core/graph/graph.h1
-rw-r--r--tensorflow/core/grappler/optimizers/data/BUILD2
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils.cc21
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils.h9
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils_test.cc28
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils.cc2
7 files changed, 68 insertions, 0 deletions
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc
index 4c0cd14ff1..7a4a0096fa 100644
--- a/tensorflow/core/graph/graph.cc
+++ b/tensorflow/core/graph/graph.cc
@@ -192,6 +192,11 @@ void Node::ClearAttr(const string& name) {
(*props_->node_def.mutable_attr()).erase(name);
}
+void Node::set_name(string name) {
+ MaybeCopyOnWrite();
+ props_->node_def.set_name(std::move(name));
+}
+
void Node::set_requested_device(const string& device) {
MaybeCopyOnWrite();
props_->node_def.set_device(device);
diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h
index 72cef07072..2944951f82 100644
--- a/tensorflow/core/graph/graph.h
+++ b/tensorflow/core/graph/graph.h
@@ -72,6 +72,7 @@ class Node {
int id() const { return id_; }
int cost_id() const { return cost_id_; }
const string& name() const;
+ void set_name(string name);
const string& type_string() const;
// def() provides the NodeDef the user supplied, but the specifics
diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD
index 5a3abbb545..755af3361e 100644
--- a/tensorflow/core/grappler/optimizers/data/BUILD
+++ b/tensorflow/core/grappler/optimizers/data/BUILD
@@ -129,6 +129,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core/grappler:mutable_graph_view",
"//tensorflow/core/grappler:utils",
+ "//tensorflow/core:lib_internal",
] + tf_protos_all(),
)
@@ -138,6 +139,7 @@ tf_cc_test(
visibility = ["//visibility:public"],
deps = [
":graph_utils",
+ "//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
index 3eaaf8fbef..b863a25dc5 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
@@ -272,6 +273,26 @@ void ConcatAttributeList(const string& attribute_name, const NodeDef& first,
->MergeFrom(second.attr().at(attribute_name).list());
}
+Status EnsureNodeNamesUnique(Graph* g) {
+ // Modeled after Scope::Impl::GetUniqueName
+ std::unordered_map<string, int> name_map;
+
+ for (auto node : g->op_nodes()) {
+ const string& prefix = node->name();
+ if (auto entry = gtl::FindOrNull(name_map, prefix)) {
+ string unique_name;
+ do {
+ unique_name = strings::StrCat(prefix, "_", ++(*entry));
+ } while (name_map.find(unique_name) != name_map.end());
+ name_map.insert({unique_name, 0});
+ node->set_name(std::move(unique_name));
+ } else {
+ name_map.insert({node->name(), 0});
+ }
+ }
+
+ return Status::OK();
+}
} // end namespace graph_utils
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.h b/tensorflow/core/grappler/optimizers/data/graph_utils.h
index 3af34f6904..d130fee204 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils.h
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils.h
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/grappler/mutable_graph_view.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -131,6 +132,14 @@ void CopyAttribute(const string& attribute_name, const NodeDef& from,
void ConcatAttributeList(const string& attribute_name, const NodeDef& first,
const NodeDef& second, NodeDef* to_node);
+// Checks that all nodes in the graphs have unique names, and sets their names
+// to be unique if they are not already. This is necessary as Graph does not
+// have the provisions to deduplicate names, and name deduplication elsewhere
+// in tensorflow happens in other layers (for example, in the Scope class of the
+// C++ API). Note that the nodes in the graph are identified by their id,
+// and renaming nodes does not mutate any edges.
+Status EnsureNodeNamesUnique(Graph* g);
+
} // end namespace graph_utils
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
index db986542b2..4ab6d71532 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
@@ -229,6 +230,33 @@ TEST(GraphUtilsTest, GetInputNode) {
EXPECT_EQ(GetInputNode(*node1, graph), nullptr);
}
+TEST(GraphUtilsTest, EnsureNodeNamesUnique) {
+ Graph g(OpRegistry::Global());
+
+ Node *const_0, *const_1, *const_2;
+
+ // Arbitrary const
+ Tensor tensor(DT_INT32, {});
+ tensor.scalar<int32>()() = 5;
+
+ for (auto node : {&const_0, &const_1}) {
+ TF_EXPECT_OK(NodeBuilder("Const", "Const")
+ .Attr("value", tensor)
+ .Attr("dtype", DT_INT32)
+ .Finalize(&g, node));
+ }
+ // Make sure generated name doesn't clash with existing name either
+ TF_EXPECT_OK(NodeBuilder("Const_1", "Const")
+ .Attr("value", tensor)
+ .Attr("dtype", DT_INT32)
+ .Finalize(&g, &const_2));
+
+ TF_EXPECT_OK(EnsureNodeNamesUnique(&g));
+ EXPECT_NE(const_0->name(), const_1->name());
+ EXPECT_NE(const_1->name(), const_2->name());
+ EXPECT_NE(const_0->name(), const_2->name());
+}
+
} // namespace
} // namespace graph_utils
} // namespace grappler
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
index cea667f668..2d6cf562b1 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
@@ -352,6 +352,8 @@ Status Vectorization::Initialize(const FunctionDef& outer_scope,
Status Vectorization::GetResult(FunctionDef** vectorized_function) {
TF_RETURN_IF_ERROR(status_);
+ TF_RETURN_IF_ERROR(graph_utils::EnsureNodeNamesUnique(outer_scope_.get()));
+ TF_RETURN_IF_ERROR(graph_utils::EnsureNodeNamesUnique(map_defun_fn_->graph));
if (!map_defun_fn_->ret_nodes.empty()) {
FunctionDef* map_defun_fn = lib_->add_function();