aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Jingyue Wu <jingyue@google.com>2018-05-30 22:00:32 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-30 22:03:11 -0700
commit52a21f5df5ba0c7eeae91e4f818a6f2b989734cb (patch)
treef439a29814c8ba532ef52fff7e4fe93ad2cf1bc5
parentf33d551ea6ed6a46c70cafd3a567933fe1159ddf (diff)
Improve ReshapeIsIdentity to work with symbolic shapes.
For example, with this CL, ArithmeticOptimizer can optimize the Reshape below into a no-op. s = Shape(t) Reshape(t, Concat(s[0], s[1], s[2], s[3])) PiperOrigin-RevId: 198668726
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc35
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc40
2 files changed, 41 insertions, 34 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 9c18c45f18..e7f385cbd6 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -209,40 +209,7 @@ bool ReshapeIsIdentity(const NodeDef& reshape, const NodeDef& input,
return false;
}
- const PartialTensorShape& src_shape = input_props[output_pos].shape();
- const PartialTensorShape& dst_shape = reshape_props[0].shape();
-
- if (src_shape.unknown_rank() || dst_shape.unknown_rank()) {
- return false;
- }
-
- if (!dst_shape.IsCompatibleWith(src_shape)) {
- return false;
- }
-
- // Returns false when src_shape or dst_shape has >=2 dimensions with unknown
- // sizes.
- auto num_unknown_dim_sizes = [](const PartialTensorShape& partial_shape) {
- auto dim_sizes = partial_shape.dim_sizes();
- return std::count_if(dim_sizes.begin(), dim_sizes.end(),
- [](int dim) { return dim < 0; });
- };
- int src_num_unknown_dim_sizes = num_unknown_dim_sizes(src_shape);
- int dst_num_unknown_dim_sizes = num_unknown_dim_sizes(dst_shape);
- if (src_num_unknown_dim_sizes > 1 || dst_num_unknown_dim_sizes > 1) {
- return false;
- }
-
- // If dst_num_unknown_dim_sizes != src_num_unknown_dim_sizes we would weaken
- // shape inference in subsequent passes if we removed this reshape.
- if (src_num_unknown_dim_sizes != dst_num_unknown_dim_sizes) {
- return false;
- }
-
- // Remove the reshape if both are fully defined or partially defined and the
- // unknown or symbolic shape appears on the same dimension, i.e., if
- // IsIdenticalTo returns true.
- return dst_shape.IsIdenticalTo(src_shape);
+ return ShapesSymbolicallyEqual(input_props[output_pos], reshape_props[0]);
}
NodeDef* GetTailOfValuePreservingChain(
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index a908416e45..f678ea7227 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -989,6 +989,46 @@ TEST_F(ArithmeticOptimizerTest, IdentityReshape) {
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
+TEST_F(ArithmeticOptimizerTest, IdentityReshapeBetweenSymbolicShapes) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output inputs =
+ ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({-1, 3, -1, -1}));
+ Output inputs_shape = ops::Shape(s, inputs);
+ // The target shape of the reshape is the concatenation of `batch_size`, 3,
+ // `height, and `width`.
+ Output batch_size = ops::Slice(s, inputs_shape, ops::Const(s, {0}, {1}),
+ ops::Const(s, {1}, {1}));
+ Output height = ops::Slice(s, inputs_shape, ops::Const(s, {2}, {1}),
+ ops::Const(s, {1}, {1}));
+ Output width = ops::Slice(s, inputs_shape, ops::Const(s, {3}, {1}),
+ ops::Const(s, {1}, {1}));
+ Output target_shape =
+ ops::Concat(s.WithOpName("target_shape"),
+ {batch_size, ops::Const(s, {3}, {1}), height, width},
+ ops::Const(s, {0}, {}));
+ Output reshape = ops::Reshape(s, inputs, target_shape);
+ Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
+
+ GrapplerItem item;
+ item.fetch = {"outputs"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 3, 28, 28}));
+ auto tensors_expected =
+ EvaluateNodes(item.graph, item.fetch, {{"Placeholder", x_t}});
+ EXPECT_EQ(1, tensors_expected.size());
+ GraphDef output;
+ TF_EXPECT_OK(ArithmeticOptimizer(RewriterConfig::AGGRESSIVE)
+ .Optimize(nullptr, item, &output));
+
+ item.graph.Swap(&output);
+ TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
+
+ EXPECT_EQ(0, CountOpNodes(output, "Reshape"));
+ auto tensors = EvaluateNodes(output, item.fetch, {{"Placeholder", x_t}});
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
+}
+
TEST_F(ArithmeticOptimizerTest, NotAssumeValidFeeds) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output inputs =