aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-21 16:14:10 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-21 16:18:47 -0700
commit4d03411da6fcc803d9abcef97a59072144e325f9 (patch)
treee336af7175e3ed146f80a38d3275c5026cdc1f39 /tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
parentd67a99406f986021df9a0f8a99ff6eaf801dcb25 (diff)
Add arithmetic optimizer stage that removes LogicalNot that takes a comparison as input, i.e.
!(a == b) => a != b !(a != b) => a == b !(a < b) => a >= b !(a <= b) => a > b !(a > b) => a <= b !(a >= b) => a < b PiperOrigin-RevId: 197477959
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc103
1 files changed, 103 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index 99f93e6eec..64fdc8a83b 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -177,6 +177,11 @@ class ArithmeticOptimizerTest : public GrapplerTest {
DisableAllStages(optimizer);
optimizer->options_.remove_idempotent = true;
}
+
+ void EnableOnlyRemoveLogicalNot(ArithmeticOptimizer* optimizer) {
+ DisableAllStages(optimizer);
+ optimizer->options_.remove_logical_not = true;
+ }
};
TEST_F(ArithmeticOptimizerTest, NoOp) {
@@ -2737,5 +2742,103 @@ TEST_F(ArithmeticOptimizerTest, RemoveIdempotent) {
}
}
+TEST_F(ArithmeticOptimizerTest, RemoveLogicalNot) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::Const(s.WithOpName("a"), 3.14f, {32});
+ Output b = ops::Const(s.WithOpName("b"), -3.14f, {32});
+ Output eq = ops::Equal(s.WithOpName("eq"), a, b);
+ Output neq = ops::NotEqual(s.WithOpName("neq"), a, b);
+ Output lt = ops::Less(s.WithOpName("lt"), a, b);
+ Output le = ops::LessEqual(s.WithOpName("le"), a, b);
+ Output gt = ops::Greater(s.WithOpName("gt"), a, b);
+ Output ge = ops::GreaterEqual(s.WithOpName("ge"), a, b);
+ // not_eq is reserved
+ Output not_eq1 = ops::LogicalNot(s.WithOpName("not_eq1"), eq);
+ Output not_neq = ops::LogicalNot(s.WithOpName("not_neq"), neq);
+ Output not_lt = ops::LogicalNot(s.WithOpName("not_lt"), lt);
+ Output not_le = ops::LogicalNot(s.WithOpName("not_le"), le);
+ Output not_gt = ops::LogicalNot(s.WithOpName("not_gt"), gt);
+ Output not_ge = ops::LogicalNot(s.WithOpName("not_ge"), ge);
+ Output id_not_eq = ops::Identity(s.WithOpName("id_not_eq"), not_eq1);
+ Output id_not_neq = ops::Identity(s.WithOpName("id_not_neq"), not_neq);
+ Output id_not_lt = ops::Identity(s.WithOpName("id_not_lt"), not_lt);
+ Output id_not_le = ops::Identity(s.WithOpName("id_not_le"), not_le);
+ Output id_not_gt = ops::Identity(s.WithOpName("id_not_gt"), not_gt);
+ Output id_not_ge = ops::Identity(s.WithOpName("id_not_ge"), not_ge);
+
+ GrapplerItem item;
+ item.fetch = {"id_not_eq", "id_not_neq", "id_not_lt",
+ "id_not_le", "id_not_gt", "id_not_ge"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+
+ GraphDef output;
+ ArithmeticOptimizer optimizer;
+ EnableOnlyRemoveLogicalNot(&optimizer);
+ OptimizeTwice(&optimizer, &item, &output);
+ LOG(INFO) << output.DebugString();
+ int found = 0;
+ for (const NodeDef& node : output.node()) {
+ if (node.name() == "id_not_eq") {
+ EXPECT_EQ("eq", node.input(0));
+ ++found;
+ }
+ if (node.name() == "id_not_neq") {
+ EXPECT_EQ("neq", node.input(0));
+ ++found;
+ }
+ if (node.name() == "id_not_lt") {
+ EXPECT_EQ("lt", node.input(0));
+ ++found;
+ }
+ if (node.name() == "id_not_le") {
+ EXPECT_EQ("le", node.input(0));
+ ++found;
+ }
+ if (node.name() == "id_not_gt") {
+ EXPECT_EQ("gt", node.input(0));
+ ++found;
+ }
+ if (node.name() == "id_not_ge") {
+ EXPECT_EQ("ge", node.input(0));
+ ++found;
+ }
+
+ if (node.name() == "eq") {
+ EXPECT_EQ("NotEqual", node.op());
+ ++found;
+ }
+ if (node.name() == "neq") {
+ EXPECT_EQ("Equal", node.op());
+ ++found;
+ }
+ if (node.name() == "lt") {
+ EXPECT_EQ("GreaterEqual", node.op());
+ ++found;
+ }
+ if (node.name() == "le") {
+ EXPECT_EQ("Greater", node.op());
+ ++found;
+ }
+ if (node.name() == "gt") {
+ EXPECT_EQ("LessEqual", node.op());
+ ++found;
+ }
+ if (node.name() == "ge") {
+ EXPECT_EQ("Less", node.op());
+ ++found;
+ }
+ }
+ EXPECT_EQ(12, found);
+
+ auto tensors = EvaluateNodes(output, item.fetch);
+ EXPECT_EQ(tensors.size(), tensors_expected.size());
+ EXPECT_EQ(tensors.size(), item.fetch.size());
+ for (int i = 0; i < item.fetch.size(); ++i) {
+ test::ExpectTensorEqual<bool>(tensors_expected[i], tensors[i]);
+ }
+}
+
} // namespace grappler
} // namespace tensorflow