diff options
author | 2018-05-03 13:00:56 -0700 | |
---|---|---|
committer | 2018-05-03 13:40:29 -0700 | |
commit | ceda30408f66a7eea86dc359164deb662d5a32d0 (patch) | |
tree | a20d71c9d126dca85b7e1588d8f661c13f3a1b6d /tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc | |
parent | 775d1c03c1772c0c2e10e5884af8d9363cfdf314 (diff) |
Enable unary chain hoisting optimization for concat/split/splitv by default.
PiperOrigin-RevId: 195297330
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc | 18 |
1 files changed, 13 insertions, 5 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index d6510ba681..2a5654f752 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -1400,6 +1400,11 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage { return n > 1; } else if (IsSplit(*node) || IsSplitV(*node)) { const int num_split = node->attr().at("num_split").i(); + if (NumNonControlOutputs(*node, *ctx().node_map) > num_split) { + // TODO(rmlarsen): Remove this constraint when we have optimizations + // in place for merging slices into splits. + return false; + } return num_split > 1 && !IsAlreadyOptimized(*node); } return false; @@ -1458,13 +1463,13 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage { if (tails.empty()) { return Status::OK(); } - AddControlInputs(ctrl_inputs, root_node); AddToOptimizationQueue(root_node); optimized_nodes_.insert(root_node->name()); if (node_is_concat_) { + AddControlInputs(ctrl_inputs, root_node); return HoistChainForConcat(prefix_length, tails, root_node); } else { - return HoistChainForSplit(prefix_length, tails, root_node); + return HoistChainForSplit(prefix_length, tails, ctrl_inputs, root_node); } } @@ -1542,9 +1547,8 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage { IsInPreserveSet(*op)) { return false; } - if (node_is_concat_ && - ctx().node_map->GetOutputs(op->name()).size() > 1) { - // TODO(rmlarsen): Allow and hoist outgoing control edges. + if (ctx().node_map->GetOutputs(op->name()).size() > 1) { + // TODO(rmlarsen): Allow outgoing control edges. return false; } } @@ -1612,6 +1616,7 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage { } Status HoistChainForSplit(const int prefix_length, const ChainLinkSet& tails, + std::set<string>* ctrl_inputs, NodeDef* split_node) { // Create a new chain before the split node to process the input tensor. const string& split_name = split_node->name(); @@ -1646,6 +1651,9 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage { cur_copy->add_input(orig_input); ctx().node_map->UpdateOutput(NodeName(orig_input), split_name, cur_copy->name()); + // Make sure all the control inputs are satisfied before running the first + // node in the new chain. + AddControlInputs(ctrl_inputs, cur_copy); // Connect all consumers of the tail nodes directly to the // output port of Split from which the chain started. |