diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-27 00:34:34 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-27 00:46:18 -0700 |
commit | 4e56ea8f9bc398e4cd8bf66abf58cc872c922067 (patch) | |
tree | bdb92e50d7ccbc2a5daeacdf7019acab5181f2ee /tensorflow | |
parent | c85998ba9ca005774d81f0f15ee8055f19c6a888 (diff) |
Add support for explicit fetches when creating grappler items
PiperOrigin-RevId: 214732243
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/core/grappler/grappler_item_builder.cc | 8 | ||||
-rw-r--r-- | tensorflow/core/grappler/grappler_item_builder.h | 2 | ||||
-rw-r--r-- | tensorflow/core/grappler/grappler_item_builder_test.cc | 23 |
3 files changed, 31 insertions, 2 deletions
diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc index 029515ad3c..369046666d 100644 --- a/tensorflow/core/grappler/grappler_item_builder.cc +++ b/tensorflow/core/grappler/grappler_item_builder.cc @@ -192,9 +192,13 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef( const string feed_name = NodeName(feed_node); new_item->feed.emplace_back(feed_name, Tensor()); } + for (const auto& fetch_node : cfg.fetch_nodes) { + new_item->fetch.emplace_back(NodeName(fetch_node)); + } - // Attempt to detect the fetch node(s). - if (meta_graph.collection_def().count("train_op") > 0) { + // Attempt to detect the fetch node(s) if they were not set explicitly. + if (new_item->fetch.empty() && + meta_graph.collection_def().count("train_op") > 0) { const CollectionDef& nodes = meta_graph.collection_def().at("train_op"); if (nodes.has_node_list()) { for (const auto& node : nodes.node_list().value()) { diff --git a/tensorflow/core/grappler/grappler_item_builder.h b/tensorflow/core/grappler/grappler_item_builder.h index aafd2fdcda..1698587f8c 100644 --- a/tensorflow/core/grappler/grappler_item_builder.h +++ b/tensorflow/core/grappler/grappler_item_builder.h @@ -49,6 +49,8 @@ struct ItemConfig { bool prune_graph = false; // Override feed nodes list. std::set<string> feed_nodes; + // Override fetch nodes list. + std::set<string> fetch_nodes; }; // Factory method for creating a GrapplerItem from a MetaGraphDef. diff --git a/tensorflow/core/grappler/grappler_item_builder_test.cc b/tensorflow/core/grappler/grappler_item_builder_test.cc index 4b90bf3038..d00981f174 100644 --- a/tensorflow/core/grappler/grappler_item_builder_test.cc +++ b/tensorflow/core/grappler/grappler_item_builder_test.cc @@ -313,6 +313,29 @@ TEST_F(GrapplerItemBuilderTest, FromGraphWithUnknownDimInSignatureInput) { EXPECT_EQ(item2->feed[0].second.NumElements(), 1); } +TEST_F(GrapplerItemBuilderTest, ExplicitFeedAndFetch) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto x = ops::Const(s.WithOpName("x"), 0); + auto y = ops::Const(s.WithOpName("y"), 1); + auto z = ops::Add(s.WithOpName("z"), x, y); + + MetaGraphDef meta_graph; + TF_CHECK_OK(s.ToGraphDef(meta_graph.mutable_graph_def())); + + ItemConfig config; + config.feed_nodes.insert("x"); + config.fetch_nodes.insert("z"); + + std::unique_ptr<GrapplerItem> item = + GrapplerItemFromMetaGraphDef("0", meta_graph, config); + ASSERT_TRUE(item != nullptr); + + EXPECT_EQ(item->feed.size(), 1); + EXPECT_EQ(item->fetch.size(), 1); + EXPECT_EQ(item->feed[0].first, "x"); + EXPECT_EQ(item->fetch[0], "z"); +} + } // namespace } // namespace grappler } // namespace tensorflow |