aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/grappler_item_builder.cc
diff options
context:
space:
mode:
authorGravatar Yuefeng Zhou <yuefengz@google.com>2018-02-28 12:14:03 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-28 12:21:42 -0800
commit3dbbf740441cdd41b2dc998e09980d72d2e9d440 (patch)
treef01283e0ee5df55ee5d6b10b6b62e1ed09f86f9a /tensorflow/core/grappler/grappler_item_builder.cc
parent31421c3fa3a0585c01198458fa123c3493c21b62 (diff)
In Grappler item builder, support inferring fetch nodes from siganture defs.
PiperOrigin-RevId: 187364078
Diffstat (limited to 'tensorflow/core/grappler/grappler_item_builder.cc')
-rw-r--r--tensorflow/core/grappler/grappler_item_builder.cc76
1 files changed, 64 insertions, 12 deletions
diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc
index 606807b9e9..33ad426bbf 100644
--- a/tensorflow/core/grappler/grappler_item_builder.cc
+++ b/tensorflow/core/grappler/grappler_item_builder.cc
@@ -168,12 +168,6 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
// Fill in feed nodes from config, if any provided.
for (const auto& feed_node : cfg.feed_nodes) {
const string feed_name = NodeName(feed_node);
- if (feed_name.empty()) {
- LOG(ERROR) << "Invalid feed node name " << feed_node
- << ", skipping this input.";
- return nullptr;
- }
- VLOG(1) << "Will use feed node " << feed_name;
new_item->feed.emplace_back(feed_name, Tensor());
}
@@ -182,17 +176,75 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
const CollectionDef& nodes = meta_graph.collection_def().at("train_op");
if (nodes.has_node_list()) {
for (const auto& node : nodes.node_list().value()) {
- const string name = NodeName(node);
- if (name.empty()) {
- LOG(ERROR) << "Invalid fetch node name " << node
- << ", skipping this input";
+ new_item->fetch.push_back(NodeName(node));
+ }
+ }
+ }
+
+ // Detect feed and fetch nodes from signature defs.
+ for (const auto& name_and_signature : meta_graph.signature_def()) {
+ for (const auto& name_and_input : name_and_signature.second.inputs()) {
+ const TensorInfo& input = name_and_input.second;
+ if (input.has_coo_sparse()) {
+ // Define the shapes following the comment of CooSparse.
+ PartialTensorShape partial_shape_1d({-1});
+ PartialTensorShape partial_shape_2d({-1, -1});
+ TensorShape shape_1d;
+ TensorShape shape_2d;
+ if (!partial_shape_1d.AsTensorShape(&shape_1d) ||
+ !partial_shape_2d.AsTensorShape(&shape_2d)) {
+ LOG(ERROR) << "Internal error when constructing tensor shapes.";
return nullptr;
}
- VLOG(1) << "Will use fetch node " << name;
- new_item->fetch.push_back(name);
+
+ new_item->feed.emplace_back(
+ NodeName(input.coo_sparse().values_tensor_name()),
+ Tensor(input.dtype(), shape_1d));
+ new_item->feed.emplace_back(
+ NodeName(input.coo_sparse().indices_tensor_name()),
+ Tensor(DT_INT64, shape_2d));
+ new_item->feed.emplace_back(
+ NodeName(input.coo_sparse().dense_shape_tensor_name()),
+ Tensor(DT_INT64, shape_1d));
+ } else {
+ new_item->feed.emplace_back(
+ NodeName(input.name()),
+ Tensor(input.dtype(), input.tensor_shape()));
}
}
+ for (const auto& name_and_output : name_and_signature.second.outputs()) {
+ const TensorInfo& output = name_and_output.second;
+ if (output.has_coo_sparse()) {
+ new_item->fetch.push_back(
+ NodeName(output.coo_sparse().values_tensor_name()));
+ new_item->fetch.push_back(
+ NodeName(output.coo_sparse().indices_tensor_name()));
+ new_item->fetch.push_back(
+ NodeName(output.coo_sparse().dense_shape_tensor_name()));
+ } else {
+ new_item->fetch.push_back(NodeName(output.name()));
+ }
+ }
+ }
+
+ for (const auto& feed : new_item->feed) {
+ if (feed.first.empty()) {
+ LOG(ERROR) << "Invalid feed node name skipping this input";
+ return nullptr;
+ } else {
+ VLOG(1) << "Will use feed node " << feed.first;
+ }
+ }
+
+ for (const auto& fetch : new_item->fetch) {
+ if (fetch.empty()) {
+ LOG(ERROR) << "Invalid fetch node name skipping this input";
+ return nullptr;
+ } else {
+ VLOG(1) << "Will use fetch node " << fetch;
+ }
}
+
if (new_item->fetch.empty()) {
LOG(ERROR) << "Failed to detect the fetch node(s), skipping this input";
return nullptr;