aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data
diff options
context:
space:
mode:
authorGravatar Jiri Simsa <jsimsa@google.com>2018-09-08 09:22:58 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-08 09:33:11 -0700
commit4136bd49d92c80de3c6ae03ffdb2524b36e96fa8 (patch)
tree64dd2dff838d5cca9257739cbfc061aa4212248b /tensorflow/contrib/data
parentf04f67f58fc6a5823fc4a78bd068c76f69d9fdd2 (diff)
[tf.data] Refactoring of optimization tests.
PiperOrigin-RevId: 212119773
Diffstat (limited to 'tensorflow/contrib/data')
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/BUILD15
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py2
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/BUILD35
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py64
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py)37
5 files changed, 97 insertions, 56 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index b9320e5fef..6f0111a2bd 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -286,21 +286,6 @@ py_test(
)
py_test(
- name = "optimize_dataset_op_test",
- size = "small",
- srcs = ["optimize_dataset_op_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/data/python/ops:optimization",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:errors",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- "@absl_py//absl/testing:parameterized",
- ],
-)
-
-py_test(
name = "parsing_ops_test",
size = "small",
srcs = ["parsing_ops_test.py"],
diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
index dc9d56dd53..55c9ac68dd 100644
--- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
@@ -209,7 +209,7 @@ class MapDatasetBenchmark(test.Benchmark):
end = time.time()
chained_deltas.append(end - start)
- fused_dataset = dataset = dataset.apply(
+ fused_dataset = dataset.apply(
batching.map_and_batch(
math_ops.matmul,
num_parallel_calls=num_calls,
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD b/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD
index b299e0736f..459bdf66f3 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD
@@ -7,6 +7,34 @@ exports_files(["LICENSE"])
load("//tensorflow:tensorflow.bzl", "py_test")
py_test(
+ name = "assert_next_dataset_op_test",
+ size = "medium",
+ srcs = ["assert_next_dataset_op_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/data/python/ops:optimization",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "latency_all_edges_test",
+ size = "small",
+ srcs = ["latency_all_edges_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/data/python/kernel_tests:stats_dataset_test_base",
+ "//tensorflow/contrib/data/python/ops:optimization",
+ "//tensorflow/contrib/data/python/ops:stats_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
name = "map_vectorization_test",
size = "small",
srcs = ["map_vectorization_test.py"],
@@ -46,16 +74,15 @@ py_test(
)
py_test(
- name = "latency_all_edges_test",
+ name = "optimize_dataset_op_test",
size = "small",
- srcs = ["latency_all_edges_test.py"],
+ srcs = ["optimize_dataset_op_test.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/contrib/data/python/kernel_tests:stats_dataset_test_base",
"//tensorflow/contrib/data/python/ops:optimization",
- "//tensorflow/contrib/data/python/ops:stats_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
],
)
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py
new file mode 100644
index 0000000000..bd7b50b902
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py
@@ -0,0 +1,64 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import errors
+from tensorflow.python.platform import test
+
+
+class AssertNextDatasetTest(test.TestCase):
+
+ def testAssertNext(self):
+ dataset = dataset_ops.Dataset.from_tensors(0).apply(
+ optimization.assert_next(["Map"])).map(lambda x: x)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ self.assertEqual(0, sess.run(get_next))
+
+ def testAssertNextInvalid(self):
+ dataset = dataset_ops.Dataset.from_tensors(0).apply(
+ optimization.assert_next(["Whoops"])).map(lambda x: x)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ "Asserted Whoops transformation at offset 0 but encountered "
+ "Map transformation instead."):
+ sess.run(get_next)
+
+ def testAssertNextShort(self):
+ dataset = dataset_ops.Dataset.from_tensors(0).apply(
+ optimization.assert_next(["Map", "Whoops"])).map(lambda x: x)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ "Asserted next 2 transformations but encountered only 1."):
+ sess.run(get_next)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py
index 089717156c..909da5aee0 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py
@@ -17,7 +17,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from absl.testing import parameterized
import numpy as np
from tensorflow.contrib.data.python.ops import optimization
@@ -29,41 +28,7 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
-class OptimizeDatasetTest(test.TestCase, parameterized.TestCase):
-
- def testAssertSuffix(self):
- dataset = dataset_ops.Dataset.from_tensors(0).apply(
- optimization.assert_next(["Map"])).map(lambda x: x)
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
-
- with self.test_session() as sess:
- self.assertEqual(0, sess.run(get_next))
-
- def testAssertSuffixInvalid(self):
- dataset = dataset_ops.Dataset.from_tensors(0).apply(
- optimization.assert_next(["Whoops"])).map(lambda x: x)
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
-
- with self.test_session() as sess:
- with self.assertRaisesRegexp(
- errors.InvalidArgumentError,
- "Asserted Whoops transformation at offset 0 but encountered "
- "Map transformation instead."):
- sess.run(get_next)
-
- def testAssertSuffixShort(self):
- dataset = dataset_ops.Dataset.from_tensors(0).apply(
- optimization.assert_next(["Map", "Whoops"])).map(lambda x: x)
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
-
- with self.test_session() as sess:
- with self.assertRaisesRegexp(
- errors.InvalidArgumentError,
- "Asserted next 2 transformations but encountered only 1."):
- sess.run(get_next)
+class OptimizeDatasetTest(test.TestCase):
def testOptimizationDefault(self):
dataset = dataset_ops.Dataset.range(10).apply(