diff options
author | Yao Zhang <yaozhang@google.com> | 2017-04-01 10:08:08 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-04-01 11:30:52 -0700 |
commit | 4fe6860e3a5f33eb06c5637b38dc1ab3df78acb3 (patch) | |
tree | af2c324d5b923be152306888ae386fe7e17ccd75 | |
parent | 01fcc46c16d9965b902e4c902a5a8d827a1a5537 (diff) |
Don't fold nodes that have no outgoing edges.
Change: 151919765
-rw-r--r-- | tensorflow/core/grappler/optimizers/constant_folding.cc | 14 |
1 files changed, 11 insertions, 3 deletions
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 49891e2a78..7cddedef2e 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -130,7 +130,6 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const { if (!status.ok()) { return false; } - if (op_def->is_stateful()) { return false; } @@ -144,6 +143,15 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const { return false; } + // No need to (and don't) fold nodes that have no outgoing edges. Such nodes + // could be introduced by an earlier constant folding pass and are preserved + // in case users want to fetch their values; re-processing them would + // lead to an error of adding a duplicated node to graph. + auto outputs = node_map_->GetOutputs(node.name()); + if (outputs.empty()) { + return false; + } + for (const auto& input : node.input()) { bool is_const = IsConst(*node_map_->GetNode(input)); if (!is_const) { @@ -224,8 +232,7 @@ Status ConstantFolding::EvaluateOneFoldable(const NodeDef& node, Status(error::INVALID_ARGUMENT, "Expected at least one output."); } for (int i = 0; i < output_tensors.size(); i++) { - string node_name = strings::StrCat( - AddPrefixToNodeName(node.name(), kConstantFoldingConst)); + string node_name = AddPrefixToNodeName(node.name(), kConstantFoldingConst); if (output_tensors.size() > 1) { node_name = strings::StrCat(node_name, "-", i); } @@ -299,6 +306,7 @@ Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item, nodes_to_preserve_.insert(NodeName(node)); } device_.reset(new DeviceSimple()); + *output = GraphDef(); TF_RETURN_IF_ERROR(FoldGraph(output)); LOG(INFO) << "Optimized graph size: " << output->node_size(); return Status::OK(); |