aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-03 13:00:56 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-03 13:40:29 -0700
commitceda30408f66a7eea86dc359164deb662d5a32d0 (patch)
treea20d71c9d126dca85b7e1588d8f661c13f3a1b6d
parent775d1c03c1772c0c2e10e5884af8d9363cfdf314 (diff)
Enable unary chain hoisting optimization for concat/split/splitv by default.
PiperOrigin-RevId: 195297330
-rw-r--r--tensorflow/core/grappler/op_types.cc38
-rw-r--r--tensorflow/core/grappler/op_types.h4
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc18
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.h2
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc16
5 files changed, 51 insertions, 27 deletions
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index 7c936dfca1..c48dc00941 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -476,28 +476,40 @@ bool IsInvolution(const NodeDef& node) {
return involution_ops->count(node.op()) > 0;
}
+bool IsValueAndOrderAndShapePreserving(const NodeDef& node) {
+ if (NumNonControlInputs(node) == 1 && IsAggregate(node)) {
+ return true;
+ }
+ static const std::unordered_set<string>*
+ value_and_order_and_shape_preserving_ops =
+ CHECK_NOTNULL((new const std::unordered_set<string>{
+ "CheckNumerics",
+ "DebugGradientIdentity",
+ "DeepCopy"
+ "Enter",
+ "Exit",
+ "Identity",
+ "IdentityN",
+ "PreventGradient",
+ "Print",
+ "Snapshot",
+ "StopGradient",
+ }));
+ return value_and_order_and_shape_preserving_ops->count(node.op()) > 0;
+}
+
bool IsValueAndOrderPreserving(const NodeDef& node) {
if (NumNonControlInputs(node) == 1 && IsAggregate(node)) {
return true;
}
static const std::unordered_set<string>* value_and_order_preserving_ops =
CHECK_NOTNULL((new const std::unordered_set<string>{
- "CheckNumerics",
- "DebugGradientIdentity",
- "DeepCopy"
- "Enter",
- "Exit",
"ExpandDims",
- "Identity",
- "IdentityN",
- "PreventGradient",
- "Print",
- "Reshape",
"Snapshot",
"Squeeze",
- "StopGradient",
}));
- return value_and_order_preserving_ops->count(node.op()) > 0;
+ return value_and_order_preserving_ops->count(node.op()) > 0 ||
+ IsValueAndOrderAndShapePreserving(node);
}
bool IsValuePreserving(const NodeDef& node) {
@@ -564,7 +576,7 @@ bool IsUnaryElementWise(const NodeDef& node) {
"Tanh",
}));
return element_wise_ops->count(node.op()) > 0 ||
- (!IsIdentityN(node) && IsValueAndOrderPreserving(node));
+ (!IsIdentityN(node) && IsValueAndOrderAndShapePreserving(node));
}
bool HasOpDef(const NodeDef& node) {
diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h
index 7a1b438768..e33dd21538 100644
--- a/tensorflow/core/grappler/op_types.h
+++ b/tensorflow/core/grappler/op_types.h
@@ -174,6 +174,10 @@ bool ModifiesInputsInPlace(const NodeDef& node);
// own inverse such that f(f(x)) == x.
bool IsInvolution(const NodeDef& node);
+// Returns true if the op preserves the order and value of elements
+// and shape of its first input tensor.
+bool IsValueAndOrderAndShapePreserving(const NodeDef& node);
+
// Returns true if the op preserves the order and value of elements in its
// first input tensor and possible changes its shape.
bool IsValueAndOrderPreserving(const NodeDef& node);
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index d6510ba681..2a5654f752 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -1400,6 +1400,11 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage {
return n > 1;
} else if (IsSplit(*node) || IsSplitV(*node)) {
const int num_split = node->attr().at("num_split").i();
+ if (NumNonControlOutputs(*node, *ctx().node_map) > num_split) {
+ // TODO(rmlarsen): Remove this constraint when we have optimizations
+ // in place for merging slices into splits.
+ return false;
+ }
return num_split > 1 && !IsAlreadyOptimized(*node);
}
return false;
@@ -1458,13 +1463,13 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage {
if (tails.empty()) {
return Status::OK();
}
- AddControlInputs(ctrl_inputs, root_node);
AddToOptimizationQueue(root_node);
optimized_nodes_.insert(root_node->name());
if (node_is_concat_) {
+ AddControlInputs(ctrl_inputs, root_node);
return HoistChainForConcat(prefix_length, tails, root_node);
} else {
- return HoistChainForSplit(prefix_length, tails, root_node);
+ return HoistChainForSplit(prefix_length, tails, ctrl_inputs, root_node);
}
}
@@ -1542,9 +1547,8 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage {
IsInPreserveSet(*op)) {
return false;
}
- if (node_is_concat_ &&
- ctx().node_map->GetOutputs(op->name()).size() > 1) {
- // TODO(rmlarsen): Allow and hoist outgoing control edges.
+ if (ctx().node_map->GetOutputs(op->name()).size() > 1) {
+ // TODO(rmlarsen): Allow outgoing control edges.
return false;
}
}
@@ -1612,6 +1616,7 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage {
}
Status HoistChainForSplit(const int prefix_length, const ChainLinkSet& tails,
+ std::set<string>* ctrl_inputs,
NodeDef* split_node) {
// Create a new chain before the split node to process the input tensor.
const string& split_name = split_node->name();
@@ -1646,6 +1651,9 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage {
cur_copy->add_input(orig_input);
ctx().node_map->UpdateOutput(NodeName(orig_input), split_name,
cur_copy->name());
+ // Make sure all the control inputs are satisfied before running the first
+ // node in the new chain.
+ AddControlInputs(ctrl_inputs, cur_copy);
// Connect all consumers of the tail nodes directly to the
// output port of Split from which the chain started.
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
index 3b297ec0aa..6309dc1a33 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
@@ -65,7 +65,7 @@ class ArithmeticOptimizer : public GraphOptimizer {
bool remove_redundant_bitcast = true;
bool remove_redundant_cast = true;
bool remove_negation = true;
- bool hoist_cwise_unary_chains = false;
+ bool hoist_cwise_unary_chains = true;
bool convert_sqrt_div_to_rsqrt_mul = false;
// Choose which arithmetic optimizer stages will be enabled for a given
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index f903f53a35..d32743f3f2 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -2320,16 +2320,16 @@ TEST_F(ArithmeticOptimizerTest, HoistCWiseUnaryIntoSplit) {
EXPECT_NE(node.name(), "cos_exp_b2");
if (node.name() == "split1") {
- EXPECT_EQ(3, node.input_size());
+ EXPECT_EQ(2, node.input_size());
EXPECT_EQ("axis", node.input(0));
EXPECT_EQ("ArithmeticOptimizer/_sin_a_split1", node.input(1));
- EXPECT_EQ("^ctrl1", node.input(2));
found++;
}
if (node.name() == "ArithmeticOptimizer/_sin_a_split1") {
EXPECT_EQ("Sin", node.op());
- EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ(2, node.input_size());
EXPECT_EQ("x", node.input(0));
+ EXPECT_EQ("^ctrl1", node.input(1));
found++;
}
if (node.name() == "id_a") {
@@ -2349,8 +2349,11 @@ TEST_F(ArithmeticOptimizerTest, HoistCWiseUnaryIntoSplit) {
}
if (node.name() == "ArithmeticOptimizer/_exp_a2_split2") {
EXPECT_EQ("Exp", node.op());
- EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ(4, node.input_size());
EXPECT_EQ("x", node.input(0));
+ EXPECT_EQ("^ctrl1", node.input(1));
+ EXPECT_EQ("^ctrl2", node.input(2));
+ EXPECT_EQ("^ctrl3", node.input(3));
found++;
}
if (node.name() == "ArithmeticOptimizer/_cos_exp_a2_split2") {
@@ -2360,13 +2363,10 @@ TEST_F(ArithmeticOptimizerTest, HoistCWiseUnaryIntoSplit) {
found++;
}
if (node.name() == "split2") {
- EXPECT_EQ(6, node.input_size());
+ EXPECT_EQ(3, node.input_size());
EXPECT_EQ("ArithmeticOptimizer/_cos_exp_a2_split2", node.input(0));
EXPECT_EQ("size_splits2", node.input(1));
EXPECT_EQ("axis", node.input(2));
- EXPECT_EQ("^ctrl1", node.input(3));
- EXPECT_EQ("^ctrl2", node.input(4));
- EXPECT_EQ("^ctrl3", node.input(5));
found++;
}
if (node.name() == "id_a2") {