aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/data
diff options
context:
space:
mode:
authorGravatar Jiri Simsa <jsimsa@google.com>2018-09-27 08:55:25 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-27 09:00:23 -0700
commit77244534c0325f61509ac769efc8b462dec00b95 (patch)
tree8aa8b9f462d24d9e7e0c155a354c2549390e9292 /tensorflow/python/data
parent9a68681c3e9bf7e51423dcdbefd25da9c365d256 (diff)
[tf.data] Minor refactoring of tf.data tests.
PiperOrigin-RevId: 214781794
Diffstat (limited to 'tensorflow/python/data')
-rw-r--r--tensorflow/python/data/kernel_tests/BUILD334
-rw-r--r--tensorflow/python/data/kernel_tests/batch_dataset_op_test.py10
-rw-r--r--tensorflow/python/data/kernel_tests/cache_dataset_op_test.py5
-rw-r--r--tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py8
-rw-r--r--tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/dataset_ops_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/filter_dataset_op_test.py8
-rw-r--r--tensorflow/python/data/kernel_tests/flat_map_dataset_op_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/inputs_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/map_dataset_op_test.py8
-rw-r--r--tensorflow/python/data/kernel_tests/multi_device_iterator_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/optional_ops_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/range_dataset_op_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py7
-rw-r--r--tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/shard_dataset_op_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/test_base.py29
-rw-r--r--tensorflow/python/data/kernel_tests/window_dataset_op_test.py8
-rw-r--r--tensorflow/python/data/kernel_tests/zip_dataset_op_test.py3
24 files changed, 259 insertions, 203 deletions
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD
index 7a6f03d4d3..fdcbfc3684 100644
--- a/tensorflow/python/data/kernel_tests/BUILD
+++ b/tensorflow/python/data/kernel_tests/BUILD
@@ -15,6 +15,7 @@ tf_py_test(
size = "small",
srcs = ["batch_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
"@absl_py//absl/testing:parameterized",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
@@ -31,10 +32,44 @@ tf_py_test(
)
tf_py_test(
+ name = "cache_dataset_op_test",
+ size = "small",
+ srcs = ["cache_dataset_op_test.py"],
+ additional_deps = [
+ ":test_base",
+ "//third_party/py/numpy",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
+ ],
+)
+
+tf_py_test(
+ name = "concatenate_dataset_op_test",
+ size = "small",
+ srcs = ["concatenate_dataset_op_test.py"],
+ additional_deps = [
+ ":test_base",
+ "//third_party/py/numpy",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ ],
+)
+
+tf_py_test(
name = "dataset_constructor_op_test",
size = "small",
srcs = ["dataset_constructor_op_test.py"],
additional_deps = [
+ ":test_base",
"//third_party/py/numpy",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
@@ -63,6 +98,7 @@ tf_py_test(
size = "medium",
srcs = ["dataset_from_generator_op_test.py"],
additional_deps = [
+ ":test_base",
"//third_party/py/numpy",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
@@ -78,6 +114,7 @@ tf_py_test(
size = "small",
srcs = ["dataset_ops_test.py"],
additional_deps = [
+ ":test_base",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python/data/ops:dataset_ops",
@@ -89,6 +126,7 @@ tf_py_test(
size = "small",
srcs = ["filter_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
@@ -106,6 +144,7 @@ tf_py_test(
size = "small",
srcs = ["flat_map_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
"//third_party/py/numpy",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
@@ -123,6 +162,7 @@ tf_py_test(
size = "small",
srcs = ["list_files_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
@@ -133,91 +173,52 @@ tf_py_test(
)
tf_py_test(
- name = "interleave_dataset_op_test",
+ name = "inputs_test",
size = "small",
- srcs = ["interleave_dataset_op_test.py"],
+ srcs = ["inputs_test.py"],
additional_deps = [
+ ":test_base",
"@absl_py//absl/testing:parameterized",
"//third_party/py/numpy",
- "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:session",
- "//tensorflow/python:sparse_ops",
"//tensorflow/python:sparse_tensor",
- "//tensorflow/python:training",
"//tensorflow/python/data/ops:dataset_ops",
],
)
tf_py_test(
- name = "map_dataset_op_test",
+ name = "interleave_dataset_op_test",
size = "small",
- srcs = ["map_dataset_op_test.py"],
+ srcs = ["interleave_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
"@absl_py//absl/testing:parameterized",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:data_flow_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
- "//tensorflow/python:functional_ops",
- "//tensorflow/python:lookup_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:random_ops",
- "//tensorflow/python:script_ops",
+ "//tensorflow/python:session",
"//tensorflow/python:sparse_ops",
"//tensorflow/python:sparse_tensor",
- "//tensorflow/python:string_ops",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python/data/ops:dataset_ops",
- ],
-)
-
-tf_py_test(
- name = "prefetch_dataset_op_test",
- size = "small",
- srcs = ["prefetch_dataset_op_test.py"],
- additional_deps = [
- "@absl_py//absl/testing:parameterized",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dataset_ops_gen",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
+ "//tensorflow/python:training",
"//tensorflow/python/data/ops:dataset_ops",
],
)
-tf_py_test(
- name = "range_dataset_op_test",
+cuda_py_test(
+ name = "iterator_ops_test",
size = "small",
- srcs = ["range_dataset_op_test.py"],
+ srcs = ["iterator_ops_test.py"],
additional_deps = [
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dataset_ops_gen",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:io_ops",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:parsing_ops",
- "//tensorflow/python:platform",
- "//tensorflow/python:tensor_shape",
- "//tensorflow/python:variables",
+ "//third_party/py/numpy",
+ "//tensorflow/python/data/ops:readers",
+ "//tensorflow/core:protos_all_py",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:iterator_ops",
- ],
-)
-
-tf_py_test(
- name = "reader_dataset_ops_test",
- size = "small",
- srcs = ["reader_dataset_ops_test.py"],
- additional_deps = [
+ "//tensorflow/python/data/util:sparse",
+ "//tensorflow/python/eager:context",
+ "//tensorflow/python/training/checkpointable:util",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
@@ -225,82 +226,133 @@ tf_py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:function",
+ "//tensorflow/python:functional_ops",
+ "//tensorflow/python:gradients",
"//tensorflow/python:io_ops",
- "//tensorflow/python:lib",
+ "//tensorflow/python:math_ops",
"//tensorflow/python:parsing_ops",
+ "//tensorflow/python:random_ops",
+ "//tensorflow/python:script_ops",
+ "//tensorflow/python:session",
+ "//tensorflow/python:sparse_tensor",
"//tensorflow/python:tensor_shape",
+ "//tensorflow/python:training",
+ "//tensorflow/python/compat:compat",
"//tensorflow/python:util",
- "//tensorflow/python/data/ops:iterator_ops",
- "//tensorflow/python/data/ops:readers",
+ "//tensorflow/python:variables",
],
+ grpc_enabled = True,
)
tf_py_test(
- name = "sequence_dataset_op_test",
+ name = "iterator_ops_cluster_test",
size = "small",
- srcs = ["sequence_dataset_op_test.py"],
+ srcs = ["iterator_ops_cluster_test.py"],
additional_deps = [
- "//third_party/py/numpy",
+ "//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:function",
+ "//tensorflow/python:functional_ops",
+ "//tensorflow/python:session",
"//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:string_ops",
+ "//tensorflow/python:lookup_ops",
+ ],
+ grpc_enabled = True,
+ tags = [
+ "no_oss", # Test flaky due to port collisions.
+ "no_windows",
],
)
tf_py_test(
- name = "shuffle_dataset_op_test",
+ name = "map_dataset_op_test",
size = "small",
- srcs = ["shuffle_dataset_op_test.py"],
+ srcs = ["map_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
+ "@absl_py//absl/testing:parameterized",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
+ "//tensorflow/python:data_flow_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
+ "//tensorflow/python:functional_ops",
+ "//tensorflow/python:lookup_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:random_ops",
+ "//tensorflow/python:script_ops",
+ "//tensorflow/python:sparse_ops",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:string_ops",
+ "//tensorflow/python:variable_scope",
"//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/ops:iterator_ops",
],
)
-tf_py_test(
- name = "shard_dataset_op_test",
+cuda_py_test(
+ name = "multi_device_iterator_test",
size = "small",
- srcs = ["shard_dataset_op_test.py"],
+ srcs = ["multi_device_iterator_test.py"],
additional_deps = [
+ ":test_base",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:multi_device_iterator_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
"//tensorflow/python:errors",
- "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python:framework_test_lib",
+ ],
+ tags = [
+ "no_windows_gpu",
],
)
-tf_py_test(
- name = "cache_dataset_op_test",
+cuda_py_test(
+ name = "optional_ops_test",
size = "small",
- srcs = ["cache_dataset_op_test.py"],
+ srcs = ["optional_ops_test.py"],
additional_deps = [
- "//third_party/py/numpy",
- "//tensorflow/python:array_ops",
+ ":test_base",
+ "@absl_py//absl/testing:parameterized",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python/data/ops:optional_ops",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
- "//tensorflow/python:variables",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:tensor_shape",
],
)
tf_py_test(
- name = "zip_dataset_op_test",
+ name = "prefetch_dataset_op_test",
size = "small",
- srcs = ["zip_dataset_op_test.py"],
+ srcs = ["prefetch_dataset_op_test.py"],
additional_deps = [
- "//third_party/py/numpy",
+ ":test_base",
+ "@absl_py//absl/testing:parameterized",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python:dataset_ops_gen",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python/data/ops:dataset_ops",
@@ -308,32 +360,33 @@ tf_py_test(
)
tf_py_test(
- name = "concatenate_dataset_op_test",
+ name = "range_dataset_op_test",
size = "small",
- srcs = ["concatenate_dataset_op_test.py"],
+ srcs = ["range_dataset_op_test.py"],
additional_deps = [
- "//third_party/py/numpy",
+ ":test_base",
+ "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:dtypes",
"//tensorflow/python:errors",
+ "//tensorflow/python:io_ops",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:platform",
"//tensorflow/python:tensor_shape",
+ "//tensorflow/python:variables",
"//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/ops:iterator_ops",
],
)
-cuda_py_test(
- name = "iterator_ops_test",
+tf_py_test(
+ name = "reader_dataset_ops_test",
size = "small",
- srcs = ["iterator_ops_test.py"],
+ srcs = ["reader_dataset_ops_test.py"],
additional_deps = [
- "//third_party/py/numpy",
- "//tensorflow/python/data/ops:readers",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/ops:iterator_ops",
- "//tensorflow/python/data/util:sparse",
- "//tensorflow/python/eager:context",
- "//tensorflow/python/training/checkpointable:util",
+ ":test_base",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
@@ -341,91 +394,65 @@ cuda_py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:function",
- "//tensorflow/python:functional_ops",
- "//tensorflow/python:gradients",
"//tensorflow/python:io_ops",
- "//tensorflow/python:math_ops",
+ "//tensorflow/python:lib",
"//tensorflow/python:parsing_ops",
- "//tensorflow/python:random_ops",
- "//tensorflow/python:script_ops",
- "//tensorflow/python:session",
- "//tensorflow/python:sparse_tensor",
"//tensorflow/python:tensor_shape",
- "//tensorflow/python:training",
- "//tensorflow/python/compat:compat",
"//tensorflow/python:util",
- "//tensorflow/python:variables",
+ "//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python/data/ops:readers",
],
- grpc_enabled = True,
)
tf_py_test(
- name = "iterator_ops_cluster_test",
+ name = "sequence_dataset_op_test",
size = "small",
- srcs = ["iterator_ops_cluster_test.py"],
+ srcs = ["sequence_dataset_op_test.py"],
additional_deps = [
- "//tensorflow/core:protos_all_py",
+ ":test_base",
+ "//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:function",
- "//tensorflow/python:functional_ops",
- "//tensorflow/python:session",
"//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/ops:iterator_ops",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:string_ops",
- "//tensorflow/python:lookup_ops",
- ],
- grpc_enabled = True,
- tags = [
- "no_oss", # Test flaky due to port collisions.
- "no_windows",
],
)
-cuda_py_test(
- name = "optional_ops_test",
+tf_py_test(
+ name = "shard_dataset_op_test",
size = "small",
- srcs = ["optional_ops_test.py"],
+ srcs = ["shard_dataset_op_test.py"],
additional_deps = [
- "@absl_py//absl/testing:parameterized",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/ops:iterator_ops",
- "//tensorflow/python/data/ops:optional_ops",
+ ":test_base",
"//tensorflow/python:client_testlib",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
"//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/ops:dataset_ops",
],
)
-cuda_py_test(
- name = "multi_device_iterator_test",
+tf_py_test(
+ name = "shuffle_dataset_op_test",
size = "small",
- srcs = ["multi_device_iterator_test.py"],
+ srcs = ["shuffle_dataset_op_test.py"],
additional_deps = [
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/ops:multi_device_iterator_ops",
- "//tensorflow/python/data/ops:iterator_ops",
+ ":test_base",
+ "//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
- "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
],
- tags = [
- "no_windows_gpu",
+)
+
+py_library(
+ name = "test_base",
+ srcs = ["test_base.py"],
+ deps = [
+ "//tensorflow/python:client_testlib",
],
)
@@ -434,6 +461,7 @@ tf_py_test(
size = "small",
srcs = ["window_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
"@absl_py//absl/testing:parameterized",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
@@ -447,14 +475,16 @@ tf_py_test(
)
tf_py_test(
- name = "inputs_test",
+ name = "zip_dataset_op_test",
size = "small",
- srcs = ["inputs_test.py"],
+ srcs = ["zip_dataset_op_test.py"],
additional_deps = [
- "@absl_py//absl/testing:parameterized",
+ ":test_base",
"//third_party/py/numpy",
+ "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
- "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
"//tensorflow/python/data/ops:dataset_ops",
],
)
diff --git a/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py b/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py
index c48708a2b9..9cb4daf284 100644
--- a/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py
@@ -24,6 +24,7 @@ from absl.testing import parameterized
import numpy as np
from tensorflow.python.client import session
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -37,7 +38,7 @@ from tensorflow.python.platform import test
from tensorflow.python.util import compat
-class BatchDatasetTest(test.TestCase, parameterized.TestCase):
+class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.named_parameters(
('even', 28, 14, False),
@@ -115,11 +116,6 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
with self.assertRaises(errors.InvalidArgumentError):
sess.run(get_next)
- def assertSparseValuesEqual(self, a, b):
- self.assertAllEqual(a.indices, b.indices)
- self.assertAllEqual(a.values, b.values)
- self.assertAllEqual(a.dense_shape, b.dense_shape)
-
def testBatchSparse(self):
def _sparse(i):
@@ -227,7 +223,7 @@ def _random_seq_lens(count):
return np.random.randint(20, size=(count,)).astype(np.int32)
-class PaddedBatchDatasetTest(test.TestCase, parameterized.TestCase):
+class PaddedBatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.named_parameters(
('default_padding', _random_seq_lens(32), 4, [-1], False),
diff --git a/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py b/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py
index d5f5b2fe05..63625fac03 100644
--- a/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py
@@ -23,6 +23,7 @@ import tempfile
import numpy as np
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import constant_op
@@ -34,7 +35,7 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import test
-class FileCacheDatasetTest(test.TestCase):
+class FileCacheDatasetTest(test_base.DatasetTestBase):
def setUp(self):
self.tmp_dir = tempfile.mkdtemp()
@@ -200,7 +201,7 @@ class FileCacheDatasetTest(test.TestCase):
self.assertAllEqual(elements, elements_itr2)
-class MemoryCacheDatasetTest(test.TestCase):
+class MemoryCacheDatasetTest(test_base.DatasetTestBase):
def testCacheDatasetPassthrough(self):
with ops.device("cpu:0"):
diff --git a/tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py b/tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py
index 5dfb84f28e..83af31f380 100644
--- a/tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import numpy as np
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.framework import errors
@@ -26,7 +27,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.platform import test
-class ConcatenateDatasetTest(test.TestCase):
+class ConcatenateDatasetTest(test_base.DatasetTestBase):
def testConcatenateDataset(self):
input_components = (
diff --git a/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py b/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py
index e43564a2eb..bc6b36285a 100644
--- a/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py
+++ b/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py
@@ -23,6 +23,7 @@ import numpy as np
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.framework import dtypes
@@ -36,7 +37,7 @@ from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
-class DatasetConstructorTest(test.TestCase):
+class DatasetConstructorTest(test_base.DatasetTestBase):
def testFromTensors(self):
"""Test a dataset that represents a single tuple of tensors."""
@@ -58,11 +59,6 @@ class DatasetConstructorTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- def assertSparseValuesEqual(self, a, b):
- self.assertAllEqual(a.indices, b.indices)
- self.assertAllEqual(a.values, b.values)
- self.assertAllEqual(a.dense_shape, b.dense_shape)
-
def testFromTensorsSparse(self):
"""Test a dataset that represents a single tuple of tensors."""
components = (sparse_tensor.SparseTensorValue(
diff --git a/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py b/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py
index cd0c1ddf1e..cb8cb9a77d 100644
--- a/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py
+++ b/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py
@@ -22,6 +22,7 @@ import threading
import numpy as np
from tensorflow.python.client import session
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -30,7 +31,7 @@ from tensorflow.python.ops import script_ops
from tensorflow.python.platform import test
-class DatasetConstructorTest(test.TestCase):
+class DatasetConstructorTest(test_base.DatasetTestBase):
def _testFromGenerator(self, generator, elem_sequence, num_repeats,
output_types=None):
diff --git a/tensorflow/python/data/kernel_tests/dataset_ops_test.py b/tensorflow/python/data/kernel_tests/dataset_ops_test.py
index 239aa85175..f115f9d9c7 100644
--- a/tensorflow/python/data/kernel_tests/dataset_ops_test.py
+++ b/tensorflow/python/data/kernel_tests/dataset_ops_test.py
@@ -19,11 +19,12 @@ from __future__ import division
from __future__ import print_function
from tensorflow.core.framework import graph_pb2
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.platform import test
-class DatasetOpsTest(test.TestCase):
+class DatasetOpsTest(test_base.DatasetTestBase):
def testAsSerializedGraph(self):
dataset = dataset_ops.Dataset.range(10)
diff --git a/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py b/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py
index 19944d389f..6b7afafa5d 100644
--- a/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py
@@ -22,6 +22,7 @@ import time
import numpy as np
from tensorflow.python.client import session
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -33,7 +34,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class FilterDatasetTest(test.TestCase):
+class FilterDatasetTest(test_base.DatasetTestBase):
def testFilterDataset(self):
components = (
@@ -129,11 +130,6 @@ class FilterDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- def assertSparseValuesEqual(self, a, b):
- self.assertAllEqual(a.indices, b.indices)
- self.assertAllEqual(a.values, b.values)
- self.assertAllEqual(a.dense_shape, b.dense_shape)
-
def testSparse(self):
def _map_fn(i):
diff --git a/tensorflow/python/data/kernel_tests/flat_map_dataset_op_test.py b/tensorflow/python/data/kernel_tests/flat_map_dataset_op_test.py
index 1123cbff62..68038f9cfc 100644
--- a/tensorflow/python/data/kernel_tests/flat_map_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/flat_map_dataset_op_test.py
@@ -22,6 +22,7 @@ import random
import numpy as np
from tensorflow.python.client import session
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor
@@ -30,7 +31,7 @@ from tensorflow.python.platform import test
from tensorflow.python.training import server_lib
-class FlatMapDatasetTest(test.TestCase):
+class FlatMapDatasetTest(test_base.DatasetTestBase):
# pylint: disable=g-long-lambda
def testFlatMapDataset(self):
diff --git a/tensorflow/python/data/kernel_tests/inputs_test.py b/tensorflow/python/data/kernel_tests/inputs_test.py
index 4c9279dd95..d089b49bcc 100644
--- a/tensorflow/python/data/kernel_tests/inputs_test.py
+++ b/tensorflow/python/data/kernel_tests/inputs_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
from absl.testing import parameterized
import numpy as np
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers
from tensorflow.python.data.util import nest
@@ -27,7 +28,7 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.platform import test
-class InputsTest(test.TestCase, parameterized.TestCase):
+class InputsTest(test_base.DatasetTestBase, parameterized.TestCase):
@staticmethod
def make_apply_fn(dataset):
diff --git a/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py b/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py
index e7e51df65e..92bb67b6ff 100644
--- a/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py
@@ -22,6 +22,7 @@ import itertools
from absl.testing import parameterized
import numpy as np
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor
@@ -30,7 +31,7 @@ from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import test
-class InterleaveDatasetTest(test.TestCase, parameterized.TestCase):
+class InterleaveDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
def _interleave(self, lists, cycle_length, block_length):
num_open = 0
diff --git a/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py b/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py
index c4b338a58f..8eb13815d4 100644
--- a/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py
@@ -22,6 +22,7 @@ from os import path
import shutil
import tempfile
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -30,7 +31,7 @@ from tensorflow.python.platform import test
from tensorflow.python.util import compat
-class ListFilesDatasetOpTest(test.TestCase):
+class ListFilesDatasetOpTest(test_base.DatasetTestBase):
def setUp(self):
self.tmp_dir = tempfile.mkdtemp()
diff --git a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
index ae04995436..230ae3f3fd 100644
--- a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
@@ -27,6 +27,7 @@ import numpy as np
from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.client import session
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -47,7 +48,7 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
-class MapDatasetTest(test.TestCase, parameterized.TestCase):
+class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
def _buildMapDataset(self, components, count):
def _map_fn(x, y, z):
@@ -574,11 +575,6 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- def assertSparseValuesEqual(self, a, b):
- self.assertAllEqual(a.indices, b.indices)
- self.assertAllEqual(a.values, b.values)
- self.assertAllEqual(a.dense_shape, b.dense_shape)
-
def testSparse(self):
def _sparse(i):
diff --git a/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py b/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py
index 056664b83b..1cf6dd1bea 100644
--- a/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py
+++ b/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import multi_device_iterator_ops
from tensorflow.python.framework import dtypes
@@ -29,7 +30,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class MultiDeviceIteratorTest(test.TestCase):
+class MultiDeviceIteratorTest(test_base.DatasetTestBase):
def testNoGetNext(self):
dataset = dataset_ops.Dataset.range(10)
diff --git a/tensorflow/python/data/kernel_tests/optional_ops_test.py b/tensorflow/python/data/kernel_tests/optional_ops_test.py
index 706a65fe55..604e3ad88e 100644
--- a/tensorflow/python/data/kernel_tests/optional_ops_test.py
+++ b/tensorflow/python/data/kernel_tests/optional_ops_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
from absl.testing import parameterized
import numpy as np
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.ops import optional_ops
@@ -35,7 +36,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class OptionalTest(test.TestCase, parameterized.TestCase):
+class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
@test_util.run_in_graph_and_eager_modes
def testFromValue(self):
diff --git a/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py b/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py
index cc97bac609..76e2697b29 100644
--- a/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
from absl.testing import parameterized
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -26,7 +27,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class PrefetchDatasetTest(test.TestCase, parameterized.TestCase):
+class PrefetchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.parameters((-1), (0), (5))
def testBufferSize(self, buffer_size):
diff --git a/tensorflow/python/data/kernel_tests/range_dataset_op_test.py b/tensorflow/python/data/kernel_tests/range_dataset_op_test.py
index 51e90785e7..b7e2a5f615 100644
--- a/tensorflow/python/data/kernel_tests/range_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/range_dataset_op_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import os
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import dtypes
@@ -34,7 +35,7 @@ from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
-class RangeDatasetTest(test.TestCase):
+class RangeDatasetTest(test_base.DatasetTestBase):
def tearDown(self):
# Remove all checkpoint files.
diff --git a/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py b/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py
index aa3636364d..aef2dd1d9c 100644
--- a/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py
+++ b/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py
@@ -21,6 +21,7 @@ import gzip
import os
import zlib
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.ops import readers
@@ -46,7 +47,7 @@ except ImportError:
psutil_import_succeeded = False
-class TextLineDatasetTest(test.TestCase):
+class TextLineDatasetTest(test_base.DatasetTestBase):
def _lineText(self, f, l):
return compat.as_bytes("%d: %d" % (f, l))
@@ -199,7 +200,7 @@ class TextLineDatasetTest(test.TestCase):
self.assertNotIn(filename, [open_file.path for open_file in open_files])
-class FixedLengthRecordReaderTest(test.TestCase):
+class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
def setUp(self):
super(FixedLengthRecordReaderTest, self).setUp()
@@ -621,7 +622,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
sess.run(get_next_op)
-class TFRecordDatasetTest(test.TestCase):
+class TFRecordDatasetTest(test_base.DatasetTestBase):
def setUp(self):
super(TFRecordDatasetTest, self).setUp()
diff --git a/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py b/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py
index 37e2333560..e86356dee7 100644
--- a/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import numpy as np
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -26,7 +27,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class SequenceDatasetTest(test.TestCase):
+class SequenceDatasetTest(test_base.DatasetTestBase):
def testRepeatTensorDataset(self):
"""Test a dataset that repeats its input multiple times."""
diff --git a/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py b/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py
index 137f6341ce..b9f3c79da5 100644
--- a/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py
@@ -17,12 +17,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.platform import test
-class ShardDatasetOpTest(test.TestCase):
+class ShardDatasetOpTest(test_base.DatasetTestBase):
def testSimpleCase(self):
dataset = dataset_ops.Dataset.range(10).shard(5, 2)
diff --git a/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py
index f294840706..347af18576 100644
--- a/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py
@@ -21,6 +21,7 @@ import collections
import numpy as np
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import constant_op
@@ -30,7 +31,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class ShuffleDatasetTest(test.TestCase):
+class ShuffleDatasetTest(test_base.DatasetTestBase):
def testShuffleDataset(self):
components = (
diff --git a/tensorflow/python/data/kernel_tests/test_base.py b/tensorflow/python/data/kernel_tests/test_base.py
new file mode 100644
index 0000000000..b4f64115b7
--- /dev/null
+++ b/tensorflow/python/data/kernel_tests/test_base.py
@@ -0,0 +1,29 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Test utilities for tf.data functionality."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.platform import test
+
+
+class DatasetTestBase(test.TestCase):
+ """Base class for dataset tests."""
+
+ def assertSparseValuesEqual(self, a, b):
+ self.assertAllEqual(a.indices, b.indices)
+ self.assertAllEqual(a.values, b.values)
+ self.assertAllEqual(a.dense_shape, b.dense_shape)
diff --git a/tensorflow/python/data/kernel_tests/window_dataset_op_test.py b/tensorflow/python/data/kernel_tests/window_dataset_op_test.py
index fd4348426d..9d06781094 100644
--- a/tensorflow/python/data/kernel_tests/window_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/window_dataset_op_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
from absl.testing import parameterized
import numpy as np
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -29,7 +30,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class WindowDatasetTest(test.TestCase, parameterized.TestCase):
+class WindowDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.named_parameters(
("1", 20, 14, 7, 1),
@@ -150,11 +151,6 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
stride_t: stride
})
- def assertSparseValuesEqual(self, a, b):
- self.assertAllEqual(a.indices, b.indices)
- self.assertAllEqual(a.values, b.values)
- self.assertAllEqual(a.dense_shape, b.dense_shape)
-
def testWindowSparse(self):
def _sparse(i):
diff --git a/tensorflow/python/data/kernel_tests/zip_dataset_op_test.py b/tensorflow/python/data/kernel_tests/zip_dataset_op_test.py
index 3106effbd3..9d76387a34 100644
--- a/tensorflow/python/data/kernel_tests/zip_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/zip_dataset_op_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import numpy as np
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -26,7 +27,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class ZipDatasetTest(test.TestCase):
+class ZipDatasetTest(test_base.DatasetTestBase):
def testZipDataset(self):
component_placeholders = [