diff options
author | 2018-03-13 11:27:46 -0700 | |
---|---|---|
committer | 2018-03-13 11:32:37 -0700 | |
commit | 7d02968ce04d9576f152e6cfd0c88da096e862a4 (patch) | |
tree | cad0502ef74c0e7fcba81fb579eed3d3e3d71e56 /tensorflow/core | |
parent | ea9e65c94ad71ca86d2be91c4109c62269b42cf8 (diff) |
Replace the unknown dimension of signature input when building grappler items.
Fix the bug where same feed nodes or fetch nodes would be added more than once.
PiperOrigin-RevId: 188902101
Diffstat (limited to 'tensorflow/core')
-rw-r--r-- | tensorflow/core/grappler/grappler_item_builder.cc | 144 | ||||
-rw-r--r-- | tensorflow/core/grappler/grappler_item_builder_test.cc | 51 |
2 files changed, 151 insertions, 44 deletions
diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc index 04c7dae30b..d7b300321a 100644 --- a/tensorflow/core/grappler/grappler_item_builder.cc +++ b/tensorflow/core/grappler/grappler_item_builder.cc @@ -38,6 +38,7 @@ limitations under the License. #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/optimizers/model_pruner.h" #include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/protobuf_internal.h" #include "tensorflow/core/protobuf/meta_graph.pb.h" @@ -152,6 +153,27 @@ Status PruneGraph(GrapplerItem* item) { return Status::OK(); } +// Replace any unknown dimensions in a shape with +// cfg.placeholder_unknown_output_shape_dim if it is no less than 0. +Status ReplaceUnknownShapeDim(const ItemConfig& cfg, + const TensorShapeProto& shape_pb_in, + TensorShapeProto* shape_pb_out, + TensorShape* shape_out) { + std::vector<int32> dims; + for (const auto& dim_proto : shape_pb_in.dim()) { + if (cfg.placeholder_unknown_output_shape_dim >= 0 && + dim_proto.size() == -1) { + dims.push_back(cfg.placeholder_unknown_output_shape_dim); + shape_pb_out->add_dim()->set_size( + cfg.placeholder_unknown_output_shape_dim); + } else { + dims.push_back(std::max<int32>(1, dim_proto.size())); + shape_pb_out->add_dim()->set_size(dim_proto.size()); + } + } + return TensorShapeUtils::MakeShape(dims.data(), dims.size(), shape_out); +} + } // namespace // static @@ -181,48 +203,92 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef( } } - // Detect feed and fetch nodes from signature defs. + // Detect feed and fetch nodes from signature defs. Signatures may share same + // inputs or outputs. + std::unordered_set<string> signature_feed_nodes; + std::unordered_set<string> signature_fetch_nodes; for (const auto& name_and_signature : meta_graph.signature_def()) { for (const auto& name_and_input : name_and_signature.second.inputs()) { const TensorInfo& input = name_and_input.second; if (input.has_coo_sparse()) { // Define the shapes following the comment of CooSparse. - PartialTensorShape partial_shape_1d({-1}); - PartialTensorShape partial_shape_2d({-1, -1}); - TensorShape shape_1d; - TensorShape shape_2d; - if (!partial_shape_1d.AsTensorShape(&shape_1d) || - !partial_shape_2d.AsTensorShape(&shape_2d)) { - LOG(ERROR) << "Internal error when constructing tensor shapes."; - return nullptr; + // TODO(yuefengz): we probably want to use different dim values for the + // three tensors of a SparseTensor. + int64 dim = std::max(1, cfg.placeholder_unknown_output_shape_dim); + TensorShape shape_1d({dim}); + TensorShape shape_2d({dim, dim}); + + if (gtl::InsertIfNotPresent( + &signature_feed_nodes, + NodeName(input.coo_sparse().values_tensor_name()))) { + Tensor value_tensor(input.dtype(), shape_1d); + InitializeTensor(input.dtype(), &value_tensor); + new_item->feed.emplace_back( + NodeName(input.coo_sparse().values_tensor_name()), value_tensor); + } + if (gtl::InsertIfNotPresent( + &signature_feed_nodes, + NodeName(input.coo_sparse().indices_tensor_name()))) { + Tensor indices_tensor(DT_INT64, shape_2d); + InitializeTensor(input.dtype(), &indices_tensor); + new_item->feed.emplace_back( + NodeName(input.coo_sparse().indices_tensor_name()), + indices_tensor); + } + if (gtl::InsertIfNotPresent( + &signature_feed_nodes, + NodeName(input.coo_sparse().dense_shape_tensor_name()))) { + Tensor dense_shape_tensor(DT_INT64, shape_1d); + InitializeTensor(input.dtype(), &dense_shape_tensor); + new_item->feed.emplace_back( + NodeName(input.coo_sparse().dense_shape_tensor_name()), + dense_shape_tensor); } - - new_item->feed.emplace_back( - NodeName(input.coo_sparse().values_tensor_name()), - Tensor(input.dtype(), shape_1d)); - new_item->feed.emplace_back( - NodeName(input.coo_sparse().indices_tensor_name()), - Tensor(DT_INT64, shape_2d)); - new_item->feed.emplace_back( - NodeName(input.coo_sparse().dense_shape_tensor_name()), - Tensor(DT_INT64, shape_1d)); } else { - new_item->feed.emplace_back( - NodeName(input.name()), - Tensor(input.dtype(), input.tensor_shape())); + if (gtl::InsertIfNotPresent(&signature_feed_nodes, + NodeName(input.name()))) { + TensorShape shape; + TensorShapeProto shape_proto; + Status s = ReplaceUnknownShapeDim(cfg, input.tensor_shape(), + &shape_proto, &shape); + if (!s.ok()) { + LOG(ERROR) << "Invalid shape for signature input " << input.name() + << ": " << s << ", skipping this input"; + return nullptr; + } + + Tensor fake_input(input.dtype(), shape); + InitializeTensor(input.dtype(), &fake_input); + new_item->feed.emplace_back(NodeName(input.name()), fake_input); + } } } for (const auto& name_and_output : name_and_signature.second.outputs()) { const TensorInfo& output = name_and_output.second; if (output.has_coo_sparse()) { - new_item->fetch.push_back( - NodeName(output.coo_sparse().values_tensor_name())); - new_item->fetch.push_back( - NodeName(output.coo_sparse().indices_tensor_name())); - new_item->fetch.push_back( - NodeName(output.coo_sparse().dense_shape_tensor_name())); + if (gtl::InsertIfNotPresent( + &signature_fetch_nodes, + NodeName(output.coo_sparse().values_tensor_name()))) { + new_item->fetch.push_back( + NodeName(output.coo_sparse().values_tensor_name())); + } + if (gtl::InsertIfNotPresent( + &signature_fetch_nodes, + NodeName(output.coo_sparse().indices_tensor_name()))) { + new_item->fetch.push_back( + NodeName(output.coo_sparse().indices_tensor_name())); + } + if (gtl::InsertIfNotPresent( + &signature_fetch_nodes, + NodeName(output.coo_sparse().dense_shape_tensor_name()))) { + new_item->fetch.push_back( + NodeName(output.coo_sparse().dense_shape_tensor_name())); + } } else { - new_item->fetch.push_back(NodeName(output.name())); + if (gtl::InsertIfNotPresent(&signature_fetch_nodes, + NodeName(output.name()))) { + new_item->fetch.push_back(NodeName(output.name())); + } } } } @@ -377,20 +443,8 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef( // shape is not empty if the shape is partially defined. TensorShape shape; TensorShapeProto shape_proto; - std::vector<int32> dims; - for (const auto& dim_proto : node.attr().at("shape").shape().dim()) { - if (cfg.placeholder_unknown_output_shape_dim >= 0 && - dim_proto.size() == -1) { - dims.push_back(cfg.placeholder_unknown_output_shape_dim); - shape_proto.add_dim()->set_size( - cfg.placeholder_unknown_output_shape_dim); - } else { - dims.push_back(std::max<int32>(1, dim_proto.size())); - shape_proto.add_dim()->set_size(dim_proto.size()); - } - } - Status make_shape_status = - TensorShapeUtils::MakeShape(dims.data(), dims.size(), &shape); + Status make_shape_status = ReplaceUnknownShapeDim( + cfg, node.attr().at("shape").shape(), &shape_proto, &shape); if (!make_shape_status.ok()) { LOG(ERROR) << "Invalid shape for placeholder " << node.name() << ": " << make_shape_status << ", skipping this input"; @@ -430,7 +484,9 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef( if (cfg.feed_nodes.empty()) { // No specific feed nodes were given. Assume all placeholders are fed. - new_item->feed.emplace_back(node.name(), fake_input); + if (signature_feed_nodes.count(node.name()) == 0) { + new_item->feed.emplace_back(node.name(), fake_input); + } } else if (cfg.feed_nodes.count(node.name()) > 0) { // If specific feed nodes were given, only update their tensors. auto it = find_if(new_item->feed.begin(), new_item->feed.end(), diff --git a/tensorflow/core/grappler/grappler_item_builder_test.cc b/tensorflow/core/grappler/grappler_item_builder_test.cc index ada90925a4..29488e4b7e 100644 --- a/tensorflow/core/grappler/grappler_item_builder_test.cc +++ b/tensorflow/core/grappler/grappler_item_builder_test.cc @@ -319,10 +319,22 @@ TEST_F(GrapplerItemBuilderTest, FromGraphWithSignatureDef) { (*serving_signature.mutable_outputs())["output"] = output; (*meta_graph.mutable_signature_def())["serving"] = serving_signature; + // It should be able to dedup the input and output with same names. + TensorInfo input2, output2; + input.set_name("x"); + input.set_dtype(DT_FLOAT); + output.set_name("z"); + SignatureDef serving_signature2; + (*serving_signature.mutable_inputs())["input2"] = input2; + (*serving_signature.mutable_outputs())["output2"] = output2; + (*meta_graph.mutable_signature_def())["serving2"] = serving_signature2; + std::unique_ptr<GrapplerItem> item = GrapplerItemFromMetaGraphDef("0", meta_graph, ItemConfig()); ASSERT_TRUE(item != nullptr); + EXPECT_EQ(item->feed.size(), 1); + EXPECT_EQ(item->fetch.size(), 1); EXPECT_EQ(item->feed[0].first, "x"); EXPECT_EQ(item->fetch[0], "z"); } @@ -354,6 +366,45 @@ TEST_F(GrapplerItemBuilderTest, FromGraphWithIncompleteSignatureDef) { ASSERT_TRUE(item == nullptr); } +TEST_F(GrapplerItemBuilderTest, FromGraphWithUnknownDimInSignatureInput) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto shape_1d = PartialTensorShape({-1}); + auto x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT, + ops::Placeholder::Shape(shape_1d)); + auto y = ops::Const(s.WithOpName("y"), static_cast<float>(1.0)); + auto z = ops::Add(s.WithOpName("z"), x, y); + + MetaGraphDef meta_graph; + TF_CHECK_OK(s.ToGraphDef(meta_graph.mutable_graph_def())); + + TensorInfo input, output; + input.set_name("x"); + input.set_dtype(DT_FLOAT); + shape_1d.AsProto(input.mutable_tensor_shape()); + output.set_name("z"); + + SignatureDef serving_signature; + (*serving_signature.mutable_inputs())["input"] = input; + (*serving_signature.mutable_outputs())["output"] = output; + (*meta_graph.mutable_signature_def())["serving"] = serving_signature; + + ItemConfig cfg; + cfg.placeholder_unknown_output_shape_dim = 64; + std::unique_ptr<GrapplerItem> item1 = + GrapplerItemFromMetaGraphDef("0", meta_graph, cfg); + ASSERT_TRUE(item1 != nullptr); + + ASSERT_EQ(item1->feed.size(), 1); + EXPECT_EQ(item1->feed[0].second.NumElements(), 64); + + std::unique_ptr<GrapplerItem> item2 = + GrapplerItemFromMetaGraphDef("0", meta_graph, ItemConfig()); + ASSERT_TRUE(item2 != nullptr); + + ASSERT_EQ(item2->feed.size(), 1); + EXPECT_EQ(item2->feed[0].second.NumElements(), 1); +} + } // namespace } // namespace grappler } // namespace tensorflow |