aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
authorGravatar Yuefeng Zhou <yuefengz@google.com>2018-03-13 11:27:46 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-13 11:32:37 -0700
commit7d02968ce04d9576f152e6cfd0c88da096e862a4 (patch)
treecad0502ef74c0e7fcba81fb579eed3d3e3d71e56 /tensorflow/core
parentea9e65c94ad71ca86d2be91c4109c62269b42cf8 (diff)
Replace the unknown dimension of signature input when building grappler items.
Fix the bug where same feed nodes or fetch nodes would be added more than once. PiperOrigin-RevId: 188902101
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/grappler/grappler_item_builder.cc144
-rw-r--r--tensorflow/core/grappler/grappler_item_builder_test.cc51
2 files changed, 151 insertions, 44 deletions
diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc
index 04c7dae30b..d7b300321a 100644
--- a/tensorflow/core/grappler/grappler_item_builder.cc
+++ b/tensorflow/core/grappler/grappler_item_builder.cc
@@ -38,6 +38,7 @@ limitations under the License.
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/protobuf_internal.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
@@ -152,6 +153,27 @@ Status PruneGraph(GrapplerItem* item) {
return Status::OK();
}
+// Replace any unknown dimensions in a shape with
+// cfg.placeholder_unknown_output_shape_dim if it is no less than 0.
+Status ReplaceUnknownShapeDim(const ItemConfig& cfg,
+ const TensorShapeProto& shape_pb_in,
+ TensorShapeProto* shape_pb_out,
+ TensorShape* shape_out) {
+ std::vector<int32> dims;
+ for (const auto& dim_proto : shape_pb_in.dim()) {
+ if (cfg.placeholder_unknown_output_shape_dim >= 0 &&
+ dim_proto.size() == -1) {
+ dims.push_back(cfg.placeholder_unknown_output_shape_dim);
+ shape_pb_out->add_dim()->set_size(
+ cfg.placeholder_unknown_output_shape_dim);
+ } else {
+ dims.push_back(std::max<int32>(1, dim_proto.size()));
+ shape_pb_out->add_dim()->set_size(dim_proto.size());
+ }
+ }
+ return TensorShapeUtils::MakeShape(dims.data(), dims.size(), shape_out);
+}
+
} // namespace
// static
@@ -181,48 +203,92 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
}
}
- // Detect feed and fetch nodes from signature defs.
+ // Detect feed and fetch nodes from signature defs. Signatures may share same
+ // inputs or outputs.
+ std::unordered_set<string> signature_feed_nodes;
+ std::unordered_set<string> signature_fetch_nodes;
for (const auto& name_and_signature : meta_graph.signature_def()) {
for (const auto& name_and_input : name_and_signature.second.inputs()) {
const TensorInfo& input = name_and_input.second;
if (input.has_coo_sparse()) {
// Define the shapes following the comment of CooSparse.
- PartialTensorShape partial_shape_1d({-1});
- PartialTensorShape partial_shape_2d({-1, -1});
- TensorShape shape_1d;
- TensorShape shape_2d;
- if (!partial_shape_1d.AsTensorShape(&shape_1d) ||
- !partial_shape_2d.AsTensorShape(&shape_2d)) {
- LOG(ERROR) << "Internal error when constructing tensor shapes.";
- return nullptr;
+ // TODO(yuefengz): we probably want to use different dim values for the
+ // three tensors of a SparseTensor.
+ int64 dim = std::max(1, cfg.placeholder_unknown_output_shape_dim);
+ TensorShape shape_1d({dim});
+ TensorShape shape_2d({dim, dim});
+
+ if (gtl::InsertIfNotPresent(
+ &signature_feed_nodes,
+ NodeName(input.coo_sparse().values_tensor_name()))) {
+ Tensor value_tensor(input.dtype(), shape_1d);
+ InitializeTensor(input.dtype(), &value_tensor);
+ new_item->feed.emplace_back(
+ NodeName(input.coo_sparse().values_tensor_name()), value_tensor);
+ }
+ if (gtl::InsertIfNotPresent(
+ &signature_feed_nodes,
+ NodeName(input.coo_sparse().indices_tensor_name()))) {
+ Tensor indices_tensor(DT_INT64, shape_2d);
+ InitializeTensor(input.dtype(), &indices_tensor);
+ new_item->feed.emplace_back(
+ NodeName(input.coo_sparse().indices_tensor_name()),
+ indices_tensor);
+ }
+ if (gtl::InsertIfNotPresent(
+ &signature_feed_nodes,
+ NodeName(input.coo_sparse().dense_shape_tensor_name()))) {
+ Tensor dense_shape_tensor(DT_INT64, shape_1d);
+ InitializeTensor(input.dtype(), &dense_shape_tensor);
+ new_item->feed.emplace_back(
+ NodeName(input.coo_sparse().dense_shape_tensor_name()),
+ dense_shape_tensor);
}
-
- new_item->feed.emplace_back(
- NodeName(input.coo_sparse().values_tensor_name()),
- Tensor(input.dtype(), shape_1d));
- new_item->feed.emplace_back(
- NodeName(input.coo_sparse().indices_tensor_name()),
- Tensor(DT_INT64, shape_2d));
- new_item->feed.emplace_back(
- NodeName(input.coo_sparse().dense_shape_tensor_name()),
- Tensor(DT_INT64, shape_1d));
} else {
- new_item->feed.emplace_back(
- NodeName(input.name()),
- Tensor(input.dtype(), input.tensor_shape()));
+ if (gtl::InsertIfNotPresent(&signature_feed_nodes,
+ NodeName(input.name()))) {
+ TensorShape shape;
+ TensorShapeProto shape_proto;
+ Status s = ReplaceUnknownShapeDim(cfg, input.tensor_shape(),
+ &shape_proto, &shape);
+ if (!s.ok()) {
+ LOG(ERROR) << "Invalid shape for signature input " << input.name()
+ << ": " << s << ", skipping this input";
+ return nullptr;
+ }
+
+ Tensor fake_input(input.dtype(), shape);
+ InitializeTensor(input.dtype(), &fake_input);
+ new_item->feed.emplace_back(NodeName(input.name()), fake_input);
+ }
}
}
for (const auto& name_and_output : name_and_signature.second.outputs()) {
const TensorInfo& output = name_and_output.second;
if (output.has_coo_sparse()) {
- new_item->fetch.push_back(
- NodeName(output.coo_sparse().values_tensor_name()));
- new_item->fetch.push_back(
- NodeName(output.coo_sparse().indices_tensor_name()));
- new_item->fetch.push_back(
- NodeName(output.coo_sparse().dense_shape_tensor_name()));
+ if (gtl::InsertIfNotPresent(
+ &signature_fetch_nodes,
+ NodeName(output.coo_sparse().values_tensor_name()))) {
+ new_item->fetch.push_back(
+ NodeName(output.coo_sparse().values_tensor_name()));
+ }
+ if (gtl::InsertIfNotPresent(
+ &signature_fetch_nodes,
+ NodeName(output.coo_sparse().indices_tensor_name()))) {
+ new_item->fetch.push_back(
+ NodeName(output.coo_sparse().indices_tensor_name()));
+ }
+ if (gtl::InsertIfNotPresent(
+ &signature_fetch_nodes,
+ NodeName(output.coo_sparse().dense_shape_tensor_name()))) {
+ new_item->fetch.push_back(
+ NodeName(output.coo_sparse().dense_shape_tensor_name()));
+ }
} else {
- new_item->fetch.push_back(NodeName(output.name()));
+ if (gtl::InsertIfNotPresent(&signature_fetch_nodes,
+ NodeName(output.name()))) {
+ new_item->fetch.push_back(NodeName(output.name()));
+ }
}
}
}
@@ -377,20 +443,8 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
// shape is not empty if the shape is partially defined.
TensorShape shape;
TensorShapeProto shape_proto;
- std::vector<int32> dims;
- for (const auto& dim_proto : node.attr().at("shape").shape().dim()) {
- if (cfg.placeholder_unknown_output_shape_dim >= 0 &&
- dim_proto.size() == -1) {
- dims.push_back(cfg.placeholder_unknown_output_shape_dim);
- shape_proto.add_dim()->set_size(
- cfg.placeholder_unknown_output_shape_dim);
- } else {
- dims.push_back(std::max<int32>(1, dim_proto.size()));
- shape_proto.add_dim()->set_size(dim_proto.size());
- }
- }
- Status make_shape_status =
- TensorShapeUtils::MakeShape(dims.data(), dims.size(), &shape);
+ Status make_shape_status = ReplaceUnknownShapeDim(
+ cfg, node.attr().at("shape").shape(), &shape_proto, &shape);
if (!make_shape_status.ok()) {
LOG(ERROR) << "Invalid shape for placeholder " << node.name() << ": "
<< make_shape_status << ", skipping this input";
@@ -430,7 +484,9 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
if (cfg.feed_nodes.empty()) {
// No specific feed nodes were given. Assume all placeholders are fed.
- new_item->feed.emplace_back(node.name(), fake_input);
+ if (signature_feed_nodes.count(node.name()) == 0) {
+ new_item->feed.emplace_back(node.name(), fake_input);
+ }
} else if (cfg.feed_nodes.count(node.name()) > 0) {
// If specific feed nodes were given, only update their tensors.
auto it = find_if(new_item->feed.begin(), new_item->feed.end(),
diff --git a/tensorflow/core/grappler/grappler_item_builder_test.cc b/tensorflow/core/grappler/grappler_item_builder_test.cc
index ada90925a4..29488e4b7e 100644
--- a/tensorflow/core/grappler/grappler_item_builder_test.cc
+++ b/tensorflow/core/grappler/grappler_item_builder_test.cc
@@ -319,10 +319,22 @@ TEST_F(GrapplerItemBuilderTest, FromGraphWithSignatureDef) {
(*serving_signature.mutable_outputs())["output"] = output;
(*meta_graph.mutable_signature_def())["serving"] = serving_signature;
+ // It should be able to dedup the input and output with same names.
+ TensorInfo input2, output2;
+ input.set_name("x");
+ input.set_dtype(DT_FLOAT);
+ output.set_name("z");
+ SignatureDef serving_signature2;
+ (*serving_signature.mutable_inputs())["input2"] = input2;
+ (*serving_signature.mutable_outputs())["output2"] = output2;
+ (*meta_graph.mutable_signature_def())["serving2"] = serving_signature2;
+
std::unique_ptr<GrapplerItem> item =
GrapplerItemFromMetaGraphDef("0", meta_graph, ItemConfig());
ASSERT_TRUE(item != nullptr);
+ EXPECT_EQ(item->feed.size(), 1);
+ EXPECT_EQ(item->fetch.size(), 1);
EXPECT_EQ(item->feed[0].first, "x");
EXPECT_EQ(item->fetch[0], "z");
}
@@ -354,6 +366,45 @@ TEST_F(GrapplerItemBuilderTest, FromGraphWithIncompleteSignatureDef) {
ASSERT_TRUE(item == nullptr);
}
+TEST_F(GrapplerItemBuilderTest, FromGraphWithUnknownDimInSignatureInput) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto shape_1d = PartialTensorShape({-1});
+ auto x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
+ ops::Placeholder::Shape(shape_1d));
+ auto y = ops::Const(s.WithOpName("y"), static_cast<float>(1.0));
+ auto z = ops::Add(s.WithOpName("z"), x, y);
+
+ MetaGraphDef meta_graph;
+ TF_CHECK_OK(s.ToGraphDef(meta_graph.mutable_graph_def()));
+
+ TensorInfo input, output;
+ input.set_name("x");
+ input.set_dtype(DT_FLOAT);
+ shape_1d.AsProto(input.mutable_tensor_shape());
+ output.set_name("z");
+
+ SignatureDef serving_signature;
+ (*serving_signature.mutable_inputs())["input"] = input;
+ (*serving_signature.mutable_outputs())["output"] = output;
+ (*meta_graph.mutable_signature_def())["serving"] = serving_signature;
+
+ ItemConfig cfg;
+ cfg.placeholder_unknown_output_shape_dim = 64;
+ std::unique_ptr<GrapplerItem> item1 =
+ GrapplerItemFromMetaGraphDef("0", meta_graph, cfg);
+ ASSERT_TRUE(item1 != nullptr);
+
+ ASSERT_EQ(item1->feed.size(), 1);
+ EXPECT_EQ(item1->feed[0].second.NumElements(), 64);
+
+ std::unique_ptr<GrapplerItem> item2 =
+ GrapplerItemFromMetaGraphDef("0", meta_graph, ItemConfig());
+ ASSERT_TRUE(item2 != nullptr);
+
+ ASSERT_EQ(item2->feed.size(), 1);
+ EXPECT_EQ(item2->feed[0].second.NumElements(), 1);
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow