aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yao Zhang <yaozhang@google.com>2017-04-01 10:08:08 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-01 11:30:52 -0700
commit4fe6860e3a5f33eb06c5637b38dc1ab3df78acb3 (patch)
treeaf2c324d5b923be152306888ae386fe7e17ccd75
parent01fcc46c16d9965b902e4c902a5a8d827a1a5537 (diff)
Don't fold nodes that have no outgoing edges.
Change: 151919765
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc14
1 files changed, 11 insertions, 3 deletions
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index 49891e2a78..7cddedef2e 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -130,7 +130,6 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const {
if (!status.ok()) {
return false;
}
-
if (op_def->is_stateful()) {
return false;
}
@@ -144,6 +143,15 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const {
return false;
}
+ // No need to (and don't) fold nodes that have no outgoing edges. Such nodes
+ // could be introduced by an earlier constant folding pass and are preserved
+ // in case users want to fetch their values; re-processing them would
+ // lead to an error of adding a duplicated node to graph.
+ auto outputs = node_map_->GetOutputs(node.name());
+ if (outputs.empty()) {
+ return false;
+ }
+
for (const auto& input : node.input()) {
bool is_const = IsConst(*node_map_->GetNode(input));
if (!is_const) {
@@ -224,8 +232,7 @@ Status ConstantFolding::EvaluateOneFoldable(const NodeDef& node,
Status(error::INVALID_ARGUMENT, "Expected at least one output.");
}
for (int i = 0; i < output_tensors.size(); i++) {
- string node_name = strings::StrCat(
- AddPrefixToNodeName(node.name(), kConstantFoldingConst));
+ string node_name = AddPrefixToNodeName(node.name(), kConstantFoldingConst);
if (output_tensors.size() > 1) {
node_name = strings::StrCat(node_name, "-", i);
}
@@ -299,6 +306,7 @@ Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item,
nodes_to_preserve_.insert(NodeName(node));
}
device_.reset(new DeviceSimple());
+ *output = GraphDef();
TF_RETURN_IF_ERROR(FoldGraph(output));
LOG(INFO) << "Optimized graph size: " << output->node_size();
return Status::OK();