diff options
author | Jingyue Wu <jingyue@google.com> | 2018-05-11 09:27:13 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-11 09:30:19 -0700 |
commit | 4aa456ef505f60fed357b9e321703468471304c7 (patch) | |
tree | 168690cbaff3ddc0efd9bccdb7c73aeac2a004b2 /tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc | |
parent | 6a43945520afbf4a6e54923402ae65c1e8361dfa (diff) |
ArithmeticOptimizer assumes valid feeds in aggressive mode.
ArithmeticOptimizer depends heavily on shapes in some stages.
PiperOrigin-RevId: 196264319
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc | 61 |
1 files changed, 61 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index d648fa0787..27c0dde419 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -964,6 +964,67 @@ TEST_F(ArithmeticOptimizerTest, IdentityReshape) { test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6); } +TEST_F(ArithmeticOptimizerTest, NotAssumeValidFeeds) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output inputs = + ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({4, 3, 28, 28})); + Output target_shape = ops::Const(s, {4, 3, 28, 28}, {4}); + Output reshape = ops::Reshape(s, inputs, target_shape); + Output outputs = ops::Identity(s.WithOpName("outputs"), reshape); + + auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4, 3, 28, 28})); + GrapplerItem item; + item.fetch = {"outputs"}; + item.feed = {{"Placeholder", x_t}}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed); + EXPECT_EQ(1, tensors_expected.size()); + + GraphDef output; + TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); + + item.graph.Swap(&output); + TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); + + // The reshape is preserved because the shape of the placeholder can be + // different from the shape of the actual feed. + EXPECT_EQ(1, CountOpNodes(output, "Reshape")); + + auto tensors = EvaluateNodes(output, item.fetch, item.feed); + EXPECT_EQ(1, tensors.size()); + test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6); +} + +TEST_F(ArithmeticOptimizerTest, AssumeValidFeedsInAggressiveMode) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output inputs = + ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({4, 3, 28, 28})); + Output target_shape = ops::Const(s, {4, 3, 28, 28}, {4}); + Output reshape = ops::Reshape(s, inputs, target_shape); + Output outputs = ops::Identity(s.WithOpName("outputs"), reshape); + + auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4, 3, 28, 28})); + GrapplerItem item; + item.fetch = {"outputs"}; + item.feed = {{"Placeholder", x_t}}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed); + EXPECT_EQ(1, tensors_expected.size()); + GraphDef output; + TF_EXPECT_OK(ArithmeticOptimizer(RewriterConfig::AGGRESSIVE) + .Optimize(nullptr, item, &output)); + + item.graph.Swap(&output); + TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); + + EXPECT_EQ(0, CountOpNodes(output, "Reshape")); + auto tensors = EvaluateNodes(output, item.fetch, item.feed); + EXPECT_EQ(1, tensors.size()); + test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6); +} + TEST_F(ArithmeticOptimizerTest, NotIdentityReshape) { // Reshape from [-1,3,28,28] to [8,-1,28,28] is not identity, because it can // be from [4,3,28,28] to [8,6,28,28]. |