diff options
author | Jingyue Wu <jingyue@google.com> | 2018-05-30 22:00:32 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-30 22:03:11 -0700 |
commit | 52a21f5df5ba0c7eeae91e4f818a6f2b989734cb (patch) | |
tree | f439a29814c8ba532ef52fff7e4fe93ad2cf1bc5 /tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc | |
parent | f33d551ea6ed6a46c70cafd3a567933fe1159ddf (diff) |
Improve ReshapeIsIdentity to work with symbolic shapes.
For example, with this CL, ArithmeticOptimizer can optimize the Reshape below
into a no-op.
s = Shape(t)
Reshape(t, Concat(s[0], s[1], s[2], s[3]))
PiperOrigin-RevId: 198668726
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc | 40 |
1 files changed, 40 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index a908416e45..f678ea7227 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -989,6 +989,46 @@ TEST_F(ArithmeticOptimizerTest, IdentityReshape) { test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6); } +TEST_F(ArithmeticOptimizerTest, IdentityReshapeBetweenSymbolicShapes) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output inputs = + ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({-1, 3, -1, -1})); + Output inputs_shape = ops::Shape(s, inputs); + // The target shape of the reshape is the concatenation of `batch_size`, 3, + // `height, and `width`. + Output batch_size = ops::Slice(s, inputs_shape, ops::Const(s, {0}, {1}), + ops::Const(s, {1}, {1})); + Output height = ops::Slice(s, inputs_shape, ops::Const(s, {2}, {1}), + ops::Const(s, {1}, {1})); + Output width = ops::Slice(s, inputs_shape, ops::Const(s, {3}, {1}), + ops::Const(s, {1}, {1})); + Output target_shape = + ops::Concat(s.WithOpName("target_shape"), + {batch_size, ops::Const(s, {3}, {1}), height, width}, + ops::Const(s, {0}, {})); + Output reshape = ops::Reshape(s, inputs, target_shape); + Output outputs = ops::Identity(s.WithOpName("outputs"), reshape); + + GrapplerItem item; + item.fetch = {"outputs"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 3, 28, 28})); + auto tensors_expected = + EvaluateNodes(item.graph, item.fetch, {{"Placeholder", x_t}}); + EXPECT_EQ(1, tensors_expected.size()); + GraphDef output; + TF_EXPECT_OK(ArithmeticOptimizer(RewriterConfig::AGGRESSIVE) + .Optimize(nullptr, item, &output)); + + item.graph.Swap(&output); + TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); + + EXPECT_EQ(0, CountOpNodes(output, "Reshape")); + auto tensors = EvaluateNodes(output, item.fetch, {{"Placeholder", x_t}}); + EXPECT_EQ(1, tensors.size()); + test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6); +} + TEST_F(ArithmeticOptimizerTest, NotAssumeValidFeeds) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output inputs = |