aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/grappler_item_builder.cc
diff options
context:
space:
mode:
authorGravatar Yuefeng Zhou <yuefengz@google.com>2018-03-13 11:27:46 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-13 11:32:37 -0700
commit7d02968ce04d9576f152e6cfd0c88da096e862a4 (patch)
treecad0502ef74c0e7fcba81fb579eed3d3e3d71e56 /tensorflow/core/grappler/grappler_item_builder.cc
parentea9e65c94ad71ca86d2be91c4109c62269b42cf8 (diff)
Replace the unknown dimension of signature input when building grappler items.
Fix the bug where same feed nodes or fetch nodes would be added more than once. PiperOrigin-RevId: 188902101
Diffstat (limited to 'tensorflow/core/grappler/grappler_item_builder.cc')
-rw-r--r--tensorflow/core/grappler/grappler_item_builder.cc144
1 files changed, 100 insertions, 44 deletions
diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc
index 04c7dae30b..d7b300321a 100644
--- a/tensorflow/core/grappler/grappler_item_builder.cc
+++ b/tensorflow/core/grappler/grappler_item_builder.cc
@@ -38,6 +38,7 @@ limitations under the License.
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/protobuf_internal.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
@@ -152,6 +153,27 @@ Status PruneGraph(GrapplerItem* item) {
return Status::OK();
}
+// Replace any unknown dimensions in a shape with
+// cfg.placeholder_unknown_output_shape_dim if it is no less than 0.
+Status ReplaceUnknownShapeDim(const ItemConfig& cfg,
+ const TensorShapeProto& shape_pb_in,
+ TensorShapeProto* shape_pb_out,
+ TensorShape* shape_out) {
+ std::vector<int32> dims;
+ for (const auto& dim_proto : shape_pb_in.dim()) {
+ if (cfg.placeholder_unknown_output_shape_dim >= 0 &&
+ dim_proto.size() == -1) {
+ dims.push_back(cfg.placeholder_unknown_output_shape_dim);
+ shape_pb_out->add_dim()->set_size(
+ cfg.placeholder_unknown_output_shape_dim);
+ } else {
+ dims.push_back(std::max<int32>(1, dim_proto.size()));
+ shape_pb_out->add_dim()->set_size(dim_proto.size());
+ }
+ }
+ return TensorShapeUtils::MakeShape(dims.data(), dims.size(), shape_out);
+}
+
} // namespace
// static
@@ -181,48 +203,92 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
}
}
- // Detect feed and fetch nodes from signature defs.
+ // Detect feed and fetch nodes from signature defs. Signatures may share same
+ // inputs or outputs.
+ std::unordered_set<string> signature_feed_nodes;
+ std::unordered_set<string> signature_fetch_nodes;
for (const auto& name_and_signature : meta_graph.signature_def()) {
for (const auto& name_and_input : name_and_signature.second.inputs()) {
const TensorInfo& input = name_and_input.second;
if (input.has_coo_sparse()) {
// Define the shapes following the comment of CooSparse.
- PartialTensorShape partial_shape_1d({-1});
- PartialTensorShape partial_shape_2d({-1, -1});
- TensorShape shape_1d;
- TensorShape shape_2d;
- if (!partial_shape_1d.AsTensorShape(&shape_1d) ||
- !partial_shape_2d.AsTensorShape(&shape_2d)) {
- LOG(ERROR) << "Internal error when constructing tensor shapes.";
- return nullptr;
+ // TODO(yuefengz): we probably want to use different dim values for the
+ // three tensors of a SparseTensor.
+ int64 dim = std::max(1, cfg.placeholder_unknown_output_shape_dim);
+ TensorShape shape_1d({dim});
+ TensorShape shape_2d({dim, dim});
+
+ if (gtl::InsertIfNotPresent(
+ &signature_feed_nodes,
+ NodeName(input.coo_sparse().values_tensor_name()))) {
+ Tensor value_tensor(input.dtype(), shape_1d);
+ InitializeTensor(input.dtype(), &value_tensor);
+ new_item->feed.emplace_back(
+ NodeName(input.coo_sparse().values_tensor_name()), value_tensor);
+ }
+ if (gtl::InsertIfNotPresent(
+ &signature_feed_nodes,
+ NodeName(input.coo_sparse().indices_tensor_name()))) {
+ Tensor indices_tensor(DT_INT64, shape_2d);
+ InitializeTensor(input.dtype(), &indices_tensor);
+ new_item->feed.emplace_back(
+ NodeName(input.coo_sparse().indices_tensor_name()),
+ indices_tensor);
+ }
+ if (gtl::InsertIfNotPresent(
+ &signature_feed_nodes,
+ NodeName(input.coo_sparse().dense_shape_tensor_name()))) {
+ Tensor dense_shape_tensor(DT_INT64, shape_1d);
+ InitializeTensor(input.dtype(), &dense_shape_tensor);
+ new_item->feed.emplace_back(
+ NodeName(input.coo_sparse().dense_shape_tensor_name()),
+ dense_shape_tensor);
}
-
- new_item->feed.emplace_back(
- NodeName(input.coo_sparse().values_tensor_name()),
- Tensor(input.dtype(), shape_1d));
- new_item->feed.emplace_back(
- NodeName(input.coo_sparse().indices_tensor_name()),
- Tensor(DT_INT64, shape_2d));
- new_item->feed.emplace_back(
- NodeName(input.coo_sparse().dense_shape_tensor_name()),
- Tensor(DT_INT64, shape_1d));
} else {
- new_item->feed.emplace_back(
- NodeName(input.name()),
- Tensor(input.dtype(), input.tensor_shape()));
+ if (gtl::InsertIfNotPresent(&signature_feed_nodes,
+ NodeName(input.name()))) {
+ TensorShape shape;
+ TensorShapeProto shape_proto;
+ Status s = ReplaceUnknownShapeDim(cfg, input.tensor_shape(),
+ &shape_proto, &shape);
+ if (!s.ok()) {
+ LOG(ERROR) << "Invalid shape for signature input " << input.name()
+ << ": " << s << ", skipping this input";
+ return nullptr;
+ }
+
+ Tensor fake_input(input.dtype(), shape);
+ InitializeTensor(input.dtype(), &fake_input);
+ new_item->feed.emplace_back(NodeName(input.name()), fake_input);
+ }
}
}
for (const auto& name_and_output : name_and_signature.second.outputs()) {
const TensorInfo& output = name_and_output.second;
if (output.has_coo_sparse()) {
- new_item->fetch.push_back(
- NodeName(output.coo_sparse().values_tensor_name()));
- new_item->fetch.push_back(
- NodeName(output.coo_sparse().indices_tensor_name()));
- new_item->fetch.push_back(
- NodeName(output.coo_sparse().dense_shape_tensor_name()));
+ if (gtl::InsertIfNotPresent(
+ &signature_fetch_nodes,
+ NodeName(output.coo_sparse().values_tensor_name()))) {
+ new_item->fetch.push_back(
+ NodeName(output.coo_sparse().values_tensor_name()));
+ }
+ if (gtl::InsertIfNotPresent(
+ &signature_fetch_nodes,
+ NodeName(output.coo_sparse().indices_tensor_name()))) {
+ new_item->fetch.push_back(
+ NodeName(output.coo_sparse().indices_tensor_name()));
+ }
+ if (gtl::InsertIfNotPresent(
+ &signature_fetch_nodes,
+ NodeName(output.coo_sparse().dense_shape_tensor_name()))) {
+ new_item->fetch.push_back(
+ NodeName(output.coo_sparse().dense_shape_tensor_name()));
+ }
} else {
- new_item->fetch.push_back(NodeName(output.name()));
+ if (gtl::InsertIfNotPresent(&signature_fetch_nodes,
+ NodeName(output.name()))) {
+ new_item->fetch.push_back(NodeName(output.name()));
+ }
}
}
}
@@ -377,20 +443,8 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
// shape is not empty if the shape is partially defined.
TensorShape shape;
TensorShapeProto shape_proto;
- std::vector<int32> dims;
- for (const auto& dim_proto : node.attr().at("shape").shape().dim()) {
- if (cfg.placeholder_unknown_output_shape_dim >= 0 &&
- dim_proto.size() == -1) {
- dims.push_back(cfg.placeholder_unknown_output_shape_dim);
- shape_proto.add_dim()->set_size(
- cfg.placeholder_unknown_output_shape_dim);
- } else {
- dims.push_back(std::max<int32>(1, dim_proto.size()));
- shape_proto.add_dim()->set_size(dim_proto.size());
- }
- }
- Status make_shape_status =
- TensorShapeUtils::MakeShape(dims.data(), dims.size(), &shape);
+ Status make_shape_status = ReplaceUnknownShapeDim(
+ cfg, node.attr().at("shape").shape(), &shape_proto, &shape);
if (!make_shape_status.ok()) {
LOG(ERROR) << "Invalid shape for placeholder " << node.name() << ": "
<< make_shape_status << ", skipping this input";
@@ -430,7 +484,9 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
if (cfg.feed_nodes.empty()) {
// No specific feed nodes were given. Assume all placeholders are fed.
- new_item->feed.emplace_back(node.name(), fake_input);
+ if (signature_feed_nodes.count(node.name()) == 0) {
+ new_item->feed.emplace_back(node.name(), fake_input);
+ }
} else if (cfg.feed_nodes.count(node.name()) > 0) {
// If specific feed nodes were given, only update their tensors.
auto it = find_if(new_item->feed.begin(), new_item->feed.end(),