diff options
author | Piotr Padlewski <prazek@google.com> | 2018-09-23 18:30:38 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-23 18:37:06 -0700 |
commit | fcd7840fbf49802be4bb7f67671465338b7b78a4 (patch) | |
tree | 017562cbd2b66b462a3562d667f3dae2a99c0ee5 /tensorflow/contrib/data | |
parent | 167272ead245ac9e0183da807d996ba9d6e401b0 (diff) |
Fix noop elimination optimization.
Fix for b/116169724
Only remove noops if they refer to const nodes.
PiperOrigin-RevId: 214199200
Diffstat (limited to 'tensorflow/contrib/data')
-rw-r--r-- | tensorflow/contrib/data/python/kernel_tests/optimization/BUILD | 16 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py | 57 |
2 files changed, 73 insertions, 0 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD b/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD index b3187bf61b..a2fc244ced 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD @@ -110,6 +110,22 @@ py_test( ) py_test( + name = "noop_elimination_test", + size = "small", + srcs = ["noop_elimination_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/data/python/ops:batching", + "//tensorflow/contrib/data/python/ops:interleave_ops", + "//tensorflow/contrib/data/python/ops:optimization", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( name = "optimize_dataset_op_test", size = "small", srcs = ["optimize_dataset_op_test.py"], diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py new file mode 100644 index 0000000000..507feda3ad --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py @@ -0,0 +1,57 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the MapParallelization optimization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.ops import optimization +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class NoopEliminationTest(test.TestCase): + + def testNoopElimination(self): + a = constant_op.constant(1, dtype=dtypes.int64) + b = constant_op.constant(2, dtype=dtypes.int64) + some_tensor = math_ops.mul(a, b) + + dataset = dataset_ops.Dataset.range(5) + dataset = dataset.apply( + optimization.assert_next( + ["FiniteRepeat", "FiniteSkip", "Prefetch", "Prefetch"])) + dataset = dataset.repeat(some_tensor).skip(5).prefetch(0).take(-1).skip( + 0).repeat(1).prefetch(0) + dataset = dataset.apply(optimization.optimize(["noop_elimination"])) + + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + for x in range(5): + result = sess.run(get_next) + self.assertAllEqual(result, x) + + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + +if __name__ == "__main__": + test.main() |