diff options
author | Piotr Padlewski <prazek@google.com> | 2018-09-23 18:30:38 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-23 18:37:06 -0700 |
commit | fcd7840fbf49802be4bb7f67671465338b7b78a4 (patch) | |
tree | 017562cbd2b66b462a3562d667f3dae2a99c0ee5 /tensorflow/core/grappler | |
parent | 167272ead245ac9e0183da807d996ba9d6e401b0 (diff) |
Fix noop elimination optimization.
Fix for b/116169724
Only remove noops if they refer to const nodes.
PiperOrigin-RevId: 214199200
Diffstat (limited to 'tensorflow/core/grappler')
4 files changed, 65 insertions, 7 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_utils.cc index b3f60e34f9..2dd9ee822e 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/graph_utils.cc @@ -88,6 +88,16 @@ NodeDef* AddScalarConstNodeHelper( } // namespace +NodeDef* AddScalarPlaceholder(DataType dtype, MutableGraphView* graph) { + NodeDef node; + node.set_op("Placeholder"); + SetUniqueGraphNodeName(node.op(), graph->GetGraph(), &node); + (*node.mutable_attr())["dtype"].set_type(dtype); + TensorShapeProto* shape = (*node.mutable_attr())["shape"].mutable_shape(); + shape->set_unknown_rank(false); + return graph->AddNode(std::move(node)); +} + NodeDef* AddNode(StringPiece name, StringPiece op, const std::vector<string>& inputs, const std::vector<std::pair<string, AttrValue>>& attributes, diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.h b/tensorflow/core/grappler/optimizers/data/graph_utils.h index 1652afcd9e..b117482db2 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_utils.h +++ b/tensorflow/core/grappler/optimizers/data/graph_utils.h @@ -37,6 +37,9 @@ NodeDef* AddNode(StringPiece name, StringPiece op, const std::vector<std::pair<string, AttrValue>>& attributes, MutableGraphView* graph); +// Adds Placeholder node for given type. +NodeDef* AddScalarPlaceholder(DataType dtype, MutableGraphView* graph); + // Adds a Const node with the given value to the graph. template <typename T> NodeDef* AddScalarConstNode(T v, MutableGraphView* graph) { diff --git a/tensorflow/core/grappler/optimizers/data/noop_elimination.cc b/tensorflow/core/grappler/optimizers/data/noop_elimination.cc index a26f1000a3..cf5a19bab1 100644 --- a/tensorflow/core/grappler/optimizers/data/noop_elimination.cc +++ b/tensorflow/core/grappler/optimizers/data/noop_elimination.cc @@ -33,25 +33,27 @@ namespace { bool IsTakeAll(const NodeDef& take_node, const GraphView& graph) { if (take_node.op() != "TakeDataset") return false; - const NodeDef& count_node = *graph.GetNode(take_node.input(1)); + const auto& count_node = *graph.GetNode(take_node.input(1)); + if (count_node.op() != "Const") return false; // We are looking only for 'take' with negative count. return count_node.attr().at("value").tensor().int64_val(0) < 0; } +bool IsConstNodeWithValue(const NodeDef& node, int value) { + if (node.op() != "Const") return false; + return node.attr().at("value").tensor().int64_val(0) == value; +} + bool IsSkipNone(const NodeDef& skip_node, const GraphView& graph) { if (skip_node.op() != "SkipDataset") return false; - - const NodeDef& count_node = *graph.GetNode(skip_node.input(1)); // We are looking only for skip(0) nodes. - return count_node.attr().at("value").tensor().int64_val(0) == 0; + return IsConstNodeWithValue(*graph.GetNode(skip_node.input(1)), 0); } bool IsRepeatOne(const NodeDef& repeat_node, const GraphView& graph) { if (repeat_node.op() != "RepeatDataset") return false; - - const NodeDef& count_node = *graph.GetNode(repeat_node.input(1)); // We are looking only for repeat(1) nodes. - return count_node.attr().at("value").tensor().int64_val(0) == 1; + return IsConstNodeWithValue(*graph.GetNode(repeat_node.input(1)), 1); } bool IsNoOp(const NodeDef& node, const GraphView& graph) { diff --git a/tensorflow/core/grappler/optimizers/data/noop_elimination_test.cc b/tensorflow/core/grappler/optimizers/data/noop_elimination_test.cc index f445e75aa7..be1a66df75 100644 --- a/tensorflow/core/grappler/optimizers/data/noop_elimination_test.cc +++ b/tensorflow/core/grappler/optimizers/data/noop_elimination_test.cc @@ -43,6 +43,14 @@ NodeDef *MakeUnaryNode(StringPiece node_type, int count, string input_node, GetCommonAttributes(), graph); } +NodeDef *MakeUnaryNonConstNode(StringPiece node_type, string input_node, + MutableGraphView *graph) { + NodeDef *node_count = graph_utils::AddScalarPlaceholder(DT_INT32, graph); + return graph_utils::AddNode("", node_type, + {std::move(input_node), node_count->name()}, + GetCommonAttributes(), graph); +} + NodeDef *MakeCacheNode(string input_node, MutableGraphView *graph) { NodeDef *node_filename = graph_utils::AddScalarConstNode<StringPiece>("", graph); @@ -205,6 +213,41 @@ INSTANTIATE_TEST_CASE_P( ::testing::Values(*kTakeNode, *kSkipNode, *kRepeatNode))); +struct NoOpPlaceholdersTest + : ::testing::TestWithParam<std::tuple<string, string>> {}; + +TEST_P(NoOpPlaceholdersTest, NonConstNoOpNode) { + GrapplerItem item; + MutableGraphView graph(&item.graph); + + static_assert(std::tuple_size<NodesTypes>::value == 2, + "Make sure to include everything in the test"); + const std::vector<string> noop_nodes = {std::get<0>(GetParam()), + std::get<1>(GetParam())}; + NodeDef *range_node = MakeRangeNode(&graph); + std::vector<string> nodes_to_keep; + nodes_to_keep.reserve(noop_nodes.size()); + NodeDef *previous = range_node; + + for (const auto &noop_node : noop_nodes) { + NodeDef *node = MakeUnaryNonConstNode(noop_node, previous->name(), &graph); + nodes_to_keep.push_back(node->name()); + previous = node; + } + + NoOpElimination optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + for (const auto &noop_node_name : nodes_to_keep) + EXPECT_TRUE(graph_utils::ContainsGraphNodeWithName(noop_node_name, output)); +} + +INSTANTIATE_TEST_CASE_P( + DoNotRemovePlaceholders, NoOpPlaceholdersTest, + ::testing::Combine( + ::testing::Values("TakeDataset", "SkipDataset", "RepeatDataset"), + ::testing::Values("TakeDataset", "SkipDataset", "RepeatDataset"))); + } // namespace } // namespace grappler } // namespace tensorflow |