diff options
author | 2018-03-13 11:27:46 -0700 | |
---|---|---|
committer | 2018-03-13 11:32:37 -0700 | |
commit | 7d02968ce04d9576f152e6cfd0c88da096e862a4 (patch) | |
tree | cad0502ef74c0e7fcba81fb579eed3d3e3d71e56 /tensorflow/core/grappler/grappler_item_builder.cc | |
parent | ea9e65c94ad71ca86d2be91c4109c62269b42cf8 (diff) |
Replace the unknown dimension of signature input when building grappler items.
Fix the bug where same feed nodes or fetch nodes would be added more than once.
PiperOrigin-RevId: 188902101
Diffstat (limited to 'tensorflow/core/grappler/grappler_item_builder.cc')
-rw-r--r-- | tensorflow/core/grappler/grappler_item_builder.cc | 144 |
1 files changed, 100 insertions, 44 deletions
diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc index 04c7dae30b..d7b300321a 100644 --- a/tensorflow/core/grappler/grappler_item_builder.cc +++ b/tensorflow/core/grappler/grappler_item_builder.cc @@ -38,6 +38,7 @@ limitations under the License. #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/optimizers/model_pruner.h" #include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/protobuf_internal.h" #include "tensorflow/core/protobuf/meta_graph.pb.h" @@ -152,6 +153,27 @@ Status PruneGraph(GrapplerItem* item) { return Status::OK(); } +// Replace any unknown dimensions in a shape with +// cfg.placeholder_unknown_output_shape_dim if it is no less than 0. +Status ReplaceUnknownShapeDim(const ItemConfig& cfg, + const TensorShapeProto& shape_pb_in, + TensorShapeProto* shape_pb_out, + TensorShape* shape_out) { + std::vector<int32> dims; + for (const auto& dim_proto : shape_pb_in.dim()) { + if (cfg.placeholder_unknown_output_shape_dim >= 0 && + dim_proto.size() == -1) { + dims.push_back(cfg.placeholder_unknown_output_shape_dim); + shape_pb_out->add_dim()->set_size( + cfg.placeholder_unknown_output_shape_dim); + } else { + dims.push_back(std::max<int32>(1, dim_proto.size())); + shape_pb_out->add_dim()->set_size(dim_proto.size()); + } + } + return TensorShapeUtils::MakeShape(dims.data(), dims.size(), shape_out); +} + } // namespace // static @@ -181,48 +203,92 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef( } } - // Detect feed and fetch nodes from signature defs. + // Detect feed and fetch nodes from signature defs. Signatures may share same + // inputs or outputs. + std::unordered_set<string> signature_feed_nodes; + std::unordered_set<string> signature_fetch_nodes; for (const auto& name_and_signature : meta_graph.signature_def()) { for (const auto& name_and_input : name_and_signature.second.inputs()) { const TensorInfo& input = name_and_input.second; if (input.has_coo_sparse()) { // Define the shapes following the comment of CooSparse. - PartialTensorShape partial_shape_1d({-1}); - PartialTensorShape partial_shape_2d({-1, -1}); - TensorShape shape_1d; - TensorShape shape_2d; - if (!partial_shape_1d.AsTensorShape(&shape_1d) || - !partial_shape_2d.AsTensorShape(&shape_2d)) { - LOG(ERROR) << "Internal error when constructing tensor shapes."; - return nullptr; + // TODO(yuefengz): we probably want to use different dim values for the + // three tensors of a SparseTensor. + int64 dim = std::max(1, cfg.placeholder_unknown_output_shape_dim); + TensorShape shape_1d({dim}); + TensorShape shape_2d({dim, dim}); + + if (gtl::InsertIfNotPresent( + &signature_feed_nodes, + NodeName(input.coo_sparse().values_tensor_name()))) { + Tensor value_tensor(input.dtype(), shape_1d); + InitializeTensor(input.dtype(), &value_tensor); + new_item->feed.emplace_back( + NodeName(input.coo_sparse().values_tensor_name()), value_tensor); + } + if (gtl::InsertIfNotPresent( + &signature_feed_nodes, + NodeName(input.coo_sparse().indices_tensor_name()))) { + Tensor indices_tensor(DT_INT64, shape_2d); + InitializeTensor(input.dtype(), &indices_tensor); + new_item->feed.emplace_back( + NodeName(input.coo_sparse().indices_tensor_name()), + indices_tensor); + } + if (gtl::InsertIfNotPresent( + &signature_feed_nodes, + NodeName(input.coo_sparse().dense_shape_tensor_name()))) { + Tensor dense_shape_tensor(DT_INT64, shape_1d); + InitializeTensor(input.dtype(), &dense_shape_tensor); + new_item->feed.emplace_back( + NodeName(input.coo_sparse().dense_shape_tensor_name()), + dense_shape_tensor); } - - new_item->feed.emplace_back( - NodeName(input.coo_sparse().values_tensor_name()), - Tensor(input.dtype(), shape_1d)); - new_item->feed.emplace_back( - NodeName(input.coo_sparse().indices_tensor_name()), - Tensor(DT_INT64, shape_2d)); - new_item->feed.emplace_back( - NodeName(input.coo_sparse().dense_shape_tensor_name()), - Tensor(DT_INT64, shape_1d)); } else { - new_item->feed.emplace_back( - NodeName(input.name()), - Tensor(input.dtype(), input.tensor_shape())); + if (gtl::InsertIfNotPresent(&signature_feed_nodes, + NodeName(input.name()))) { + TensorShape shape; + TensorShapeProto shape_proto; + Status s = ReplaceUnknownShapeDim(cfg, input.tensor_shape(), + &shape_proto, &shape); + if (!s.ok()) { + LOG(ERROR) << "Invalid shape for signature input " << input.name() + << ": " << s << ", skipping this input"; + return nullptr; + } + + Tensor fake_input(input.dtype(), shape); + InitializeTensor(input.dtype(), &fake_input); + new_item->feed.emplace_back(NodeName(input.name()), fake_input); + } } } for (const auto& name_and_output : name_and_signature.second.outputs()) { const TensorInfo& output = name_and_output.second; if (output.has_coo_sparse()) { - new_item->fetch.push_back( - NodeName(output.coo_sparse().values_tensor_name())); - new_item->fetch.push_back( - NodeName(output.coo_sparse().indices_tensor_name())); - new_item->fetch.push_back( - NodeName(output.coo_sparse().dense_shape_tensor_name())); + if (gtl::InsertIfNotPresent( + &signature_fetch_nodes, + NodeName(output.coo_sparse().values_tensor_name()))) { + new_item->fetch.push_back( + NodeName(output.coo_sparse().values_tensor_name())); + } + if (gtl::InsertIfNotPresent( + &signature_fetch_nodes, + NodeName(output.coo_sparse().indices_tensor_name()))) { + new_item->fetch.push_back( + NodeName(output.coo_sparse().indices_tensor_name())); + } + if (gtl::InsertIfNotPresent( + &signature_fetch_nodes, + NodeName(output.coo_sparse().dense_shape_tensor_name()))) { + new_item->fetch.push_back( + NodeName(output.coo_sparse().dense_shape_tensor_name())); + } } else { - new_item->fetch.push_back(NodeName(output.name())); + if (gtl::InsertIfNotPresent(&signature_fetch_nodes, + NodeName(output.name()))) { + new_item->fetch.push_back(NodeName(output.name())); + } } } } @@ -377,20 +443,8 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef( // shape is not empty if the shape is partially defined. TensorShape shape; TensorShapeProto shape_proto; - std::vector<int32> dims; - for (const auto& dim_proto : node.attr().at("shape").shape().dim()) { - if (cfg.placeholder_unknown_output_shape_dim >= 0 && - dim_proto.size() == -1) { - dims.push_back(cfg.placeholder_unknown_output_shape_dim); - shape_proto.add_dim()->set_size( - cfg.placeholder_unknown_output_shape_dim); - } else { - dims.push_back(std::max<int32>(1, dim_proto.size())); - shape_proto.add_dim()->set_size(dim_proto.size()); - } - } - Status make_shape_status = - TensorShapeUtils::MakeShape(dims.data(), dims.size(), &shape); + Status make_shape_status = ReplaceUnknownShapeDim( + cfg, node.attr().at("shape").shape(), &shape_proto, &shape); if (!make_shape_status.ok()) { LOG(ERROR) << "Invalid shape for placeholder " << node.name() << ": " << make_shape_status << ", skipping this input"; @@ -430,7 +484,9 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef( if (cfg.feed_nodes.empty()) { // No specific feed nodes were given. Assume all placeholders are fed. - new_item->feed.emplace_back(node.name(), fake_input); + if (signature_feed_nodes.count(node.name()) == 0) { + new_item->feed.emplace_back(node.name(), fake_input); + } } else if (cfg.feed_nodes.count(node.name()) > 0) { // If specific feed nodes were given, only update their tensors. auto it = find_if(new_item->feed.begin(), new_item->feed.end(), |