aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2017-12-04 16:58:49 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-04 17:02:31 -0800
commit129892420278367aa774400455396e4e4d0734ba (patch)
tree7d7e7b35801585f9eb6515e3de3e6959f53d18c4
parent868e2b344b3e7b2e6b069d5c6ec21d73959352c8 (diff)
Fix bug with uniquified colocation attrs in ImportGraphDef.
The colocation attrs must be updated after all NodeDefs have been processed. The nodes are processed and uniquified in topological order, which allows us to update the inputs simultaneously due to the topological ordering, but this doesn't work for the colocation groups. I also considered updating all the NodeDefs with prefixes or unique names at the very beginning, before starting conversion. This would make the logic simpler, but require us to potentially keep a full copy of all the NodeDefs in memory (so we could edit them), so I decided to edit in-place after construction. We might want to consider this alternate in future though. PiperOrigin-RevId: 177890362
-rw-r--r--tensorflow/core/graph/graph_constructor.cc38
-rw-r--r--tensorflow/core/graph/graph_constructor_test.cc55
2 files changed, 77 insertions, 16 deletions
diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc
index 63e3d5ee7d..0fb61fd9af 100644
--- a/tensorflow/core/graph/graph_constructor.cc
+++ b/tensorflow/core/graph/graph_constructor.cc
@@ -159,6 +159,7 @@ class GraphConstructor {
TF_RETURN_IF_ERROR(UpdateVersionDef());
TF_RETURN_IF_ERROR(PopulateReturnTensors());
TF_RETURN_IF_ERROR(PopulateReturnNodes());
+ UpdateUniquifiedColocationNames();
FixupSourceAndSinkEdges(g_);
return Status::OK();
}
@@ -201,6 +202,11 @@ class GraphConstructor {
void UniquifyNames(const std::vector<bool>& input_already_exists,
NodeDef* node_def);
+ // Updates any constructed nodes' colocation group names if the name has been
+ // updated by UniquifyNames. This is called after all the nodes have been
+ // constructed so all the names have been uniquified if necessary.
+ void UpdateUniquifiedColocationNames();
+
// Returns true if `name` already exists in `g_` (either as a node name or
// prefix).
bool NameExistsInGraph(StringPiece name);
@@ -785,18 +791,30 @@ void GraphConstructor::UniquifyNames(
id.first = iter->second;
node_def->set_input(i, id.ToString());
}
- // Update names of colocation groups
- if (node_def->attr().find(kColocationAttrName) != node_def->attr().end()) {
- auto* list =
- node_def->mutable_attr()->at(kColocationAttrName).mutable_list();
- for (int i = 0; i < list->s_size(); ++i) {
- StringPiece v(list->s(i));
- if (v.Consume(kColocationGroupPrefix)) {
- auto iter = uniquified_names_.find(v.ToString());
- if (iter == uniquified_names_.end()) continue;
- list->set_s(i, strings::StrCat(kColocationGroupPrefix, iter->second));
+}
+
+void GraphConstructor::UpdateUniquifiedColocationNames() {
+ for (const auto& pair : gdef_nodes_) {
+ Node* node = pair.second.node;
+ if (node == nullptr) continue;
+ std::vector<string> coloc_values;
+ Status status =
+ GetNodeAttr(node->attrs(), kColocationAttrName, &coloc_values);
+ if (!status.ok()) continue;
+ bool updated = false;
+ for (int i = 0; i < coloc_values.size(); ++i) {
+ StringPiece val(coloc_values[i]);
+ if (val.Consume(kColocationGroupPrefix)) {
+ const auto& name_pair = uniquified_names_.find(val.ToString());
+ if (name_pair == uniquified_names_.end()) continue;
+ updated = true;
+ coloc_values[i] =
+ strings::StrCat(kColocationGroupPrefix, name_pair->second);
}
}
+ if (updated) {
+ node->AddAttr(kColocationAttrName, coloc_values);
+ }
}
}
diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc
index 479f07f7f6..83aba6c9be 100644
--- a/tensorflow/core/graph/graph_constructor_test.cc
+++ b/tensorflow/core/graph/graph_constructor_test.cc
@@ -1898,13 +1898,22 @@ TEST_F(GraphConstructorTest, ImportGraphDef_UniquifyNames) {
EXPECT_EQ(results.return_nodes[0]->name(), "A_5");
EXPECT_EQ(results.return_nodes[1]->name(), "B_5");
EXPECT_EQ(results.return_nodes[1]->def().input(0), "A:0");
+}
+
+TEST_F(GraphConstructorTest, ImportGraphDef_UniquifyNames_ColocationGroups) {
+ ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, graph_.op_registry());
+
+ // Create nodes 'A' and 'b"
+ ExpectOK(
+ "node { name: 'A' op: 'TestInput' }"
+ "node { name: 'B' op: 'TestOneInputTwoOutputs' input: ['A'] }");
// Check that colocation groups are updated
- opts = ImportGraphDefOptions();
+ ImportGraphDefOptions opts;
opts.uniquify_names = true;
opts.return_nodes.push_back("A");
opts.return_nodes.push_back("B");
- results = ImportGraphDefResults();
+ ImportGraphDefResults results;
ExpectOK(
"node { name: 'A' op: 'TestInput' }"
"node { name: 'B' op: 'TestOneInputTwoOutputs' input: ['A:0'] "
@@ -1912,14 +1921,48 @@ TEST_F(GraphConstructorTest, ImportGraphDef_UniquifyNames) {
opts, &refiner, &results);
ASSERT_EQ(results.return_nodes.size(), 2);
- EXPECT_EQ(results.return_nodes[0]->name(), "A_6");
- EXPECT_EQ(results.return_nodes[1]->name(), "B_6");
- EXPECT_EQ(results.return_nodes[1]->def().input(0), "A_6:0");
+ EXPECT_EQ(results.return_nodes[0]->name(), "A_1");
+ EXPECT_EQ(results.return_nodes[1]->name(), "B_1");
const AttrValue* class_attr =
results.return_nodes[1]->attrs().Find(kColocationAttrName);
ASSERT_TRUE(class_attr != nullptr);
ASSERT_EQ(class_attr->list().s_size(), 1);
- EXPECT_EQ(class_attr->list().s(0), "loc:@A_6");
+ EXPECT_EQ(class_attr->list().s(0), "loc:@A_1");
+
+ results = ImportGraphDefResults();
+ ExpectOK(
+ "node { name: 'A' op: 'TestInput' "
+ " attr { key: '_class' value { list { s:'loc:@B' } } } }"
+ "node { name: 'B' op: 'TestOneInputTwoOutputs' input: ['A:0'] }",
+ opts, &refiner, &results);
+
+ ASSERT_EQ(results.return_nodes.size(), 2);
+ EXPECT_EQ(results.return_nodes[0]->name(), "A_2");
+ EXPECT_EQ(results.return_nodes[1]->name(), "B_2");
+ class_attr = results.return_nodes[0]->attrs().Find(kColocationAttrName);
+ ASSERT_TRUE(class_attr != nullptr);
+ ASSERT_EQ(class_attr->list().s_size(), 1);
+ EXPECT_EQ(class_attr->list().s(0), "loc:@B_2");
+
+ results = ImportGraphDefResults();
+ ExpectOK(
+ "node { name: 'A' op: 'TestInput' "
+ " attr { key: '_class' value { list { s:'loc:@B' } } } }"
+ "node { name: 'B' op: 'TestOneInputTwoOutputs' input: ['A:0'] "
+ " attr { key: '_class' value { list { s:'loc:@B' } } } }",
+ opts, &refiner, &results);
+
+ ASSERT_EQ(results.return_nodes.size(), 2);
+ EXPECT_EQ(results.return_nodes[0]->name(), "A_3");
+ EXPECT_EQ(results.return_nodes[1]->name(), "B_3");
+ class_attr = results.return_nodes[0]->attrs().Find(kColocationAttrName);
+ ASSERT_TRUE(class_attr != nullptr);
+ ASSERT_EQ(class_attr->list().s_size(), 1);
+ EXPECT_EQ(class_attr->list().s(0), "loc:@B_3");
+ class_attr = results.return_nodes[1]->attrs().Find(kColocationAttrName);
+ ASSERT_TRUE(class_attr != nullptr);
+ ASSERT_EQ(class_attr->list().s_size(), 1);
+ EXPECT_EQ(class_attr->list().s(0), "loc:@B_3");
}
TEST_F(GraphConstructorTest, ImportGraphDef_WithCycle) {