aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2017-11-29 17:55:53 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-29 17:59:09 -0800
commitb97585f5d2157b1e0273a4b20a568635fb58ad57 (patch)
treec62d59c616fc9eb42127a16d842631971fa0ff42
parentcb4ef362e4a18b3c42a2c90bdad8754d5ead4caf (diff)
Always leverage shapes inference now that it can handle fed nodes
conservatively. PiperOrigin-RevId: 177391746
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc88
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.h6
2 files changed, 52 insertions, 42 deletions
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index 03eaa4a84a..b5172a4833 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -190,6 +190,14 @@ Status ConvertShapeToConstant(const string& op, const DataType& type,
return Status::OK();
}
+bool ConstantFolding::IsReallyConstant(const NodeDef& node) const {
+ if (!IsConstant(node)) {
+ return false;
+ }
+ // If the node is fed it's not constant anymore.
+ return feed_nodes_.find(node.name()) == feed_nodes_.end();
+}
+
Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
// We may add some nodes to the graph to encode control dependencies: there is
// no need to process these, so only iterate over the nodes of the input
@@ -327,9 +335,9 @@ Status ConstantFolding::MaterializeBroadcastGradientArgs(
const NodeDef* shape_node1 = node_map_->GetNode(node.input(0));
const NodeDef* shape_node2 = node_map_->GetNode(node.input(1));
if (shape_node1 == nullptr ||
- (shape_node1->op() != "Shape" && shape_node1->op() != "Const") ||
+ (shape_node1->op() != "Shape" && !IsReallyConstant(*shape_node1)) ||
shape_node2 == nullptr ||
- (shape_node2->op() != "Shape" && shape_node2->op() != "Const")) {
+ (shape_node2->op() != "Shape" && !IsReallyConstant(*shape_node2))) {
return Status::OK();
}
int64 min_id = 0;
@@ -409,7 +417,7 @@ Status ConstantFolding::MaterializeReductionIndices(
return Status::OK();
}
const NodeDef* indices = node_map_->GetNode(node->input(1));
- if (!indices || IsConstant(*indices)) {
+ if (!indices || IsReallyConstant(*indices)) {
// The reduction indices are already constant, there's nothing to do.
return Status::OK();
}
@@ -506,24 +514,23 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const {
if (node.input().empty()) {
return false;
}
-
// Skips nodes that must be preserved except whitelisted nodes.
if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end() &&
nodes_whitelist_.find(node.name()) == nodes_whitelist_.end()) {
return false;
}
-
- // Skips ops that don't benefit from folding.
- const string& op = node.op();
- // Skip constants, they're already folded
- if (op == "Const") {
+ // Skip control flow nodes, they can't be folded
+ if (ModifiesFrameInfo(node)) {
return false;
}
- // Skip constrol flow nodes, they can't be folded
- if (op == "Enter" || op == "RefEnter" || op == "Exit" || op == "RefExit" ||
- op == "NextIteration" || op == "RefNextIteration") {
+ // Skip constants, they're already folded
+ if (IsConstant(node)) {
return false;
}
+
+ // Skips ops that don't benefit from folding.
+ const string& op = node.op();
+
if (op.find("Placeholder") == 0) {
return false;
}
@@ -577,7 +584,7 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const {
if (!input_node) {
return false;
}
- bool is_const = IsConstant(*input_node);
+ bool is_const = IsReallyConstant(*input_node);
if (!is_const && !is_merge) {
return false;
}
@@ -703,7 +710,7 @@ Status ConstantFolding::EvaluateOneFoldable(const NodeDef& node,
break;
}
const NodeDef* input_node = node_map_->GetNode(input);
- if (!IsConstant(*input_node)) {
+ if (!IsReallyConstant(*input_node)) {
return Status(error::INVALID_ARGUMENT,
strings::StrCat("Can't fold ", node.name(), ", its ", input,
" isn't constant"));
@@ -757,7 +764,7 @@ Status ConstantFolding::FoldNode(NodeDef* node, GraphDef* output_graph) {
continue;
}
NodeDef* input_node = node_map_->GetNode(input);
- if (!IsConstant(*input_node)) {
+ if (!IsReallyConstant(*input_node)) {
continue;
}
bool valid_input = true;
@@ -999,7 +1006,7 @@ bool ConstantFolding::IsSimplifiableReduction(const NodeDef& node) const {
if (IsReduction(node)) {
CHECK_LE(2, node.input_size());
const NodeDef* reductions_indices = node_map_->GetNode(node.input(1));
- if (IsConstant(*reductions_indices)) {
+ if (IsReallyConstant(*reductions_indices)) {
TensorVector output;
Status s = EvaluateNode(*reductions_indices, TensorVector(), &output);
if (!s.ok()) {
@@ -1023,7 +1030,7 @@ bool ConstantFolding::IsSimplifiableReshape(
}
CHECK_LE(2, node.input_size());
const NodeDef* new_shape = node_map_->GetNode(node.input(1));
- if (!IsConstant(*new_shape)) {
+ if (!IsReallyConstant(*new_shape)) {
return false;
}
TensorVector outputs;
@@ -1074,7 +1081,8 @@ bool ConstantFolding::IsSimplifiableReshape(
}
Status ConstantFolding::SimplifyGraph(GraphDef* output,
- const GraphProperties& properties) {
+ const GraphProperties& properties,
+ bool use_shape_info) {
for (auto& node : *output->mutable_node()) {
if (IsSimplifiableReduction(node)) {
// Replace the reduction node with an identity node, that can be further
@@ -1099,10 +1107,10 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
*node.add_input() = input;
}
}
- // It's possible to feed a placeholder with a tensor that doesn't have the
- // proper shape, and reshape this tensor later on. Therefore only remove
- // reshapes in graphs that don't have placeholders.
- if (IsSimplifiableReshape(node, properties)) {
+ const bool safe_to_use_shapes =
+ use_shape_info &&
+ (feed_nodes_.empty() || opt_level_ == RewriterConfig::AGGRESSIVE);
+ if (safe_to_use_shapes && IsSimplifiableReshape(node, properties)) {
const NodeDef* new_shape = node_map_->GetNode(node.input(1));
DataType output_type = node.attr().at("T").type();
node.set_op("Identity");
@@ -1141,36 +1149,34 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
}
GraphProperties properties(item);
- const bool has_feed = !item.feed.empty();
- bool needs_shapes = !has_feed || opt_level_ == RewriterConfig::AGGRESSIVE;
- Status s = errors::Unknown(
- "The graph properties are needed but were not initialized");
- if (needs_shapes) {
- s = properties.InferStatically(false);
- }
-
- if (!has_feed && s.ok()) {
- // Only use static shape information when there is no feed in the
- // graph. That's because it's possible to feed a placeholder with a tensor
- // of any shape, which could make the static information inconsistent with
- // the shapes actually fed.
+ // It's possible to feed a placeholder with a tensor of any shape: make sure
+ // that the shape inference deals with this conservatively unless we're in
+ // aggressive mode.
+ const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
+ Status s = properties.InferStatically(assume_valid_feeds);
+ const bool can_use_shape_info = s.ok();
+
+ if (can_use_shape_info) {
TF_RETURN_IF_ERROR(MaterializeShapes(properties));
- }
- if (opt_level_ == RewriterConfig::AGGRESSIVE && s.ok()) {
- TF_RETURN_IF_ERROR(MaterializeConstants(properties));
+
+ if (opt_level_ == RewriterConfig::AGGRESSIVE) {
+ TF_RETURN_IF_ERROR(MaterializeConstants(properties));
+ }
}
TF_RETURN_IF_ERROR(FoldGraph(output));
- if (!has_feed && s.ok()) {
- TF_RETURN_IF_ERROR(SimplifyGraph(output, properties));
- }
+ TF_RETURN_IF_ERROR(SimplifyGraph(output, properties, can_use_shape_info));
+
return Status::OK();
}
Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* output) {
nodes_to_preserve_ = item.NodesToPreserve();
+ for (const auto& feed : item.feed) {
+ feed_nodes_.insert(NodeName(feed.first));
+ }
if (cpu_device_ == nullptr) {
owned_device_.reset(new DeviceSimple());
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h
index 7c5db2a70f..8af5b5fbe6 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.h
+++ b/tensorflow/core/grappler/optimizers/constant_folding.h
@@ -51,6 +51,8 @@ class ConstantFolding : public GraphOptimizer {
const GraphDef& optimize_output, double result) override;
private:
+ bool IsReallyConstant(const NodeDef& node) const;
+
Status MaterializeShapes(const GraphProperties& properties);
Status MaterializeBroadcastGradientArgs(const NodeDef& node,
@@ -75,7 +77,8 @@ class ConstantFolding : public GraphOptimizer {
bool IsSimplifiableReduction(const NodeDef& node) const;
bool IsSimplifiableReshape(const NodeDef& node,
const GraphProperties& properties) const;
- Status SimplifyGraph(GraphDef* output, const GraphProperties& properties);
+ Status SimplifyGraph(GraphDef* output, const GraphProperties& properties,
+ bool use_shape_info);
Status RunOptimizationPass(Cluster* cluster, const GrapplerItem& item,
GraphDef* output);
@@ -90,6 +93,7 @@ class ConstantFolding : public GraphOptimizer {
std::unique_ptr<NodeMap> node_map_;
std::unordered_set<string> nodes_to_preserve_;
std::unordered_set<string> nodes_whitelist_;
+ std::unordered_set<string> feed_nodes_;
bool has_fetch_;
};