aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
diff options
context:
space:
mode:
authorGravatar Jingyue Wu <jingyue@google.com>2018-05-30 22:00:32 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-30 22:03:11 -0700
commit52a21f5df5ba0c7eeae91e4f818a6f2b989734cb (patch)
treef439a29814c8ba532ef52fff7e4fe93ad2cf1bc5 /tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
parentf33d551ea6ed6a46c70cafd3a567933fe1159ddf (diff)
Improve ReshapeIsIdentity to work with symbolic shapes.
For example, with this CL, ArithmeticOptimizer can optimize the Reshape below into a no-op. s = Shape(t) Reshape(t, Concat(s[0], s[1], s[2], s[3])) PiperOrigin-RevId: 198668726
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc40
1 files changed, 40 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index a908416e45..f678ea7227 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -989,6 +989,46 @@ TEST_F(ArithmeticOptimizerTest, IdentityReshape) {
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
+TEST_F(ArithmeticOptimizerTest, IdentityReshapeBetweenSymbolicShapes) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output inputs =
+ ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({-1, 3, -1, -1}));
+ Output inputs_shape = ops::Shape(s, inputs);
+ // The target shape of the reshape is the concatenation of `batch_size`, 3,
+ // `height, and `width`.
+ Output batch_size = ops::Slice(s, inputs_shape, ops::Const(s, {0}, {1}),
+ ops::Const(s, {1}, {1}));
+ Output height = ops::Slice(s, inputs_shape, ops::Const(s, {2}, {1}),
+ ops::Const(s, {1}, {1}));
+ Output width = ops::Slice(s, inputs_shape, ops::Const(s, {3}, {1}),
+ ops::Const(s, {1}, {1}));
+ Output target_shape =
+ ops::Concat(s.WithOpName("target_shape"),
+ {batch_size, ops::Const(s, {3}, {1}), height, width},
+ ops::Const(s, {0}, {}));
+ Output reshape = ops::Reshape(s, inputs, target_shape);
+ Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
+
+ GrapplerItem item;
+ item.fetch = {"outputs"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 3, 28, 28}));
+ auto tensors_expected =
+ EvaluateNodes(item.graph, item.fetch, {{"Placeholder", x_t}});
+ EXPECT_EQ(1, tensors_expected.size());
+ GraphDef output;
+ TF_EXPECT_OK(ArithmeticOptimizer(RewriterConfig::AGGRESSIVE)
+ .Optimize(nullptr, item, &output));
+
+ item.graph.Swap(&output);
+ TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
+
+ EXPECT_EQ(0, CountOpNodes(output, "Reshape"));
+ auto tensors = EvaluateNodes(output, item.fetch, {{"Placeholder", x_t}});
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
+}
+
TEST_F(ArithmeticOptimizerTest, NotAssumeValidFeeds) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output inputs =