aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yuefeng Zhou <yuefengz@google.com>2018-02-28 12:14:03 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-28 12:21:42 -0800
commit3dbbf740441cdd41b2dc998e09980d72d2e9d440 (patch)
treef01283e0ee5df55ee5d6b10b6b62e1ed09f86f9a
parent31421c3fa3a0585c01198458fa123c3493c21b62 (diff)
In Grappler item builder, support inferring fetch nodes from siganture defs.
PiperOrigin-RevId: 187364078
-rw-r--r--tensorflow/core/grappler/grappler_item_builder.cc76
-rw-r--r--tensorflow/core/grappler/grappler_item_builder_test.cc53
2 files changed, 117 insertions, 12 deletions
diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc
index 606807b9e9..33ad426bbf 100644
--- a/tensorflow/core/grappler/grappler_item_builder.cc
+++ b/tensorflow/core/grappler/grappler_item_builder.cc
@@ -168,12 +168,6 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
// Fill in feed nodes from config, if any provided.
for (const auto& feed_node : cfg.feed_nodes) {
const string feed_name = NodeName(feed_node);
- if (feed_name.empty()) {
- LOG(ERROR) << "Invalid feed node name " << feed_node
- << ", skipping this input.";
- return nullptr;
- }
- VLOG(1) << "Will use feed node " << feed_name;
new_item->feed.emplace_back(feed_name, Tensor());
}
@@ -182,17 +176,75 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
const CollectionDef& nodes = meta_graph.collection_def().at("train_op");
if (nodes.has_node_list()) {
for (const auto& node : nodes.node_list().value()) {
- const string name = NodeName(node);
- if (name.empty()) {
- LOG(ERROR) << "Invalid fetch node name " << node
- << ", skipping this input";
+ new_item->fetch.push_back(NodeName(node));
+ }
+ }
+ }
+
+ // Detect feed and fetch nodes from signature defs.
+ for (const auto& name_and_signature : meta_graph.signature_def()) {
+ for (const auto& name_and_input : name_and_signature.second.inputs()) {
+ const TensorInfo& input = name_and_input.second;
+ if (input.has_coo_sparse()) {
+ // Define the shapes following the comment of CooSparse.
+ PartialTensorShape partial_shape_1d({-1});
+ PartialTensorShape partial_shape_2d({-1, -1});
+ TensorShape shape_1d;
+ TensorShape shape_2d;
+ if (!partial_shape_1d.AsTensorShape(&shape_1d) ||
+ !partial_shape_2d.AsTensorShape(&shape_2d)) {
+ LOG(ERROR) << "Internal error when constructing tensor shapes.";
return nullptr;
}
- VLOG(1) << "Will use fetch node " << name;
- new_item->fetch.push_back(name);
+
+ new_item->feed.emplace_back(
+ NodeName(input.coo_sparse().values_tensor_name()),
+ Tensor(input.dtype(), shape_1d));
+ new_item->feed.emplace_back(
+ NodeName(input.coo_sparse().indices_tensor_name()),
+ Tensor(DT_INT64, shape_2d));
+ new_item->feed.emplace_back(
+ NodeName(input.coo_sparse().dense_shape_tensor_name()),
+ Tensor(DT_INT64, shape_1d));
+ } else {
+ new_item->feed.emplace_back(
+ NodeName(input.name()),
+ Tensor(input.dtype(), input.tensor_shape()));
}
}
+ for (const auto& name_and_output : name_and_signature.second.outputs()) {
+ const TensorInfo& output = name_and_output.second;
+ if (output.has_coo_sparse()) {
+ new_item->fetch.push_back(
+ NodeName(output.coo_sparse().values_tensor_name()));
+ new_item->fetch.push_back(
+ NodeName(output.coo_sparse().indices_tensor_name()));
+ new_item->fetch.push_back(
+ NodeName(output.coo_sparse().dense_shape_tensor_name()));
+ } else {
+ new_item->fetch.push_back(NodeName(output.name()));
+ }
+ }
+ }
+
+ for (const auto& feed : new_item->feed) {
+ if (feed.first.empty()) {
+ LOG(ERROR) << "Invalid feed node name skipping this input";
+ return nullptr;
+ } else {
+ VLOG(1) << "Will use feed node " << feed.first;
+ }
+ }
+
+ for (const auto& fetch : new_item->fetch) {
+ if (fetch.empty()) {
+ LOG(ERROR) << "Invalid fetch node name skipping this input";
+ return nullptr;
+ } else {
+ VLOG(1) << "Will use fetch node " << fetch;
+ }
}
+
if (new_item->fetch.empty()) {
LOG(ERROR) << "Failed to detect the fetch node(s), skipping this input";
return nullptr;
diff --git a/tensorflow/core/grappler/grappler_item_builder_test.cc b/tensorflow/core/grappler/grappler_item_builder_test.cc
index ef95992af7..78cbff6c90 100644
--- a/tensorflow/core/grappler/grappler_item_builder_test.cc
+++ b/tensorflow/core/grappler/grappler_item_builder_test.cc
@@ -280,6 +280,59 @@ TEST_F(GrapplerItemBuilderTest, GraphWithFunctions) {
ASSERT_TRUE(item != nullptr);
}
+TEST_F(GrapplerItemBuilderTest, FromGraphWithSignatureDef) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto x = ops::Const(s.WithOpName("x"), 0);
+ auto y = ops::Const(s.WithOpName("y"), 1);
+ auto z = ops::Add(s.WithOpName("z"), x, y);
+
+ MetaGraphDef meta_graph;
+ TF_CHECK_OK(s.ToGraphDef(meta_graph.mutable_graph_def()));
+
+ TensorInfo input, output;
+ input.set_name("x");
+ input.set_dtype(DT_FLOAT);
+ output.set_name("z");
+ SignatureDef serving_signature;
+ (*serving_signature.mutable_inputs())["input"] = input;
+ (*serving_signature.mutable_outputs())["output"] = output;
+ (*meta_graph.mutable_signature_def())["serving"] = serving_signature;
+
+ std::unique_ptr<GrapplerItem> item =
+ GrapplerItemFromMetaGraphDef("0", meta_graph, ItemConfig());
+ ASSERT_TRUE(item != nullptr);
+
+ EXPECT_EQ(item->feed[0].first, "x");
+ EXPECT_EQ(item->fetch[0], "z");
+}
+
+TEST_F(GrapplerItemBuilderTest, FromGraphWithIncompleteSignatureDef) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto x = ops::Const(s.WithOpName("x"), 0);
+ auto y = ops::Const(s.WithOpName("y"), 1);
+
+ MetaGraphDef meta_graph;
+ TF_CHECK_OK(s.ToGraphDef(meta_graph.mutable_graph_def()));
+
+ CollectionDef train_op;
+ train_op.mutable_node_list()->add_value("y");
+ (*meta_graph.mutable_collection_def())["train_op"] = train_op;
+
+ TensorInfo input, output;
+ input.set_name("x");
+ input.set_dtype(DT_FLOAT);
+ // Its coo_sparse proto is incomplete.
+ output.mutable_coo_sparse()->set_values_tensor_name("z");
+ SignatureDef serving_signature;
+ (*serving_signature.mutable_inputs())["input"] = input;
+ (*serving_signature.mutable_outputs())["output"] = output;
+ (*meta_graph.mutable_signature_def())["serving"] = serving_signature;
+
+ std::unique_ptr<GrapplerItem> item =
+ GrapplerItemFromMetaGraphDef("0", meta_graph, ItemConfig());
+ ASSERT_TRUE(item == nullptr);
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow