diff options
Diffstat (limited to 'tensorflow/core/grappler/optimizers/data/noop_elimination_test.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/data/noop_elimination_test.cc | 210 |
1 files changed, 210 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/noop_elimination_test.cc b/tensorflow/core/grappler/optimizers/data/noop_elimination_test.cc new file mode 100644 index 0000000000..a6cc63edba --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/noop_elimination_test.cc @@ -0,0 +1,210 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/data/noop_elimination.h" +#include <tuple> +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { +namespace { + +std::vector<std::pair<string, AttrValue>> GetCommonAttributes() { + AttrValue shapes_attr, types_attr; + SetAttrValue("output_shapes", &shapes_attr); + SetAttrValue("output_types", &types_attr); + std::vector<std::pair<string, AttrValue>> commonAttributes = { + {"output_shapes", shapes_attr}, {"output_types", types_attr}}; + + return commonAttributes; +} + +NodeDef *MakeUnaryNode(const std::string &node_type, int count, + string input_node, MutableGraphView *graph) { + NodeDef *node_count = graph_utils::AddScalarConstNode<int64>(count, graph); + return graph_utils::AddNode("", node_type, + {std::move(input_node), node_count->name()}, + GetCommonAttributes(), graph); +} + +NodeDef *MakeCacheNode(string input_node, MutableGraphView *graph) { + NodeDef *node_filename = + graph_utils::AddScalarConstNode<StringPiece>("", graph); + return graph_utils::AddNode("", "CacheDataset", + {std::move(input_node), node_filename->name()}, + GetCommonAttributes(), graph); +} + +NodeDef *MakeRangeNode(MutableGraphView *graph) { + auto *start_node = graph_utils::AddScalarConstNode<int64>(0, graph); + auto *stop_node = graph_utils::AddScalarConstNode<int64>(10, graph); + auto *step_node = graph_utils::AddScalarConstNode<int64>(1, graph); + + std::vector<string> range_inputs = {start_node->name(), stop_node->name(), + step_node->name()}; + + return graph_utils::AddNode("", "RangeDataset", range_inputs, + GetCommonAttributes(), graph); +} + +struct NoOpLastEliminationTest + : ::testing::TestWithParam<std::tuple<std::string, int, bool>> {}; + +// This test checks whether the no-op elimination correctly handles +// transformations at the end of the pipeline. +TEST_P(NoOpLastEliminationTest, EliminateLastNoOpNode) { + GrapplerItem item; + MutableGraphView graph(&item.graph); + + const std::string &node_type = std::get<0>(GetParam()); + const int node_count = std::get<1>(GetParam()); + const bool should_keep_node = std::get<2>(GetParam()); + + NodeDef *range_node = MakeRangeNode(&graph); + + NodeDef *node = + MakeUnaryNode(node_type, node_count, range_node->name(), &graph); + + NoOpElimination optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + + EXPECT_EQ(graph_utils::ContainsGraphNodeWithName(node->name(), output), + should_keep_node); +} + +INSTANTIATE_TEST_CASE_P( + BasicRemovalTest, NoOpLastEliminationTest, + ::testing::Values(std::make_tuple("TakeDataset", -3, false), + std::make_tuple("TakeDataset", -1, false), + std::make_tuple("TakeDataset", 0, true), + std::make_tuple("TakeDataset", 3, true), + std::make_tuple("SkipDataset", -1, true), + std::make_tuple("SkipDataset", 0, false), + std::make_tuple("SkipDataset", 3, true), + std::make_tuple("RepeatDataset", 1, false), + std::make_tuple("RepeatDataset", 2, true))); + +struct NoOpMiddleEliminationTest + : ::testing::TestWithParam<std::tuple<std::string, int, bool>> {}; + +// This test checks whether the no-op elimination correctly handles +// transformations int the middle of the pipeline. +TEST_P(NoOpMiddleEliminationTest, EliminateMiddleNoOpNode) { + GrapplerItem item; + MutableGraphView graph(&item.graph); + + const std::string &node_type = std::get<0>(GetParam()); + const int node_count = std::get<1>(GetParam()); + const bool should_keep_node = std::get<2>(GetParam()); + + NodeDef *range_node = MakeRangeNode(&graph); + + NodeDef *node = + MakeUnaryNode(node_type, node_count, range_node->name(), &graph); + + NodeDef *cache_node = MakeCacheNode(node->name(), &graph); + NoOpElimination optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + + EXPECT_EQ(graph_utils::ContainsGraphNodeWithName(node->name(), output), + should_keep_node); + EXPECT_TRUE( + graph_utils::ContainsGraphNodeWithName(cache_node->name(), output)); + + NodeDef cache_node_out = output.node( + graph_utils::FindGraphNodeWithName(cache_node->name(), output)); + + EXPECT_EQ(cache_node_out.input_size(), 2); + auto last_node_input = (should_keep_node ? node : range_node)->name(); + EXPECT_EQ(cache_node_out.input(0), last_node_input); +} + +INSTANTIATE_TEST_CASE_P( + BasicRemovalTest, NoOpMiddleEliminationTest, + ::testing::Values(std::make_tuple("TakeDataset", -1, false), + std::make_tuple("TakeDataset", -3, false), + std::make_tuple("TakeDataset", 0, true), + std::make_tuple("TakeDataset", 3, true), + std::make_tuple("SkipDataset", -1, true), + std::make_tuple("SkipDataset", 0, false), + std::make_tuple("SkipDataset", 3, true), + std::make_tuple("RepeatDataset", 1, false), + std::make_tuple("RepeatDataset", 2, true))); + +using NodesTypes = std::tuple<std::pair<string, int>, std::pair<string, int>>; +struct NoOpMultipleEliminationTest : ::testing::TestWithParam<NodesTypes> {}; + +// This test checks whether the no-op elimination correctly removes +// multiple noop nodes. +TEST_P(NoOpMultipleEliminationTest, EliminateMultipleNoOpNode) { + GrapplerItem item; + MutableGraphView graph(&item.graph); + + static_assert(std::tuple_size<NodesTypes>::value == 2, + "Make sure to include everything in the test"); + const std::vector<std::pair<string, int>> noop_nodes = { + std::get<0>(GetParam()), std::get<1>(GetParam())}; + + NodeDef *range_node = MakeRangeNode(&graph); + + NodeDef *previous = range_node; + std::vector<string> nodes_to_remove; + nodes_to_remove.reserve(noop_nodes.size()); + + for (const auto &noop_node : noop_nodes) { + NodeDef *node = MakeUnaryNode(noop_node.first, noop_node.second, + previous->name(), &graph); + nodes_to_remove.push_back(node->name()); + previous = node; + } + + NodeDef *cache_node = MakeCacheNode(previous->name(), &graph); + NoOpElimination optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + + for (const auto &noop_node_name : nodes_to_remove) + EXPECT_FALSE( + graph_utils::ContainsGraphNodeWithName(noop_node_name, output)); + + EXPECT_TRUE( + graph_utils::ContainsGraphNodeWithName(cache_node->name(), output)); + + NodeDef cache_node_out = output.node( + graph_utils::FindGraphNodeWithName(cache_node->name(), output)); + + EXPECT_EQ(cache_node_out.input_size(), 2); + EXPECT_EQ(cache_node_out.input(0), range_node->name()); +} + +const auto *const kTakeNode = new std::pair<string, int>{"TakeDataset", -1}; +const auto *const kSkipNode = new std::pair<string, int>{"SkipDataset", 0}; +const auto *const kRepeatNode = new std::pair<string, int>{"RepeatDataset", 1}; + +INSTANTIATE_TEST_CASE_P( + BasicRemovalTest, NoOpMultipleEliminationTest, + ::testing::Combine(::testing::Values(*kTakeNode, *kSkipNode, *kRepeatNode), + ::testing::Values(*kTakeNode, *kSkipNode, + *kRepeatNode))); + +} // namespace +} // namespace grappler +} // namespace tensorflow |