aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-03 13:00:56 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-03 13:40:29 -0700
commitceda30408f66a7eea86dc359164deb662d5a32d0 (patch)
treea20d71c9d126dca85b7e1588d8f661c13f3a1b6d /tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
parent775d1c03c1772c0c2e10e5884af8d9363cfdf314 (diff)
Enable unary chain hoisting optimization for concat/split/splitv by default.
PiperOrigin-RevId: 195297330
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc18
1 files changed, 13 insertions, 5 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index d6510ba681..2a5654f752 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -1400,6 +1400,11 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage {
return n > 1;
} else if (IsSplit(*node) || IsSplitV(*node)) {
const int num_split = node->attr().at("num_split").i();
+ if (NumNonControlOutputs(*node, *ctx().node_map) > num_split) {
+ // TODO(rmlarsen): Remove this constraint when we have optimizations
+ // in place for merging slices into splits.
+ return false;
+ }
return num_split > 1 && !IsAlreadyOptimized(*node);
}
return false;
@@ -1458,13 +1463,13 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage {
if (tails.empty()) {
return Status::OK();
}
- AddControlInputs(ctrl_inputs, root_node);
AddToOptimizationQueue(root_node);
optimized_nodes_.insert(root_node->name());
if (node_is_concat_) {
+ AddControlInputs(ctrl_inputs, root_node);
return HoistChainForConcat(prefix_length, tails, root_node);
} else {
- return HoistChainForSplit(prefix_length, tails, root_node);
+ return HoistChainForSplit(prefix_length, tails, ctrl_inputs, root_node);
}
}
@@ -1542,9 +1547,8 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage {
IsInPreserveSet(*op)) {
return false;
}
- if (node_is_concat_ &&
- ctx().node_map->GetOutputs(op->name()).size() > 1) {
- // TODO(rmlarsen): Allow and hoist outgoing control edges.
+ if (ctx().node_map->GetOutputs(op->name()).size() > 1) {
+ // TODO(rmlarsen): Allow outgoing control edges.
return false;
}
}
@@ -1612,6 +1616,7 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage {
}
Status HoistChainForSplit(const int prefix_length, const ChainLinkSet& tails,
+ std::set<string>* ctrl_inputs,
NodeDef* split_node) {
// Create a new chain before the split node to process the input tensor.
const string& split_name = split_node->name();
@@ -1646,6 +1651,9 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage {
cur_copy->add_input(orig_input);
ctx().node_map->UpdateOutput(NodeName(orig_input), split_name,
cur_copy->name());
+ // Make sure all the control inputs are satisfied before running the first
+ // node in the new chain.
+ AddControlInputs(ctrl_inputs, cur_copy);
// Connect all consumers of the tail nodes directly to the
// output port of Split from which the chain started.