diff options
author | 2017-11-29 17:55:53 -0800 | |
---|---|---|
committer | 2017-11-29 17:59:09 -0800 | |
commit | b97585f5d2157b1e0273a4b20a568635fb58ad57 (patch) | |
tree | c62d59c616fc9eb42127a16d842631971fa0ff42 | |
parent | cb4ef362e4a18b3c42a2c90bdad8754d5ead4caf (diff) |
Always leverage shapes inference now that it can handle fed nodes
conservatively.
PiperOrigin-RevId: 177391746
-rw-r--r-- | tensorflow/core/grappler/optimizers/constant_folding.cc | 88 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/constant_folding.h | 6 |
2 files changed, 52 insertions, 42 deletions
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 03eaa4a84a..b5172a4833 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -190,6 +190,14 @@ Status ConvertShapeToConstant(const string& op, const DataType& type, return Status::OK(); } +bool ConstantFolding::IsReallyConstant(const NodeDef& node) const { + if (!IsConstant(node)) { + return false; + } + // If the node is fed it's not constant anymore. + return feed_nodes_.find(node.name()) == feed_nodes_.end(); +} + Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) { // We may add some nodes to the graph to encode control dependencies: there is // no need to process these, so only iterate over the nodes of the input @@ -327,9 +335,9 @@ Status ConstantFolding::MaterializeBroadcastGradientArgs( const NodeDef* shape_node1 = node_map_->GetNode(node.input(0)); const NodeDef* shape_node2 = node_map_->GetNode(node.input(1)); if (shape_node1 == nullptr || - (shape_node1->op() != "Shape" && shape_node1->op() != "Const") || + (shape_node1->op() != "Shape" && !IsReallyConstant(*shape_node1)) || shape_node2 == nullptr || - (shape_node2->op() != "Shape" && shape_node2->op() != "Const")) { + (shape_node2->op() != "Shape" && !IsReallyConstant(*shape_node2))) { return Status::OK(); } int64 min_id = 0; @@ -409,7 +417,7 @@ Status ConstantFolding::MaterializeReductionIndices( return Status::OK(); } const NodeDef* indices = node_map_->GetNode(node->input(1)); - if (!indices || IsConstant(*indices)) { + if (!indices || IsReallyConstant(*indices)) { // The reduction indices are already constant, there's nothing to do. return Status::OK(); } @@ -506,24 +514,23 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const { if (node.input().empty()) { return false; } - // Skips nodes that must be preserved except whitelisted nodes. if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end() && nodes_whitelist_.find(node.name()) == nodes_whitelist_.end()) { return false; } - - // Skips ops that don't benefit from folding. - const string& op = node.op(); - // Skip constants, they're already folded - if (op == "Const") { + // Skip control flow nodes, they can't be folded + if (ModifiesFrameInfo(node)) { return false; } - // Skip constrol flow nodes, they can't be folded - if (op == "Enter" || op == "RefEnter" || op == "Exit" || op == "RefExit" || - op == "NextIteration" || op == "RefNextIteration") { + // Skip constants, they're already folded + if (IsConstant(node)) { return false; } + + // Skips ops that don't benefit from folding. + const string& op = node.op(); + if (op.find("Placeholder") == 0) { return false; } @@ -577,7 +584,7 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const { if (!input_node) { return false; } - bool is_const = IsConstant(*input_node); + bool is_const = IsReallyConstant(*input_node); if (!is_const && !is_merge) { return false; } @@ -703,7 +710,7 @@ Status ConstantFolding::EvaluateOneFoldable(const NodeDef& node, break; } const NodeDef* input_node = node_map_->GetNode(input); - if (!IsConstant(*input_node)) { + if (!IsReallyConstant(*input_node)) { return Status(error::INVALID_ARGUMENT, strings::StrCat("Can't fold ", node.name(), ", its ", input, " isn't constant")); @@ -757,7 +764,7 @@ Status ConstantFolding::FoldNode(NodeDef* node, GraphDef* output_graph) { continue; } NodeDef* input_node = node_map_->GetNode(input); - if (!IsConstant(*input_node)) { + if (!IsReallyConstant(*input_node)) { continue; } bool valid_input = true; @@ -999,7 +1006,7 @@ bool ConstantFolding::IsSimplifiableReduction(const NodeDef& node) const { if (IsReduction(node)) { CHECK_LE(2, node.input_size()); const NodeDef* reductions_indices = node_map_->GetNode(node.input(1)); - if (IsConstant(*reductions_indices)) { + if (IsReallyConstant(*reductions_indices)) { TensorVector output; Status s = EvaluateNode(*reductions_indices, TensorVector(), &output); if (!s.ok()) { @@ -1023,7 +1030,7 @@ bool ConstantFolding::IsSimplifiableReshape( } CHECK_LE(2, node.input_size()); const NodeDef* new_shape = node_map_->GetNode(node.input(1)); - if (!IsConstant(*new_shape)) { + if (!IsReallyConstant(*new_shape)) { return false; } TensorVector outputs; @@ -1074,7 +1081,8 @@ bool ConstantFolding::IsSimplifiableReshape( } Status ConstantFolding::SimplifyGraph(GraphDef* output, - const GraphProperties& properties) { + const GraphProperties& properties, + bool use_shape_info) { for (auto& node : *output->mutable_node()) { if (IsSimplifiableReduction(node)) { // Replace the reduction node with an identity node, that can be further @@ -1099,10 +1107,10 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output, *node.add_input() = input; } } - // It's possible to feed a placeholder with a tensor that doesn't have the - // proper shape, and reshape this tensor later on. Therefore only remove - // reshapes in graphs that don't have placeholders. - if (IsSimplifiableReshape(node, properties)) { + const bool safe_to_use_shapes = + use_shape_info && + (feed_nodes_.empty() || opt_level_ == RewriterConfig::AGGRESSIVE); + if (safe_to_use_shapes && IsSimplifiableReshape(node, properties)) { const NodeDef* new_shape = node_map_->GetNode(node.input(1)); DataType output_type = node.attr().at("T").type(); node.set_op("Identity"); @@ -1141,36 +1149,34 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster, } GraphProperties properties(item); - const bool has_feed = !item.feed.empty(); - bool needs_shapes = !has_feed || opt_level_ == RewriterConfig::AGGRESSIVE; - Status s = errors::Unknown( - "The graph properties are needed but were not initialized"); - if (needs_shapes) { - s = properties.InferStatically(false); - } - - if (!has_feed && s.ok()) { - // Only use static shape information when there is no feed in the - // graph. That's because it's possible to feed a placeholder with a tensor - // of any shape, which could make the static information inconsistent with - // the shapes actually fed. + // It's possible to feed a placeholder with a tensor of any shape: make sure + // that the shape inference deals with this conservatively unless we're in + // aggressive mode. + const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE; + Status s = properties.InferStatically(assume_valid_feeds); + const bool can_use_shape_info = s.ok(); + + if (can_use_shape_info) { TF_RETURN_IF_ERROR(MaterializeShapes(properties)); - } - if (opt_level_ == RewriterConfig::AGGRESSIVE && s.ok()) { - TF_RETURN_IF_ERROR(MaterializeConstants(properties)); + + if (opt_level_ == RewriterConfig::AGGRESSIVE) { + TF_RETURN_IF_ERROR(MaterializeConstants(properties)); + } } TF_RETURN_IF_ERROR(FoldGraph(output)); - if (!has_feed && s.ok()) { - TF_RETURN_IF_ERROR(SimplifyGraph(output, properties)); - } + TF_RETURN_IF_ERROR(SimplifyGraph(output, properties, can_use_shape_info)); + return Status::OK(); } Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item, GraphDef* output) { nodes_to_preserve_ = item.NodesToPreserve(); + for (const auto& feed : item.feed) { + feed_nodes_.insert(NodeName(feed.first)); + } if (cpu_device_ == nullptr) { owned_device_.reset(new DeviceSimple()); diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h index 7c5db2a70f..8af5b5fbe6 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.h +++ b/tensorflow/core/grappler/optimizers/constant_folding.h @@ -51,6 +51,8 @@ class ConstantFolding : public GraphOptimizer { const GraphDef& optimize_output, double result) override; private: + bool IsReallyConstant(const NodeDef& node) const; + Status MaterializeShapes(const GraphProperties& properties); Status MaterializeBroadcastGradientArgs(const NodeDef& node, @@ -75,7 +77,8 @@ class ConstantFolding : public GraphOptimizer { bool IsSimplifiableReduction(const NodeDef& node) const; bool IsSimplifiableReshape(const NodeDef& node, const GraphProperties& properties) const; - Status SimplifyGraph(GraphDef* output, const GraphProperties& properties); + Status SimplifyGraph(GraphDef* output, const GraphProperties& properties, + bool use_shape_info); Status RunOptimizationPass(Cluster* cluster, const GrapplerItem& item, GraphDef* output); @@ -90,6 +93,7 @@ class ConstantFolding : public GraphOptimizer { std::unique_ptr<NodeMap> node_map_; std::unordered_set<string> nodes_to_preserve_; std::unordered_set<string> nodes_whitelist_; + std::unordered_set<string> feed_nodes_; bool has_fetch_; }; |