diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-05-21 16:14:10 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-21 16:18:47 -0700 |
commit | 4d03411da6fcc803d9abcef97a59072144e325f9 (patch) | |
tree | e336af7175e3ed146f80a38d3275c5026cdc1f39 /tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc | |
parent | d67a99406f986021df9a0f8a99ff6eaf801dcb25 (diff) |
Add arithmetic optimizer stage that removes LogicalNot that takes a comparison as input, i.e.
!(a == b) => a != b
!(a != b) => a == b
!(a < b) => a >= b
!(a <= b) => a > b
!(a > b) => a <= b
!(a >= b) => a < b
PiperOrigin-RevId: 197477959
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc | 103 |
1 files changed, 103 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index 99f93e6eec..64fdc8a83b 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -177,6 +177,11 @@ class ArithmeticOptimizerTest : public GrapplerTest { DisableAllStages(optimizer); optimizer->options_.remove_idempotent = true; } + + void EnableOnlyRemoveLogicalNot(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.remove_logical_not = true; + } }; TEST_F(ArithmeticOptimizerTest, NoOp) { @@ -2737,5 +2742,103 @@ TEST_F(ArithmeticOptimizerTest, RemoveIdempotent) { } } +TEST_F(ArithmeticOptimizerTest, RemoveLogicalNot) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output a = ops::Const(s.WithOpName("a"), 3.14f, {32}); + Output b = ops::Const(s.WithOpName("b"), -3.14f, {32}); + Output eq = ops::Equal(s.WithOpName("eq"), a, b); + Output neq = ops::NotEqual(s.WithOpName("neq"), a, b); + Output lt = ops::Less(s.WithOpName("lt"), a, b); + Output le = ops::LessEqual(s.WithOpName("le"), a, b); + Output gt = ops::Greater(s.WithOpName("gt"), a, b); + Output ge = ops::GreaterEqual(s.WithOpName("ge"), a, b); + // not_eq is reserved + Output not_eq1 = ops::LogicalNot(s.WithOpName("not_eq1"), eq); + Output not_neq = ops::LogicalNot(s.WithOpName("not_neq"), neq); + Output not_lt = ops::LogicalNot(s.WithOpName("not_lt"), lt); + Output not_le = ops::LogicalNot(s.WithOpName("not_le"), le); + Output not_gt = ops::LogicalNot(s.WithOpName("not_gt"), gt); + Output not_ge = ops::LogicalNot(s.WithOpName("not_ge"), ge); + Output id_not_eq = ops::Identity(s.WithOpName("id_not_eq"), not_eq1); + Output id_not_neq = ops::Identity(s.WithOpName("id_not_neq"), not_neq); + Output id_not_lt = ops::Identity(s.WithOpName("id_not_lt"), not_lt); + Output id_not_le = ops::Identity(s.WithOpName("id_not_le"), not_le); + Output id_not_gt = ops::Identity(s.WithOpName("id_not_gt"), not_gt); + Output id_not_ge = ops::Identity(s.WithOpName("id_not_ge"), not_ge); + + GrapplerItem item; + item.fetch = {"id_not_eq", "id_not_neq", "id_not_lt", + "id_not_le", "id_not_gt", "id_not_ge"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + auto tensors_expected = EvaluateNodes(item.graph, item.fetch); + + GraphDef output; + ArithmeticOptimizer optimizer; + EnableOnlyRemoveLogicalNot(&optimizer); + OptimizeTwice(&optimizer, &item, &output); + LOG(INFO) << output.DebugString(); + int found = 0; + for (const NodeDef& node : output.node()) { + if (node.name() == "id_not_eq") { + EXPECT_EQ("eq", node.input(0)); + ++found; + } + if (node.name() == "id_not_neq") { + EXPECT_EQ("neq", node.input(0)); + ++found; + } + if (node.name() == "id_not_lt") { + EXPECT_EQ("lt", node.input(0)); + ++found; + } + if (node.name() == "id_not_le") { + EXPECT_EQ("le", node.input(0)); + ++found; + } + if (node.name() == "id_not_gt") { + EXPECT_EQ("gt", node.input(0)); + ++found; + } + if (node.name() == "id_not_ge") { + EXPECT_EQ("ge", node.input(0)); + ++found; + } + + if (node.name() == "eq") { + EXPECT_EQ("NotEqual", node.op()); + ++found; + } + if (node.name() == "neq") { + EXPECT_EQ("Equal", node.op()); + ++found; + } + if (node.name() == "lt") { + EXPECT_EQ("GreaterEqual", node.op()); + ++found; + } + if (node.name() == "le") { + EXPECT_EQ("Greater", node.op()); + ++found; + } + if (node.name() == "gt") { + EXPECT_EQ("LessEqual", node.op()); + ++found; + } + if (node.name() == "ge") { + EXPECT_EQ("Less", node.op()); + ++found; + } + } + EXPECT_EQ(12, found); + + auto tensors = EvaluateNodes(output, item.fetch); + EXPECT_EQ(tensors.size(), tensors_expected.size()); + EXPECT_EQ(tensors.size(), item.fetch.size()); + for (int i = 0; i < item.fetch.size(); ++i) { + test::ExpectTensorEqual<bool>(tensors_expected[i], tensors[i]); + } +} + } // namespace grappler } // namespace tensorflow |