aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-30 14:55:54 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-30 15:28:43 -0700
commitd15f77048558a7af16648146faca1c5d13d8d6e1 (patch)
tree098fc91e752605870bc56b04251bd0198d991285 /tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
parente469934f1274c7c498e5061995fec425a21c9be8 (diff)
Move RemoveInvolution optimization to optimizer stage.
PiperOrigin-RevId: 198624394
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc130
1 files changed, 75 insertions, 55 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index 64fdc8a83b..a908416e45 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -115,12 +115,17 @@ class ArithmeticOptimizerTest : public GrapplerTest {
options.dedup_computations = false;
options.enable_try_simplify_and_replace = false;
options.combine_add_to_addn = false;
+ options.convert_sqrt_div_to_rsqrt_mul = false;
options.hoist_common_factor_out_of_aggregation = false;
+ options.hoist_cwise_unary_chains = false;
options.minimize_broadcasts = false;
options.remove_identity_transpose = false;
+ options.remove_involution = false;
+ options.remove_idempotent = false;
options.remove_redundant_bitcast = false;
options.remove_redundant_cast = false;
options.remove_negation = false;
+ options.remove_logical_not = false;
optimizer->options_ = options;
}
@@ -148,6 +153,11 @@ class ArithmeticOptimizerTest : public GrapplerTest {
optimizer->options_.remove_identity_transpose = true;
}
+ void EnableOnlyRemoveInvolution(ArithmeticOptimizer* optimizer) {
+ DisableAllStages(optimizer);
+ optimizer->options_.remove_involution = true;
+ }
+
void EnableOnlyRemoveRedundantBitcast(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.remove_redundant_bitcast = true;
@@ -338,100 +348,110 @@ TEST_F(ArithmeticOptimizerTest, MulToSquare) {
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
-TEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsReal) {
+TEST_F(ArithmeticOptimizerTest, RemoveInvolution_AdjacentNodes) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- Output c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
- Output neg1 = ops::Neg(s.WithOpName("neg1"), c);
- Output neg2 = ops::Neg(s.WithOpName("neg2"), neg1);
- Output recip1 = ops::Reciprocal(s.WithOpName("recip1"), neg2);
- Output recip2 = ops::Reciprocal(s.WithOpName("recip2"), recip1);
- Output id = ops::Identity(s.WithOpName("id"), recip2);
+
+ auto c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
+ auto neg1 = ops::Neg(s.WithOpName("neg1"), c);
+ auto neg2 = ops::Neg(s.WithOpName("neg2"), neg1);
+ auto recip1 = ops::Reciprocal(s.WithOpName("recip1"), neg2);
+ auto recip2 = ops::Reciprocal(s.WithOpName("recip2"), recip1);
+ auto id = ops::Identity(s.WithOpName("id"), recip2);
+
+ std::vector<string> fetch = {"id"};
+
GrapplerItem item;
+ item.fetch = fetch;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
- std::vector<string> fetch = {"id"};
auto tensors_expected = EvaluateNodes(item.graph, fetch);
EXPECT_EQ(1, tensors_expected.size());
- ArithmeticOptimizer optimizer;
GraphDef output;
- Status status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
+ ArithmeticOptimizer optimizer;
+ EnableOnlyRemoveInvolution(&optimizer);
+ OptimizeAndPrune(&optimizer, &item, &output);
- EXPECT_EQ(6, output.node_size());
+ // Negation and Reciprocal nodes cancelled each other.
+ EXPECT_EQ(2, output.node_size());
+ EXPECT_EQ("id", output.node(1).name());
EXPECT_EQ("c", output.node(1).input(0));
- EXPECT_EQ("c", output.node(3).input(0));
- EXPECT_EQ("c", output.node(5).input(0));
auto tensors = EvaluateNodes(output, fetch);
EXPECT_EQ(1, tensors.size());
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
-TEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsWithChain) {
+TEST_F(ArithmeticOptimizerTest, RemoveInvolution_AroundValuePreservingChain) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- Output c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
- Output recip1 = ops::Reciprocal(s.WithOpName("recip1"), c);
- Output id1 = ops::Identity(s.WithOpName("id1"), recip1);
- Output squeeze = ops::Squeeze(s.WithOpName("squeeze"), id1);
- Output recip2 = ops::Reciprocal(s.WithOpName("recip2"), squeeze);
- Output id2 = ops::Identity(s.WithOpName("id2"), recip2);
+
+ auto c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
+ auto recip1 = ops::Reciprocal(s.WithOpName("recip1"), c);
+ auto id1 = ops::Identity(s.WithOpName("id1"), recip1);
+ auto squeeze = ops::Squeeze(s.WithOpName("squeeze"), id1);
+ auto recip2 = ops::Reciprocal(s.WithOpName("recip2"), squeeze);
+ auto id2 = ops::Identity(s.WithOpName("id2"), recip2);
+
+ std::vector<string> fetch = {"id2"};
+
GrapplerItem item;
+ item.fetch = fetch;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
- std::vector<string> fetch = {"id2"};
auto tensors_expected = EvaluateNodes(item.graph, fetch);
EXPECT_EQ(1, tensors_expected.size());
- ArithmeticOptimizer optimizer;
GraphDef output;
- Status status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
- // Run the optimizer twice to make sure the rewrite is idempotent.
- item.graph.Swap(&output);
- status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
+ ArithmeticOptimizer optimizer;
+ EnableOnlyRemoveInvolution(&optimizer);
+ OptimizeTwiceAndPrune(&optimizer, &item, &output);
- EXPECT_EQ(6, output.node_size());
- EXPECT_EQ("squeeze", output.node(5).input(0));
- EXPECT_EQ("c", output.node(2).input(0));
+ // Check that Reciprocal nodes were removed from the graph.
+ EXPECT_EQ(3, output.node_size());
+
+ // And const directly flows into squeeze.
+ int found = 0;
+ for (const NodeDef& node : output.node()) {
+ if (node.name() == "squeeze") {
+ EXPECT_EQ("c", node.input(0));
+ found++;
+ } else if (node.name() == "id2") {
+ EXPECT_EQ("squeeze", node.input(0));
+ found++;
+ }
+ }
+ EXPECT_EQ(2, found);
auto tensors = EvaluateNodes(output, fetch);
EXPECT_EQ(1, tensors.size());
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
-TEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsWithControlChain) {
+TEST_F(ArithmeticOptimizerTest, RemoveInvolution_SkipControlDependencies) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- Output c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
- Output recip1 = ops::Reciprocal(s.WithOpName("recip1"), c);
- Output id1 = ops::Identity(s.WithOpName("id1"), recip1);
- Output squeeze = ops::Squeeze(s.WithOpName("squeeze"), id1);
- Output recip2 = ops::Reciprocal(
+
+ auto c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
+ auto recip1 = ops::Reciprocal(s.WithOpName("recip1"), c);
+ auto id1 = ops::Identity(s.WithOpName("id1"), recip1);
+ auto squeeze = ops::Squeeze(s.WithOpName("squeeze"), id1);
+ auto recip2 = ops::Reciprocal(
s.WithOpName("recip2").WithControlDependencies(squeeze), c);
- Output id2 = ops::Identity(s.WithOpName("id2"), recip2);
+ auto id2 = ops::Identity(s.WithOpName("id2"), recip2);
+
+ std::vector<string> fetch = {"id2"};
+
GrapplerItem item;
+ item.fetch = fetch;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
- std::vector<string> fetch = {"id2"};
auto tensors_expected = EvaluateNodes(item.graph, fetch);
EXPECT_EQ(1, tensors_expected.size());
- ArithmeticOptimizer optimizer;
GraphDef output;
- Status status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
+ ArithmeticOptimizer optimizer;
+ EnableOnlyRemoveInvolution(&optimizer);
+ OptimizeTwice(&optimizer, &item, &output); // do not prune in this test
// The optimizer should be a noop.
- EXPECT_EQ(item.graph.node_size(), output.node_size());
- for (int i = 0; i < item.graph.node_size(); ++i) {
- const NodeDef& original = item.graph.node(i);
- const NodeDef& optimized = output.node(i);
- EXPECT_EQ(original.name(), optimized.name());
- EXPECT_EQ(original.op(), optimized.op());
- EXPECT_EQ(original.input_size(), optimized.input_size());
- for (int j = 0; j < original.input_size(); ++j) {
- EXPECT_EQ(original.input(j), optimized.input(j));
- }
- }
+ VerifyGraphsMatch(item.graph, output, __LINE__);
auto tensors = EvaluateNodes(output, fetch);
EXPECT_EQ(1, tensors.size());
@@ -2777,7 +2797,7 @@ TEST_F(ArithmeticOptimizerTest, RemoveLogicalNot) {
ArithmeticOptimizer optimizer;
EnableOnlyRemoveLogicalNot(&optimizer);
OptimizeTwice(&optimizer, &item, &output);
- LOG(INFO) << output.DebugString();
+
int found = 0;
for (const NodeDef& node : output.node()) {
if (node.name() == "id_not_eq") {