diff options
author | 2018-09-08 09:22:58 -0700 | |
---|---|---|
committer | 2018-09-08 09:33:11 -0700 | |
commit | 4136bd49d92c80de3c6ae03ffdb2524b36e96fa8 (patch) | |
tree | 64dd2dff838d5cca9257739cbfc061aa4212248b /tensorflow/contrib/data | |
parent | f04f67f58fc6a5823fc4a78bd068c76f69d9fdd2 (diff) |
[tf.data] Refactoring of optimization tests.
PiperOrigin-RevId: 212119773
Diffstat (limited to 'tensorflow/contrib/data')
-rw-r--r-- | tensorflow/contrib/data/python/kernel_tests/BUILD | 15 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py | 2 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/kernel_tests/optimization/BUILD | 35 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py | 64 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py) | 37 |
5 files changed, 97 insertions, 56 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index b9320e5fef..6f0111a2bd 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -286,21 +286,6 @@ py_test( ) py_test( - name = "optimize_dataset_op_test", - size = "small", - srcs = ["optimize_dataset_op_test.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/data/python/ops:optimization", - "//tensorflow/python:client_testlib", - "//tensorflow/python:errors", - "//tensorflow/python/data/ops:dataset_ops", - "//third_party/py/numpy", - "@absl_py//absl/testing:parameterized", - ], -) - -py_test( name = "parsing_ops_test", size = "small", srcs = ["parsing_ops_test.py"], diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py index dc9d56dd53..55c9ac68dd 100644 --- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py @@ -209,7 +209,7 @@ class MapDatasetBenchmark(test.Benchmark): end = time.time() chained_deltas.append(end - start) - fused_dataset = dataset = dataset.apply( + fused_dataset = dataset.apply( batching.map_and_batch( math_ops.matmul, num_parallel_calls=num_calls, diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD b/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD index b299e0736f..459bdf66f3 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD @@ -7,6 +7,34 @@ exports_files(["LICENSE"]) load("//tensorflow:tensorflow.bzl", "py_test") py_test( + name = "assert_next_dataset_op_test", + size = "medium", + srcs = ["assert_next_dataset_op_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/data/python/ops:optimization", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "latency_all_edges_test", + size = "small", + srcs = ["latency_all_edges_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/data/python/kernel_tests:stats_dataset_test_base", + "//tensorflow/contrib/data/python/ops:optimization", + "//tensorflow/contrib/data/python/ops:stats_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( name = "map_vectorization_test", size = "small", srcs = ["map_vectorization_test.py"], @@ -46,16 +74,15 @@ py_test( ) py_test( - name = "latency_all_edges_test", + name = "optimize_dataset_op_test", size = "small", - srcs = ["latency_all_edges_test.py"], + srcs = ["optimize_dataset_op_test.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/contrib/data/python/kernel_tests:stats_dataset_test_base", "//tensorflow/contrib/data/python/ops:optimization", - "//tensorflow/contrib/data/python/ops:stats_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", ], ) diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py new file mode 100644 index 0000000000..bd7b50b902 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py @@ -0,0 +1,64 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.ops import optimization +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import errors +from tensorflow.python.platform import test + + +class AssertNextDatasetTest(test.TestCase): + + def testAssertNext(self): + dataset = dataset_ops.Dataset.from_tensors(0).apply( + optimization.assert_next(["Map"])).map(lambda x: x) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + self.assertEqual(0, sess.run(get_next)) + + def testAssertNextInvalid(self): + dataset = dataset_ops.Dataset.from_tensors(0).apply( + optimization.assert_next(["Whoops"])).map(lambda x: x) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + "Asserted Whoops transformation at offset 0 but encountered " + "Map transformation instead."): + sess.run(get_next) + + def testAssertNextShort(self): + dataset = dataset_ops.Dataset.from_tensors(0).apply( + optimization.assert_next(["Map", "Whoops"])).map(lambda x: x) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + "Asserted next 2 transformations but encountered only 1."): + sess.run(get_next) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py index 089717156c..909da5aee0 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py @@ -17,7 +17,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from absl.testing import parameterized import numpy as np from tensorflow.contrib.data.python.ops import optimization @@ -29,41 +28,7 @@ from tensorflow.python.ops import random_ops from tensorflow.python.platform import test -class OptimizeDatasetTest(test.TestCase, parameterized.TestCase): - - def testAssertSuffix(self): - dataset = dataset_ops.Dataset.from_tensors(0).apply( - optimization.assert_next(["Map"])).map(lambda x: x) - iterator = dataset.make_one_shot_iterator() - get_next = iterator.get_next() - - with self.test_session() as sess: - self.assertEqual(0, sess.run(get_next)) - - def testAssertSuffixInvalid(self): - dataset = dataset_ops.Dataset.from_tensors(0).apply( - optimization.assert_next(["Whoops"])).map(lambda x: x) - iterator = dataset.make_one_shot_iterator() - get_next = iterator.get_next() - - with self.test_session() as sess: - with self.assertRaisesRegexp( - errors.InvalidArgumentError, - "Asserted Whoops transformation at offset 0 but encountered " - "Map transformation instead."): - sess.run(get_next) - - def testAssertSuffixShort(self): - dataset = dataset_ops.Dataset.from_tensors(0).apply( - optimization.assert_next(["Map", "Whoops"])).map(lambda x: x) - iterator = dataset.make_one_shot_iterator() - get_next = iterator.get_next() - - with self.test_session() as sess: - with self.assertRaisesRegexp( - errors.InvalidArgumentError, - "Asserted next 2 transformations but encountered only 1."): - sess.run(get_next) +class OptimizeDatasetTest(test.TestCase): def testOptimizationDefault(self): dataset = dataset_ops.Dataset.range(10).apply( |