aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-04 11:18:50 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-04 11:47:06 -0700
commit965e3b0ca01ed7cc951131454b38ab638ff44fbf (patch)
tree7a7dad14dc6a68b1cce6b33c3d37f569a79f63b4
parent5d183ab7fc7b82f1dea0b9fa9c6412c39ade15a1 (diff)
Extend hoisting monotonic functions out of min/max reductions to all monotonic unary functions.
Add the ability to flip Max <-> Min if the function is non-increasing, e.g. Max(Neg(x)) => Neg(Min(x)). PiperOrigin-RevId: 211490436
-rw-r--r--tensorflow/core/grappler/op_types.cc37
-rw-r--r--tensorflow/core/grappler/op_types.h2
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc10
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc42
4 files changed, 80 insertions, 11 deletions
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index 653b088b1d..e78239bd43 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -135,16 +135,37 @@ bool IsDequeueOp(const NodeDef& node) {
bool IsDiv(const NodeDef& node) { return node.op() == "Div"; }
-bool IsElementWiseMonotonic(const NodeDef& node) {
- static const std::unordered_set<string>* element_wise_monotonic_ops =
+// Returns true if node represents a unary elementwise function that is
+// monotonic. If *is_non_decreasing is true, the function is non-decreasing,
+// e.g. sqrt, exp. *is_non_decreasing is false, the function is non-increasing,
+// e.g. inv.
+bool IsElementWiseMonotonic(const NodeDef& node, bool* is_non_decreasing) {
+ static const std::unordered_set<string>* monotonic_non_decreasing_ops =
CHECK_NOTNULL((new std::unordered_set<string>{
- "Relu",
- "Relu6",
- "Sigmoid",
- "Sqrt",
- "Tanh",
+ "Asinh", "Atanh", "Ceil", "Elu", "Erf", "Exp", "Expm1",
+ "Floor", "Log", "Log1p", "Relu", "Relu", "Relu6", "Rint",
+ "Selu", "Sigmoid", "Sign", "Sinh", "Sqrt", "Tanh",
+ }));
+ static const std::unordered_set<string>* monotonic_non_increasing_ops =
+ CHECK_NOTNULL((new std::unordered_set<string>{
+ "Inv",
+ "Reciprocal",
+ "Erfc",
+ "Rsqrt",
+ "Neg",
}));
- return element_wise_monotonic_ops->count(node.op()) > 0;
+ if (monotonic_non_decreasing_ops->count(node.op()) > 0) {
+ if (is_non_decreasing) {
+ *is_non_decreasing = true;
+ }
+ return true;
+ } else if (monotonic_non_increasing_ops->count(node.op()) > 0) {
+ if (is_non_decreasing) {
+ *is_non_decreasing = false;
+ }
+ return true;
+ }
+ return false;
}
bool IsEluGrad(const NodeDef& node) { return node.op() == "EluGrad"; }
diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h
index 94439265c9..25ab6b65ac 100644
--- a/tensorflow/core/grappler/op_types.h
+++ b/tensorflow/core/grappler/op_types.h
@@ -55,7 +55,7 @@ bool IsDepthwiseConv2dNativeBackpropFilter(const NodeDef& node);
bool IsDepthwiseConv2dNativeBackpropInput(const NodeDef& node);
bool IsDequeueOp(const NodeDef& node);
bool IsDiv(const NodeDef& node);
-bool IsElementWiseMonotonic(const NodeDef& node);
+bool IsElementWiseMonotonic(const NodeDef& node, bool* is_non_decreasing);
bool IsEluGrad(const NodeDef& node);
bool IsEnter(const NodeDef& node);
bool IsEqual(const NodeDef& node);
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 4fed88d536..65947ddce5 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -2706,8 +2706,9 @@ class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage {
// 0. inner_function is not in the preserve set,
// 1. inner_function's Op is element-wise monotonic
// 2. inner_function's output is not being consumed elsewhere.
+ bool is_non_decreasing = false;
if (!IsInPreserveSet(*inner_function) &&
- IsElementWiseMonotonic(*inner_function) &&
+ IsElementWiseMonotonic(*inner_function, &is_non_decreasing) &&
ctx().node_map->GetOutputs(inner_function->name()).size() == 1) {
// Swap the first inputs of the inner function Op & the reduction Op.
NodeDef* inner_input;
@@ -2719,7 +2720,12 @@ class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage {
UpdateConsumers(reduction_node, inner_function->name());
ctx().node_map->UpdateInput(inner_function->name(), inner_input->name(),
reduction_node->name());
-
+ if (!is_non_decreasing) {
+ // Flip Min<->Max if the function is non-increasing, e.g.
+ // Max(Neg(x)) = Neg(Min(x)).
+ const string opposite = IsMax(*reduction_node) ? "Min" : "Max";
+ reduction_node->set_op(opposite);
+ }
AddToOptimizationQueue(reduction_node);
AddToOptimizationQueue(inner_function);
AddToOptimizationQueue(inner_input);
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index bfccc0affd..39517edc06 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -3248,6 +3248,48 @@ TEST_F(ArithmeticOptimizerTest,
VerifyGraphsMatch(item.graph, output, __LINE__);
}
+TEST_F(ArithmeticOptimizerTest,
+ OptimizeMaxOrMinOfMonotonicElementWiseNonIncreasing) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
+ Output neg = ops::Neg(s.WithOpName("neg"), x);
+ Output reduce_max = ops::Max(s.WithOpName("reduce_max"), neg, {0});
+ Output final_out = ops::Identity(s.WithOpName("final_out"), reduce_max);
+
+ GrapplerItem item;
+ item.fetch = {"final_out"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(1, tensors_expected.size());
+
+ GraphDef output;
+ ArithmeticOptimizer optimizer;
+ EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
+ OptimizeAndPrune(&optimizer, &item, &output);
+ auto tensors = EvaluateNodes(output, item.fetch);
+ EXPECT_EQ(1, tensors.size());
+
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
+ EXPECT_EQ(item.graph.node_size(), output.node_size());
+ // Check if the inputs are switched
+ int required_node_count = 0;
+ for (int i = 0; i < output.node_size(); ++i) {
+ const NodeDef& node = output.node(i);
+ if (node.name() == "neg") {
+ EXPECT_EQ("Neg", node.op());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("reduce_max", node.input(0));
+ ++required_node_count;
+ } else if (node.name() == "reduce_max") {
+ EXPECT_EQ("Min", node.op());
+ EXPECT_EQ(2, node.input_size());
+ EXPECT_EQ("x", node.input(0));
+ ++required_node_count;
+ }
+ }
+ EXPECT_EQ(2, required_node_count);
+}
+
TEST_F(ArithmeticOptimizerTest, UnaryOpsComposition) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();