diff options
author | Piotr Padlewski <prazek@google.com> | 2018-09-23 18:30:38 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-23 18:37:06 -0700 |
commit | fcd7840fbf49802be4bb7f67671465338b7b78a4 (patch) | |
tree | 017562cbd2b66b462a3562d667f3dae2a99c0ee5 | |
parent | 167272ead245ac9e0183da807d996ba9d6e401b0 (diff) |
Fix noop elimination optimization.
Fix for b/116169724
Only remove noops if they refer to const nodes.
PiperOrigin-RevId: 214199200
6 files changed, 138 insertions, 7 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD b/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD index b3187bf61b..a2fc244ced 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD @@ -110,6 +110,22 @@ py_test( ) py_test( + name = "noop_elimination_test", + size = "small", + srcs = ["noop_elimination_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/data/python/ops:batching", + "//tensorflow/contrib/data/python/ops:interleave_ops", + "//tensorflow/contrib/data/python/ops:optimization", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( name = "optimize_dataset_op_test", size = "small", srcs = ["optimize_dataset_op_test.py"], diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py new file mode 100644 index 0000000000..507feda3ad --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py @@ -0,0 +1,57 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the MapParallelization optimization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.ops import optimization +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class NoopEliminationTest(test.TestCase): + + def testNoopElimination(self): + a = constant_op.constant(1, dtype=dtypes.int64) + b = constant_op.constant(2, dtype=dtypes.int64) + some_tensor = math_ops.mul(a, b) + + dataset = dataset_ops.Dataset.range(5) + dataset = dataset.apply( + optimization.assert_next( + ["FiniteRepeat", "FiniteSkip", "Prefetch", "Prefetch"])) + dataset = dataset.repeat(some_tensor).skip(5).prefetch(0).take(-1).skip( + 0).repeat(1).prefetch(0) + dataset = dataset.apply(optimization.optimize(["noop_elimination"])) + + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + for x in range(5): + result = sess.run(get_next) + self.assertAllEqual(result, x) + + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_utils.cc index b3f60e34f9..2dd9ee822e 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/graph_utils.cc @@ -88,6 +88,16 @@ NodeDef* AddScalarConstNodeHelper( } // namespace +NodeDef* AddScalarPlaceholder(DataType dtype, MutableGraphView* graph) { + NodeDef node; + node.set_op("Placeholder"); + SetUniqueGraphNodeName(node.op(), graph->GetGraph(), &node); + (*node.mutable_attr())["dtype"].set_type(dtype); + TensorShapeProto* shape = (*node.mutable_attr())["shape"].mutable_shape(); + shape->set_unknown_rank(false); + return graph->AddNode(std::move(node)); +} + NodeDef* AddNode(StringPiece name, StringPiece op, const std::vector<string>& inputs, const std::vector<std::pair<string, AttrValue>>& attributes, diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.h b/tensorflow/core/grappler/optimizers/data/graph_utils.h index 1652afcd9e..b117482db2 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_utils.h +++ b/tensorflow/core/grappler/optimizers/data/graph_utils.h @@ -37,6 +37,9 @@ NodeDef* AddNode(StringPiece name, StringPiece op, const std::vector<std::pair<string, AttrValue>>& attributes, MutableGraphView* graph); +// Adds Placeholder node for given type. +NodeDef* AddScalarPlaceholder(DataType dtype, MutableGraphView* graph); + // Adds a Const node with the given value to the graph. template <typename T> NodeDef* AddScalarConstNode(T v, MutableGraphView* graph) { diff --git a/tensorflow/core/grappler/optimizers/data/noop_elimination.cc b/tensorflow/core/grappler/optimizers/data/noop_elimination.cc index a26f1000a3..cf5a19bab1 100644 --- a/tensorflow/core/grappler/optimizers/data/noop_elimination.cc +++ b/tensorflow/core/grappler/optimizers/data/noop_elimination.cc @@ -33,25 +33,27 @@ namespace { bool IsTakeAll(const NodeDef& take_node, const GraphView& graph) { if (take_node.op() != "TakeDataset") return false; - const NodeDef& count_node = *graph.GetNode(take_node.input(1)); + const auto& count_node = *graph.GetNode(take_node.input(1)); + if (count_node.op() != "Const") return false; // We are looking only for 'take' with negative count. return count_node.attr().at("value").tensor().int64_val(0) < 0; } +bool IsConstNodeWithValue(const NodeDef& node, int value) { + if (node.op() != "Const") return false; + return node.attr().at("value").tensor().int64_val(0) == value; +} + bool IsSkipNone(const NodeDef& skip_node, const GraphView& graph) { if (skip_node.op() != "SkipDataset") return false; - - const NodeDef& count_node = *graph.GetNode(skip_node.input(1)); // We are looking only for skip(0) nodes. - return count_node.attr().at("value").tensor().int64_val(0) == 0; + return IsConstNodeWithValue(*graph.GetNode(skip_node.input(1)), 0); } bool IsRepeatOne(const NodeDef& repeat_node, const GraphView& graph) { if (repeat_node.op() != "RepeatDataset") return false; - - const NodeDef& count_node = *graph.GetNode(repeat_node.input(1)); // We are looking only for repeat(1) nodes. - return count_node.attr().at("value").tensor().int64_val(0) == 1; + return IsConstNodeWithValue(*graph.GetNode(repeat_node.input(1)), 1); } bool IsNoOp(const NodeDef& node, const GraphView& graph) { diff --git a/tensorflow/core/grappler/optimizers/data/noop_elimination_test.cc b/tensorflow/core/grappler/optimizers/data/noop_elimination_test.cc index f445e75aa7..be1a66df75 100644 --- a/tensorflow/core/grappler/optimizers/data/noop_elimination_test.cc +++ b/tensorflow/core/grappler/optimizers/data/noop_elimination_test.cc @@ -43,6 +43,14 @@ NodeDef *MakeUnaryNode(StringPiece node_type, int count, string input_node, GetCommonAttributes(), graph); } +NodeDef *MakeUnaryNonConstNode(StringPiece node_type, string input_node, + MutableGraphView *graph) { + NodeDef *node_count = graph_utils::AddScalarPlaceholder(DT_INT32, graph); + return graph_utils::AddNode("", node_type, + {std::move(input_node), node_count->name()}, + GetCommonAttributes(), graph); +} + NodeDef *MakeCacheNode(string input_node, MutableGraphView *graph) { NodeDef *node_filename = graph_utils::AddScalarConstNode<StringPiece>("", graph); @@ -205,6 +213,41 @@ INSTANTIATE_TEST_CASE_P( ::testing::Values(*kTakeNode, *kSkipNode, *kRepeatNode))); +struct NoOpPlaceholdersTest + : ::testing::TestWithParam<std::tuple<string, string>> {}; + +TEST_P(NoOpPlaceholdersTest, NonConstNoOpNode) { + GrapplerItem item; + MutableGraphView graph(&item.graph); + + static_assert(std::tuple_size<NodesTypes>::value == 2, + "Make sure to include everything in the test"); + const std::vector<string> noop_nodes = {std::get<0>(GetParam()), + std::get<1>(GetParam())}; + NodeDef *range_node = MakeRangeNode(&graph); + std::vector<string> nodes_to_keep; + nodes_to_keep.reserve(noop_nodes.size()); + NodeDef *previous = range_node; + + for (const auto &noop_node : noop_nodes) { + NodeDef *node = MakeUnaryNonConstNode(noop_node, previous->name(), &graph); + nodes_to_keep.push_back(node->name()); + previous = node; + } + + NoOpElimination optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + for (const auto &noop_node_name : nodes_to_keep) + EXPECT_TRUE(graph_utils::ContainsGraphNodeWithName(noop_node_name, output)); +} + +INSTANTIATE_TEST_CASE_P( + DoNotRemovePlaceholders, NoOpPlaceholdersTest, + ::testing::Combine( + ::testing::Values("TakeDataset", "SkipDataset", "RepeatDataset"), + ::testing::Values("TakeDataset", "SkipDataset", "RepeatDataset"))); + } // namespace } // namespace grappler } // namespace tensorflow |