diff options
author | 2017-12-04 16:58:49 -0800 | |
---|---|---|
committer | 2017-12-04 17:02:31 -0800 | |
commit | 129892420278367aa774400455396e4e4d0734ba (patch) | |
tree | 7d7e7b35801585f9eb6515e3de3e6959f53d18c4 | |
parent | 868e2b344b3e7b2e6b069d5c6ec21d73959352c8 (diff) |
Fix bug with uniquified colocation attrs in ImportGraphDef.
The colocation attrs must be updated after all NodeDefs have been
processed. The nodes are processed and uniquified in topological
order, which allows us to update the inputs simultaneously due to the
topological ordering, but this doesn't work for the colocation groups.
I also considered updating all the NodeDefs with prefixes or unique
names at the very beginning, before starting conversion. This would
make the logic simpler, but require us to potentially keep a full copy
of all the NodeDefs in memory (so we could edit them), so I decided to
edit in-place after construction. We might want to consider this
alternate in future though.
PiperOrigin-RevId: 177890362
-rw-r--r-- | tensorflow/core/graph/graph_constructor.cc | 38 | ||||
-rw-r--r-- | tensorflow/core/graph/graph_constructor_test.cc | 55 |
2 files changed, 77 insertions, 16 deletions
diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc index 63e3d5ee7d..0fb61fd9af 100644 --- a/tensorflow/core/graph/graph_constructor.cc +++ b/tensorflow/core/graph/graph_constructor.cc @@ -159,6 +159,7 @@ class GraphConstructor { TF_RETURN_IF_ERROR(UpdateVersionDef()); TF_RETURN_IF_ERROR(PopulateReturnTensors()); TF_RETURN_IF_ERROR(PopulateReturnNodes()); + UpdateUniquifiedColocationNames(); FixupSourceAndSinkEdges(g_); return Status::OK(); } @@ -201,6 +202,11 @@ class GraphConstructor { void UniquifyNames(const std::vector<bool>& input_already_exists, NodeDef* node_def); + // Updates any constructed nodes' colocation group names if the name has been + // updated by UniquifyNames. This is called after all the nodes have been + // constructed so all the names have been uniquified if necessary. + void UpdateUniquifiedColocationNames(); + // Returns true if `name` already exists in `g_` (either as a node name or // prefix). bool NameExistsInGraph(StringPiece name); @@ -785,18 +791,30 @@ void GraphConstructor::UniquifyNames( id.first = iter->second; node_def->set_input(i, id.ToString()); } - // Update names of colocation groups - if (node_def->attr().find(kColocationAttrName) != node_def->attr().end()) { - auto* list = - node_def->mutable_attr()->at(kColocationAttrName).mutable_list(); - for (int i = 0; i < list->s_size(); ++i) { - StringPiece v(list->s(i)); - if (v.Consume(kColocationGroupPrefix)) { - auto iter = uniquified_names_.find(v.ToString()); - if (iter == uniquified_names_.end()) continue; - list->set_s(i, strings::StrCat(kColocationGroupPrefix, iter->second)); +} + +void GraphConstructor::UpdateUniquifiedColocationNames() { + for (const auto& pair : gdef_nodes_) { + Node* node = pair.second.node; + if (node == nullptr) continue; + std::vector<string> coloc_values; + Status status = + GetNodeAttr(node->attrs(), kColocationAttrName, &coloc_values); + if (!status.ok()) continue; + bool updated = false; + for (int i = 0; i < coloc_values.size(); ++i) { + StringPiece val(coloc_values[i]); + if (val.Consume(kColocationGroupPrefix)) { + const auto& name_pair = uniquified_names_.find(val.ToString()); + if (name_pair == uniquified_names_.end()) continue; + updated = true; + coloc_values[i] = + strings::StrCat(kColocationGroupPrefix, name_pair->second); } } + if (updated) { + node->AddAttr(kColocationAttrName, coloc_values); + } } } diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc index 479f07f7f6..83aba6c9be 100644 --- a/tensorflow/core/graph/graph_constructor_test.cc +++ b/tensorflow/core/graph/graph_constructor_test.cc @@ -1898,13 +1898,22 @@ TEST_F(GraphConstructorTest, ImportGraphDef_UniquifyNames) { EXPECT_EQ(results.return_nodes[0]->name(), "A_5"); EXPECT_EQ(results.return_nodes[1]->name(), "B_5"); EXPECT_EQ(results.return_nodes[1]->def().input(0), "A:0"); +} + +TEST_F(GraphConstructorTest, ImportGraphDef_UniquifyNames_ColocationGroups) { + ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, graph_.op_registry()); + + // Create nodes 'A' and 'b" + ExpectOK( + "node { name: 'A' op: 'TestInput' }" + "node { name: 'B' op: 'TestOneInputTwoOutputs' input: ['A'] }"); // Check that colocation groups are updated - opts = ImportGraphDefOptions(); + ImportGraphDefOptions opts; opts.uniquify_names = true; opts.return_nodes.push_back("A"); opts.return_nodes.push_back("B"); - results = ImportGraphDefResults(); + ImportGraphDefResults results; ExpectOK( "node { name: 'A' op: 'TestInput' }" "node { name: 'B' op: 'TestOneInputTwoOutputs' input: ['A:0'] " @@ -1912,14 +1921,48 @@ TEST_F(GraphConstructorTest, ImportGraphDef_UniquifyNames) { opts, &refiner, &results); ASSERT_EQ(results.return_nodes.size(), 2); - EXPECT_EQ(results.return_nodes[0]->name(), "A_6"); - EXPECT_EQ(results.return_nodes[1]->name(), "B_6"); - EXPECT_EQ(results.return_nodes[1]->def().input(0), "A_6:0"); + EXPECT_EQ(results.return_nodes[0]->name(), "A_1"); + EXPECT_EQ(results.return_nodes[1]->name(), "B_1"); const AttrValue* class_attr = results.return_nodes[1]->attrs().Find(kColocationAttrName); ASSERT_TRUE(class_attr != nullptr); ASSERT_EQ(class_attr->list().s_size(), 1); - EXPECT_EQ(class_attr->list().s(0), "loc:@A_6"); + EXPECT_EQ(class_attr->list().s(0), "loc:@A_1"); + + results = ImportGraphDefResults(); + ExpectOK( + "node { name: 'A' op: 'TestInput' " + " attr { key: '_class' value { list { s:'loc:@B' } } } }" + "node { name: 'B' op: 'TestOneInputTwoOutputs' input: ['A:0'] }", + opts, &refiner, &results); + + ASSERT_EQ(results.return_nodes.size(), 2); + EXPECT_EQ(results.return_nodes[0]->name(), "A_2"); + EXPECT_EQ(results.return_nodes[1]->name(), "B_2"); + class_attr = results.return_nodes[0]->attrs().Find(kColocationAttrName); + ASSERT_TRUE(class_attr != nullptr); + ASSERT_EQ(class_attr->list().s_size(), 1); + EXPECT_EQ(class_attr->list().s(0), "loc:@B_2"); + + results = ImportGraphDefResults(); + ExpectOK( + "node { name: 'A' op: 'TestInput' " + " attr { key: '_class' value { list { s:'loc:@B' } } } }" + "node { name: 'B' op: 'TestOneInputTwoOutputs' input: ['A:0'] " + " attr { key: '_class' value { list { s:'loc:@B' } } } }", + opts, &refiner, &results); + + ASSERT_EQ(results.return_nodes.size(), 2); + EXPECT_EQ(results.return_nodes[0]->name(), "A_3"); + EXPECT_EQ(results.return_nodes[1]->name(), "B_3"); + class_attr = results.return_nodes[0]->attrs().Find(kColocationAttrName); + ASSERT_TRUE(class_attr != nullptr); + ASSERT_EQ(class_attr->list().s_size(), 1); + EXPECT_EQ(class_attr->list().s(0), "loc:@B_3"); + class_attr = results.return_nodes[1]->attrs().Find(kColocationAttrName); + ASSERT_TRUE(class_attr != nullptr); + ASSERT_EQ(class_attr->list().s_size(), 1); + EXPECT_EQ(class_attr->list().s(0), "loc:@B_3"); } TEST_F(GraphConstructorTest, ImportGraphDef_WithCycle) { |