aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
diff options
context:
space:
mode:
authorGravatar Jingyue Wu <jingyue@google.com>2018-05-11 09:27:13 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-11 09:30:19 -0700
commit4aa456ef505f60fed357b9e321703468471304c7 (patch)
tree168690cbaff3ddc0efd9bccdb7c73aeac2a004b2 /tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
parent6a43945520afbf4a6e54923402ae65c1e8361dfa (diff)
ArithmeticOptimizer assumes valid feeds in aggressive mode.
ArithmeticOptimizer depends heavily on shapes in some stages. PiperOrigin-RevId: 196264319
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc61
1 files changed, 61 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index d648fa0787..27c0dde419 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -964,6 +964,67 @@ TEST_F(ArithmeticOptimizerTest, IdentityReshape) {
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
+TEST_F(ArithmeticOptimizerTest, NotAssumeValidFeeds) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output inputs =
+ ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({4, 3, 28, 28}));
+ Output target_shape = ops::Const(s, {4, 3, 28, 28}, {4});
+ Output reshape = ops::Reshape(s, inputs, target_shape);
+ Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
+
+ auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4, 3, 28, 28}));
+ GrapplerItem item;
+ item.fetch = {"outputs"};
+ item.feed = {{"Placeholder", x_t}};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
+ EXPECT_EQ(1, tensors_expected.size());
+
+ GraphDef output;
+ TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
+
+ item.graph.Swap(&output);
+ TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
+
+ // The reshape is preserved because the shape of the placeholder can be
+ // different from the shape of the actual feed.
+ EXPECT_EQ(1, CountOpNodes(output, "Reshape"));
+
+ auto tensors = EvaluateNodes(output, item.fetch, item.feed);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
+}
+
+TEST_F(ArithmeticOptimizerTest, AssumeValidFeedsInAggressiveMode) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output inputs =
+ ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({4, 3, 28, 28}));
+ Output target_shape = ops::Const(s, {4, 3, 28, 28}, {4});
+ Output reshape = ops::Reshape(s, inputs, target_shape);
+ Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
+
+ auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4, 3, 28, 28}));
+ GrapplerItem item;
+ item.fetch = {"outputs"};
+ item.feed = {{"Placeholder", x_t}};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
+ EXPECT_EQ(1, tensors_expected.size());
+ GraphDef output;
+ TF_EXPECT_OK(ArithmeticOptimizer(RewriterConfig::AGGRESSIVE)
+ .Optimize(nullptr, item, &output));
+
+ item.graph.Swap(&output);
+ TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
+
+ EXPECT_EQ(0, CountOpNodes(output, "Reshape"));
+ auto tensors = EvaluateNodes(output, item.fetch, item.feed);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
+}
+
TEST_F(ArithmeticOptimizerTest, NotIdentityReshape) {
// Reshape from [-1,3,28,28] to [8,-1,28,28] is not identity, because it can
// be from [4,3,28,28] to [8,6,28,28].