aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-27 00:34:34 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-27 00:46:18 -0700
commit4e56ea8f9bc398e4cd8bf66abf58cc872c922067 (patch)
treebdb92e50d7ccbc2a5daeacdf7019acab5181f2ee /tensorflow/core/grappler
parentc85998ba9ca005774d81f0f15ee8055f19c6a888 (diff)
Add support for explicit fetches when creating grappler items
PiperOrigin-RevId: 214732243
Diffstat (limited to 'tensorflow/core/grappler')
-rw-r--r--tensorflow/core/grappler/grappler_item_builder.cc8
-rw-r--r--tensorflow/core/grappler/grappler_item_builder.h2
-rw-r--r--tensorflow/core/grappler/grappler_item_builder_test.cc23
3 files changed, 31 insertions, 2 deletions
diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc
index 029515ad3c..369046666d 100644
--- a/tensorflow/core/grappler/grappler_item_builder.cc
+++ b/tensorflow/core/grappler/grappler_item_builder.cc
@@ -192,9 +192,13 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
const string feed_name = NodeName(feed_node);
new_item->feed.emplace_back(feed_name, Tensor());
}
+ for (const auto& fetch_node : cfg.fetch_nodes) {
+ new_item->fetch.emplace_back(NodeName(fetch_node));
+ }
- // Attempt to detect the fetch node(s).
- if (meta_graph.collection_def().count("train_op") > 0) {
+ // Attempt to detect the fetch node(s) if they were not set explicitly.
+ if (new_item->fetch.empty() &&
+ meta_graph.collection_def().count("train_op") > 0) {
const CollectionDef& nodes = meta_graph.collection_def().at("train_op");
if (nodes.has_node_list()) {
for (const auto& node : nodes.node_list().value()) {
diff --git a/tensorflow/core/grappler/grappler_item_builder.h b/tensorflow/core/grappler/grappler_item_builder.h
index aafd2fdcda..1698587f8c 100644
--- a/tensorflow/core/grappler/grappler_item_builder.h
+++ b/tensorflow/core/grappler/grappler_item_builder.h
@@ -49,6 +49,8 @@ struct ItemConfig {
bool prune_graph = false;
// Override feed nodes list.
std::set<string> feed_nodes;
+ // Override fetch nodes list.
+ std::set<string> fetch_nodes;
};
// Factory method for creating a GrapplerItem from a MetaGraphDef.
diff --git a/tensorflow/core/grappler/grappler_item_builder_test.cc b/tensorflow/core/grappler/grappler_item_builder_test.cc
index 4b90bf3038..d00981f174 100644
--- a/tensorflow/core/grappler/grappler_item_builder_test.cc
+++ b/tensorflow/core/grappler/grappler_item_builder_test.cc
@@ -313,6 +313,29 @@ TEST_F(GrapplerItemBuilderTest, FromGraphWithUnknownDimInSignatureInput) {
EXPECT_EQ(item2->feed[0].second.NumElements(), 1);
}
+TEST_F(GrapplerItemBuilderTest, ExplicitFeedAndFetch) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto x = ops::Const(s.WithOpName("x"), 0);
+ auto y = ops::Const(s.WithOpName("y"), 1);
+ auto z = ops::Add(s.WithOpName("z"), x, y);
+
+ MetaGraphDef meta_graph;
+ TF_CHECK_OK(s.ToGraphDef(meta_graph.mutable_graph_def()));
+
+ ItemConfig config;
+ config.feed_nodes.insert("x");
+ config.fetch_nodes.insert("z");
+
+ std::unique_ptr<GrapplerItem> item =
+ GrapplerItemFromMetaGraphDef("0", meta_graph, config);
+ ASSERT_TRUE(item != nullptr);
+
+ EXPECT_EQ(item->feed.size(), 1);
+ EXPECT_EQ(item->fetch.size(), 1);
+ EXPECT_EQ(item->feed[0].first, "x");
+ EXPECT_EQ(item->fetch[0], "z");
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow