diff options
author | Yuefeng Zhou <yuefengz@google.com> | 2018-02-28 12:14:03 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-02-28 12:21:42 -0800 |
commit | 3dbbf740441cdd41b2dc998e09980d72d2e9d440 (patch) | |
tree | f01283e0ee5df55ee5d6b10b6b62e1ed09f86f9a | |
parent | 31421c3fa3a0585c01198458fa123c3493c21b62 (diff) |
In Grappler item builder, support inferring fetch nodes from siganture defs.
PiperOrigin-RevId: 187364078
-rw-r--r-- | tensorflow/core/grappler/grappler_item_builder.cc | 76 | ||||
-rw-r--r-- | tensorflow/core/grappler/grappler_item_builder_test.cc | 53 |
2 files changed, 117 insertions, 12 deletions
diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc index 606807b9e9..33ad426bbf 100644 --- a/tensorflow/core/grappler/grappler_item_builder.cc +++ b/tensorflow/core/grappler/grappler_item_builder.cc @@ -168,12 +168,6 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef( // Fill in feed nodes from config, if any provided. for (const auto& feed_node : cfg.feed_nodes) { const string feed_name = NodeName(feed_node); - if (feed_name.empty()) { - LOG(ERROR) << "Invalid feed node name " << feed_node - << ", skipping this input."; - return nullptr; - } - VLOG(1) << "Will use feed node " << feed_name; new_item->feed.emplace_back(feed_name, Tensor()); } @@ -182,17 +176,75 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef( const CollectionDef& nodes = meta_graph.collection_def().at("train_op"); if (nodes.has_node_list()) { for (const auto& node : nodes.node_list().value()) { - const string name = NodeName(node); - if (name.empty()) { - LOG(ERROR) << "Invalid fetch node name " << node - << ", skipping this input"; + new_item->fetch.push_back(NodeName(node)); + } + } + } + + // Detect feed and fetch nodes from signature defs. + for (const auto& name_and_signature : meta_graph.signature_def()) { + for (const auto& name_and_input : name_and_signature.second.inputs()) { + const TensorInfo& input = name_and_input.second; + if (input.has_coo_sparse()) { + // Define the shapes following the comment of CooSparse. + PartialTensorShape partial_shape_1d({-1}); + PartialTensorShape partial_shape_2d({-1, -1}); + TensorShape shape_1d; + TensorShape shape_2d; + if (!partial_shape_1d.AsTensorShape(&shape_1d) || + !partial_shape_2d.AsTensorShape(&shape_2d)) { + LOG(ERROR) << "Internal error when constructing tensor shapes."; return nullptr; } - VLOG(1) << "Will use fetch node " << name; - new_item->fetch.push_back(name); + + new_item->feed.emplace_back( + NodeName(input.coo_sparse().values_tensor_name()), + Tensor(input.dtype(), shape_1d)); + new_item->feed.emplace_back( + NodeName(input.coo_sparse().indices_tensor_name()), + Tensor(DT_INT64, shape_2d)); + new_item->feed.emplace_back( + NodeName(input.coo_sparse().dense_shape_tensor_name()), + Tensor(DT_INT64, shape_1d)); + } else { + new_item->feed.emplace_back( + NodeName(input.name()), + Tensor(input.dtype(), input.tensor_shape())); } } + for (const auto& name_and_output : name_and_signature.second.outputs()) { + const TensorInfo& output = name_and_output.second; + if (output.has_coo_sparse()) { + new_item->fetch.push_back( + NodeName(output.coo_sparse().values_tensor_name())); + new_item->fetch.push_back( + NodeName(output.coo_sparse().indices_tensor_name())); + new_item->fetch.push_back( + NodeName(output.coo_sparse().dense_shape_tensor_name())); + } else { + new_item->fetch.push_back(NodeName(output.name())); + } + } + } + + for (const auto& feed : new_item->feed) { + if (feed.first.empty()) { + LOG(ERROR) << "Invalid feed node name skipping this input"; + return nullptr; + } else { + VLOG(1) << "Will use feed node " << feed.first; + } + } + + for (const auto& fetch : new_item->fetch) { + if (fetch.empty()) { + LOG(ERROR) << "Invalid fetch node name skipping this input"; + return nullptr; + } else { + VLOG(1) << "Will use fetch node " << fetch; + } } + if (new_item->fetch.empty()) { LOG(ERROR) << "Failed to detect the fetch node(s), skipping this input"; return nullptr; diff --git a/tensorflow/core/grappler/grappler_item_builder_test.cc b/tensorflow/core/grappler/grappler_item_builder_test.cc index ef95992af7..78cbff6c90 100644 --- a/tensorflow/core/grappler/grappler_item_builder_test.cc +++ b/tensorflow/core/grappler/grappler_item_builder_test.cc @@ -280,6 +280,59 @@ TEST_F(GrapplerItemBuilderTest, GraphWithFunctions) { ASSERT_TRUE(item != nullptr); } +TEST_F(GrapplerItemBuilderTest, FromGraphWithSignatureDef) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto x = ops::Const(s.WithOpName("x"), 0); + auto y = ops::Const(s.WithOpName("y"), 1); + auto z = ops::Add(s.WithOpName("z"), x, y); + + MetaGraphDef meta_graph; + TF_CHECK_OK(s.ToGraphDef(meta_graph.mutable_graph_def())); + + TensorInfo input, output; + input.set_name("x"); + input.set_dtype(DT_FLOAT); + output.set_name("z"); + SignatureDef serving_signature; + (*serving_signature.mutable_inputs())["input"] = input; + (*serving_signature.mutable_outputs())["output"] = output; + (*meta_graph.mutable_signature_def())["serving"] = serving_signature; + + std::unique_ptr<GrapplerItem> item = + GrapplerItemFromMetaGraphDef("0", meta_graph, ItemConfig()); + ASSERT_TRUE(item != nullptr); + + EXPECT_EQ(item->feed[0].first, "x"); + EXPECT_EQ(item->fetch[0], "z"); +} + +TEST_F(GrapplerItemBuilderTest, FromGraphWithIncompleteSignatureDef) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto x = ops::Const(s.WithOpName("x"), 0); + auto y = ops::Const(s.WithOpName("y"), 1); + + MetaGraphDef meta_graph; + TF_CHECK_OK(s.ToGraphDef(meta_graph.mutable_graph_def())); + + CollectionDef train_op; + train_op.mutable_node_list()->add_value("y"); + (*meta_graph.mutable_collection_def())["train_op"] = train_op; + + TensorInfo input, output; + input.set_name("x"); + input.set_dtype(DT_FLOAT); + // Its coo_sparse proto is incomplete. + output.mutable_coo_sparse()->set_values_tensor_name("z"); + SignatureDef serving_signature; + (*serving_signature.mutable_inputs())["input"] = input; + (*serving_signature.mutable_outputs())["output"] = output; + (*meta_graph.mutable_signature_def())["serving"] = serving_signature; + + std::unique_ptr<GrapplerItem> item = + GrapplerItemFromMetaGraphDef("0", meta_graph, ItemConfig()); + ASSERT_TRUE(item == nullptr); +} + } // namespace } // namespace grappler } // namespace tensorflow |