diff options
author | Jiri Simsa <jsimsa@google.com> | 2018-09-27 08:55:25 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-27 09:00:23 -0700 |
commit | 77244534c0325f61509ac769efc8b462dec00b95 (patch) | |
tree | 8aa8b9f462d24d9e7e0c155a354c2549390e9292 /tensorflow/python/data | |
parent | 9a68681c3e9bf7e51423dcdbefd25da9c365d256 (diff) |
[tf.data] Minor refactoring of tf.data tests.
PiperOrigin-RevId: 214781794
Diffstat (limited to 'tensorflow/python/data')
24 files changed, 259 insertions, 203 deletions
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD index 7a6f03d4d3..fdcbfc3684 100644 --- a/tensorflow/python/data/kernel_tests/BUILD +++ b/tensorflow/python/data/kernel_tests/BUILD @@ -15,6 +15,7 @@ tf_py_test( size = "small", srcs = ["batch_dataset_op_test.py"], additional_deps = [ + ":test_base", "@absl_py//absl/testing:parameterized", "//third_party/py/numpy", "//tensorflow/python:array_ops", @@ -31,10 +32,44 @@ tf_py_test( ) tf_py_test( + name = "cache_dataset_op_test", + size = "small", + srcs = ["cache_dataset_op_test.py"], + additional_deps = [ + ":test_base", + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:variables", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/ops:iterator_ops", + ], +) + +tf_py_test( + name = "concatenate_dataset_op_test", + size = "small", + srcs = ["concatenate_dataset_op_test.py"], + additional_deps = [ + ":test_base", + "//third_party/py/numpy", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python:tensor_shape", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + ], +) + +tf_py_test( name = "dataset_constructor_op_test", size = "small", srcs = ["dataset_constructor_op_test.py"], additional_deps = [ + ":test_base", "//third_party/py/numpy", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", @@ -63,6 +98,7 @@ tf_py_test( size = "medium", srcs = ["dataset_from_generator_op_test.py"], additional_deps = [ + ":test_base", "//third_party/py/numpy", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", @@ -78,6 +114,7 @@ tf_py_test( size = "small", srcs = ["dataset_ops_test.py"], additional_deps = [ + ":test_base", "//tensorflow/core:protos_all_py", "//tensorflow/python:client_testlib", "//tensorflow/python/data/ops:dataset_ops", @@ -89,6 +126,7 @@ tf_py_test( size = "small", srcs = ["filter_dataset_op_test.py"], additional_deps = [ + ":test_base", "//third_party/py/numpy", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -106,6 +144,7 @@ tf_py_test( size = "small", srcs = ["flat_map_dataset_op_test.py"], additional_deps = [ + ":test_base", "//third_party/py/numpy", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", @@ -123,6 +162,7 @@ tf_py_test( size = "small", srcs = ["list_files_dataset_op_test.py"], additional_deps = [ + ":test_base", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", @@ -133,91 +173,52 @@ tf_py_test( ) tf_py_test( - name = "interleave_dataset_op_test", + name = "inputs_test", size = "small", - srcs = ["interleave_dataset_op_test.py"], + srcs = ["inputs_test.py"], additional_deps = [ + ":test_base", "@absl_py//absl/testing:parameterized", "//third_party/py/numpy", - "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:session", - "//tensorflow/python:sparse_ops", "//tensorflow/python:sparse_tensor", - "//tensorflow/python:training", "//tensorflow/python/data/ops:dataset_ops", ], ) tf_py_test( - name = "map_dataset_op_test", + name = "interleave_dataset_op_test", size = "small", - srcs = ["map_dataset_op_test.py"], + srcs = ["interleave_dataset_op_test.py"], additional_deps = [ + ":test_base", "@absl_py//absl/testing:parameterized", "//third_party/py/numpy", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:data_flow_ops", "//tensorflow/python:dtypes", "//tensorflow/python:errors", - "//tensorflow/python:functional_ops", - "//tensorflow/python:lookup_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:random_ops", - "//tensorflow/python:script_ops", + "//tensorflow/python:session", "//tensorflow/python:sparse_ops", "//tensorflow/python:sparse_tensor", - "//tensorflow/python:string_ops", - "//tensorflow/python:variable_scope", - "//tensorflow/python/data/ops:dataset_ops", - ], -) - -tf_py_test( - name = "prefetch_dataset_op_test", - size = "small", - srcs = ["prefetch_dataset_op_test.py"], - additional_deps = [ - "@absl_py//absl/testing:parameterized", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dataset_ops_gen", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", + "//tensorflow/python:training", "//tensorflow/python/data/ops:dataset_ops", ], ) -tf_py_test( - name = "range_dataset_op_test", +cuda_py_test( + name = "iterator_ops_test", size = "small", - srcs = ["range_dataset_op_test.py"], + srcs = ["iterator_ops_test.py"], additional_deps = [ - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dataset_ops_gen", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:io_ops", - "//tensorflow/python:framework_ops", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:platform", - "//tensorflow/python:tensor_shape", - "//tensorflow/python:variables", + "//third_party/py/numpy", + "//tensorflow/python/data/ops:readers", + "//tensorflow/core:protos_all_py", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:iterator_ops", - ], -) - -tf_py_test( - name = "reader_dataset_ops_test", - size = "small", - srcs = ["reader_dataset_ops_test.py"], - additional_deps = [ + "//tensorflow/python/data/util:sparse", + "//tensorflow/python/eager:context", + "//tensorflow/python/training/checkpointable:util", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", @@ -225,82 +226,133 @@ tf_py_test( "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:function", + "//tensorflow/python:functional_ops", + "//tensorflow/python:gradients", "//tensorflow/python:io_ops", - "//tensorflow/python:lib", + "//tensorflow/python:math_ops", "//tensorflow/python:parsing_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:script_ops", + "//tensorflow/python:session", + "//tensorflow/python:sparse_tensor", "//tensorflow/python:tensor_shape", + "//tensorflow/python:training", + "//tensorflow/python/compat:compat", "//tensorflow/python:util", - "//tensorflow/python/data/ops:iterator_ops", - "//tensorflow/python/data/ops:readers", + "//tensorflow/python:variables", ], + grpc_enabled = True, ) tf_py_test( - name = "sequence_dataset_op_test", + name = "iterator_ops_cluster_test", size = "small", - srcs = ["sequence_dataset_op_test.py"], + srcs = ["iterator_ops_cluster_test.py"], additional_deps = [ - "//third_party/py/numpy", + "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:function", + "//tensorflow/python:functional_ops", + "//tensorflow/python:session", "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/ops:iterator_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:string_ops", + "//tensorflow/python:lookup_ops", + ], + grpc_enabled = True, + tags = [ + "no_oss", # Test flaky due to port collisions. + "no_windows", ], ) tf_py_test( - name = "shuffle_dataset_op_test", + name = "map_dataset_op_test", size = "small", - srcs = ["shuffle_dataset_op_test.py"], + srcs = ["map_dataset_op_test.py"], additional_deps = [ + ":test_base", + "@absl_py//absl/testing:parameterized", "//third_party/py/numpy", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", + "//tensorflow/python:data_flow_ops", "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "//tensorflow/python:functional_ops", + "//tensorflow/python:lookup_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:script_ops", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:string_ops", + "//tensorflow/python:variable_scope", "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/ops:iterator_ops", ], ) -tf_py_test( - name = "shard_dataset_op_test", +cuda_py_test( + name = "multi_device_iterator_test", size = "small", - srcs = ["shard_dataset_op_test.py"], + srcs = ["multi_device_iterator_test.py"], additional_deps = [ + ":test_base", + "//tensorflow/core:protos_all_py", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/ops:multi_device_iterator_ops", + "//tensorflow/python/data/ops:iterator_ops", + "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", "//tensorflow/python:errors", - "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python:framework_test_lib", + ], + tags = [ + "no_windows_gpu", ], ) -tf_py_test( - name = "cache_dataset_op_test", +cuda_py_test( + name = "optional_ops_test", size = "small", - srcs = ["cache_dataset_op_test.py"], + srcs = ["optional_ops_test.py"], additional_deps = [ - "//third_party/py/numpy", - "//tensorflow/python:array_ops", + ":test_base", + "@absl_py//absl/testing:parameterized", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/ops:iterator_ops", + "//tensorflow/python/data/ops:optional_ops", "//tensorflow/python:client_testlib", + "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:errors", - "//tensorflow/python:variables", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/ops:iterator_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:tensor_shape", ], ) tf_py_test( - name = "zip_dataset_op_test", + name = "prefetch_dataset_op_test", size = "small", - srcs = ["zip_dataset_op_test.py"], + srcs = ["prefetch_dataset_op_test.py"], additional_deps = [ - "//third_party/py/numpy", + ":test_base", + "@absl_py//absl/testing:parameterized", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", + "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python/data/ops:dataset_ops", @@ -308,32 +360,33 @@ tf_py_test( ) tf_py_test( - name = "concatenate_dataset_op_test", + name = "range_dataset_op_test", size = "small", - srcs = ["concatenate_dataset_op_test.py"], + srcs = ["range_dataset_op_test.py"], additional_deps = [ - "//third_party/py/numpy", + ":test_base", + "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "//tensorflow/python:io_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:platform", "//tensorflow/python:tensor_shape", + "//tensorflow/python:variables", "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/ops:iterator_ops", ], ) -cuda_py_test( - name = "iterator_ops_test", +tf_py_test( + name = "reader_dataset_ops_test", size = "small", - srcs = ["iterator_ops_test.py"], + srcs = ["reader_dataset_ops_test.py"], additional_deps = [ - "//third_party/py/numpy", - "//tensorflow/python/data/ops:readers", - "//tensorflow/core:protos_all_py", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/ops:iterator_ops", - "//tensorflow/python/data/util:sparse", - "//tensorflow/python/eager:context", - "//tensorflow/python/training/checkpointable:util", + ":test_base", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", @@ -341,91 +394,65 @@ cuda_py_test( "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:function", - "//tensorflow/python:functional_ops", - "//tensorflow/python:gradients", "//tensorflow/python:io_ops", - "//tensorflow/python:math_ops", + "//tensorflow/python:lib", "//tensorflow/python:parsing_ops", - "//tensorflow/python:random_ops", - "//tensorflow/python:script_ops", - "//tensorflow/python:session", - "//tensorflow/python:sparse_tensor", "//tensorflow/python:tensor_shape", - "//tensorflow/python:training", - "//tensorflow/python/compat:compat", "//tensorflow/python:util", - "//tensorflow/python:variables", + "//tensorflow/python/data/ops:iterator_ops", + "//tensorflow/python/data/ops:readers", ], - grpc_enabled = True, ) tf_py_test( - name = "iterator_ops_cluster_test", + name = "sequence_dataset_op_test", size = "small", - srcs = ["iterator_ops_cluster_test.py"], + srcs = ["sequence_dataset_op_test.py"], additional_deps = [ - "//tensorflow/core:protos_all_py", + ":test_base", + "//third_party/py/numpy", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:function", - "//tensorflow/python:functional_ops", - "//tensorflow/python:session", "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/ops:iterator_ops", - "//tensorflow/python:constant_op", - "//tensorflow/python:string_ops", - "//tensorflow/python:lookup_ops", - ], - grpc_enabled = True, - tags = [ - "no_oss", # Test flaky due to port collisions. - "no_windows", ], ) -cuda_py_test( - name = "optional_ops_test", +tf_py_test( + name = "shard_dataset_op_test", size = "small", - srcs = ["optional_ops_test.py"], + srcs = ["shard_dataset_op_test.py"], additional_deps = [ - "@absl_py//absl/testing:parameterized", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/ops:iterator_ops", - "//tensorflow/python/data/ops:optional_ops", + ":test_base", "//tensorflow/python:client_testlib", - "//tensorflow/python:array_ops", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:tensor_shape", + "//tensorflow/python/data/ops:dataset_ops", ], ) -cuda_py_test( - name = "multi_device_iterator_test", +tf_py_test( + name = "shuffle_dataset_op_test", size = "small", - srcs = ["multi_device_iterator_test.py"], + srcs = ["shuffle_dataset_op_test.py"], additional_deps = [ - "//tensorflow/core:protos_all_py", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/ops:multi_device_iterator_ops", - "//tensorflow/python/data/ops:iterator_ops", + ":test_base", + "//third_party/py/numpy", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:errors", - "//tensorflow/python:framework_test_lib", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/ops:iterator_ops", ], - tags = [ - "no_windows_gpu", +) + +py_library( + name = "test_base", + srcs = ["test_base.py"], + deps = [ + "//tensorflow/python:client_testlib", ], ) @@ -434,6 +461,7 @@ tf_py_test( size = "small", srcs = ["window_dataset_op_test.py"], additional_deps = [ + ":test_base", "@absl_py//absl/testing:parameterized", "//third_party/py/numpy", "//tensorflow/python:array_ops", @@ -447,14 +475,16 @@ tf_py_test( ) tf_py_test( - name = "inputs_test", + name = "zip_dataset_op_test", size = "small", - srcs = ["inputs_test.py"], + srcs = ["zip_dataset_op_test.py"], additional_deps = [ - "@absl_py//absl/testing:parameterized", + ":test_base", "//third_party/py/numpy", + "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:sparse_tensor", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", "//tensorflow/python/data/ops:dataset_ops", ], ) diff --git a/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py b/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py index c48708a2b9..9cb4daf284 100644 --- a/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py @@ -24,6 +24,7 @@ from absl.testing import parameterized import numpy as np from tensorflow.python.client import session +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -37,7 +38,7 @@ from tensorflow.python.platform import test from tensorflow.python.util import compat -class BatchDatasetTest(test.TestCase, parameterized.TestCase): +class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): @parameterized.named_parameters( ('even', 28, 14, False), @@ -115,11 +116,6 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): with self.assertRaises(errors.InvalidArgumentError): sess.run(get_next) - def assertSparseValuesEqual(self, a, b): - self.assertAllEqual(a.indices, b.indices) - self.assertAllEqual(a.values, b.values) - self.assertAllEqual(a.dense_shape, b.dense_shape) - def testBatchSparse(self): def _sparse(i): @@ -227,7 +223,7 @@ def _random_seq_lens(count): return np.random.randint(20, size=(count,)).astype(np.int32) -class PaddedBatchDatasetTest(test.TestCase, parameterized.TestCase): +class PaddedBatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): @parameterized.named_parameters( ('default_padding', _random_seq_lens(32), 4, [-1], False), diff --git a/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py b/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py index d5f5b2fe05..63625fac03 100644 --- a/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py @@ -23,6 +23,7 @@ import tempfile import numpy as np +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import constant_op @@ -34,7 +35,7 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import test -class FileCacheDatasetTest(test.TestCase): +class FileCacheDatasetTest(test_base.DatasetTestBase): def setUp(self): self.tmp_dir = tempfile.mkdtemp() @@ -200,7 +201,7 @@ class FileCacheDatasetTest(test.TestCase): self.assertAllEqual(elements, elements_itr2) -class MemoryCacheDatasetTest(test.TestCase): +class MemoryCacheDatasetTest(test_base.DatasetTestBase): def testCacheDatasetPassthrough(self): with ops.device("cpu:0"): diff --git a/tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py b/tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py index 5dfb84f28e..83af31f380 100644 --- a/tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.framework import errors @@ -26,7 +27,7 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.platform import test -class ConcatenateDatasetTest(test.TestCase): +class ConcatenateDatasetTest(test_base.DatasetTestBase): def testConcatenateDataset(self): input_components = ( diff --git a/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py b/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py index e43564a2eb..bc6b36285a 100644 --- a/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py +++ b/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py @@ -23,6 +23,7 @@ import numpy as np from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.framework import dtypes @@ -36,7 +37,7 @@ from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test -class DatasetConstructorTest(test.TestCase): +class DatasetConstructorTest(test_base.DatasetTestBase): def testFromTensors(self): """Test a dataset that represents a single tuple of tensors.""" @@ -58,11 +59,6 @@ class DatasetConstructorTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - def assertSparseValuesEqual(self, a, b): - self.assertAllEqual(a.indices, b.indices) - self.assertAllEqual(a.values, b.values) - self.assertAllEqual(a.dense_shape, b.dense_shape) - def testFromTensorsSparse(self): """Test a dataset that represents a single tuple of tensors.""" components = (sparse_tensor.SparseTensorValue( diff --git a/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py b/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py index cd0c1ddf1e..cb8cb9a77d 100644 --- a/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py +++ b/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py @@ -22,6 +22,7 @@ import threading import numpy as np from tensorflow.python.client import session +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -30,7 +31,7 @@ from tensorflow.python.ops import script_ops from tensorflow.python.platform import test -class DatasetConstructorTest(test.TestCase): +class DatasetConstructorTest(test_base.DatasetTestBase): def _testFromGenerator(self, generator, elem_sequence, num_repeats, output_types=None): diff --git a/tensorflow/python/data/kernel_tests/dataset_ops_test.py b/tensorflow/python/data/kernel_tests/dataset_ops_test.py index 239aa85175..f115f9d9c7 100644 --- a/tensorflow/python/data/kernel_tests/dataset_ops_test.py +++ b/tensorflow/python/data/kernel_tests/dataset_ops_test.py @@ -19,11 +19,12 @@ from __future__ import division from __future__ import print_function from tensorflow.core.framework import graph_pb2 +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.platform import test -class DatasetOpsTest(test.TestCase): +class DatasetOpsTest(test_base.DatasetTestBase): def testAsSerializedGraph(self): dataset = dataset_ops.Dataset.range(10) diff --git a/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py b/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py index 19944d389f..6b7afafa5d 100644 --- a/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py @@ -22,6 +22,7 @@ import time import numpy as np from tensorflow.python.client import session +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -33,7 +34,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -class FilterDatasetTest(test.TestCase): +class FilterDatasetTest(test_base.DatasetTestBase): def testFilterDataset(self): components = ( @@ -129,11 +130,6 @@ class FilterDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - def assertSparseValuesEqual(self, a, b): - self.assertAllEqual(a.indices, b.indices) - self.assertAllEqual(a.values, b.values) - self.assertAllEqual(a.dense_shape, b.dense_shape) - def testSparse(self): def _map_fn(i): diff --git a/tensorflow/python/data/kernel_tests/flat_map_dataset_op_test.py b/tensorflow/python/data/kernel_tests/flat_map_dataset_op_test.py index 1123cbff62..68038f9cfc 100644 --- a/tensorflow/python/data/kernel_tests/flat_map_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/flat_map_dataset_op_test.py @@ -22,6 +22,7 @@ import random import numpy as np from tensorflow.python.client import session +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import errors from tensorflow.python.framework import sparse_tensor @@ -30,7 +31,7 @@ from tensorflow.python.platform import test from tensorflow.python.training import server_lib -class FlatMapDatasetTest(test.TestCase): +class FlatMapDatasetTest(test_base.DatasetTestBase): # pylint: disable=g-long-lambda def testFlatMapDataset(self): diff --git a/tensorflow/python/data/kernel_tests/inputs_test.py b/tensorflow/python/data/kernel_tests/inputs_test.py index 4c9279dd95..d089b49bcc 100644 --- a/tensorflow/python/data/kernel_tests/inputs_test.py +++ b/tensorflow/python/data/kernel_tests/inputs_test.py @@ -19,6 +19,7 @@ from __future__ import print_function from absl.testing import parameterized import numpy as np +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import readers from tensorflow.python.data.util import nest @@ -27,7 +28,7 @@ from tensorflow.python.framework import sparse_tensor from tensorflow.python.platform import test -class InputsTest(test.TestCase, parameterized.TestCase): +class InputsTest(test_base.DatasetTestBase, parameterized.TestCase): @staticmethod def make_apply_fn(dataset): diff --git a/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py b/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py index e7e51df65e..92bb67b6ff 100644 --- a/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py @@ -22,6 +22,7 @@ import itertools from absl.testing import parameterized import numpy as np +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import errors from tensorflow.python.framework import sparse_tensor @@ -30,7 +31,7 @@ from tensorflow.python.ops import sparse_ops from tensorflow.python.platform import test -class InterleaveDatasetTest(test.TestCase, parameterized.TestCase): +class InterleaveDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): def _interleave(self, lists, cycle_length, block_length): num_open = 0 diff --git a/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py b/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py index c4b338a58f..8eb13815d4 100644 --- a/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py @@ -22,6 +22,7 @@ from os import path import shutil import tempfile +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -30,7 +31,7 @@ from tensorflow.python.platform import test from tensorflow.python.util import compat -class ListFilesDatasetOpTest(test.TestCase): +class ListFilesDatasetOpTest(test_base.DatasetTestBase): def setUp(self): self.tmp_dir = tempfile.mkdtemp() diff --git a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py index ae04995436..230ae3f3fd 100644 --- a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py @@ -27,6 +27,7 @@ import numpy as np from tensorflow.core.framework import attr_value_pb2 from tensorflow.python.client import session +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -47,7 +48,7 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test -class MapDatasetTest(test.TestCase, parameterized.TestCase): +class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): def _buildMapDataset(self, components, count): def _map_fn(x, y, z): @@ -574,11 +575,6 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - def assertSparseValuesEqual(self, a, b): - self.assertAllEqual(a.indices, b.indices) - self.assertAllEqual(a.values, b.values) - self.assertAllEqual(a.dense_shape, b.dense_shape) - def testSparse(self): def _sparse(i): diff --git a/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py b/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py index 056664b83b..1cf6dd1bea 100644 --- a/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py +++ b/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import multi_device_iterator_ops from tensorflow.python.framework import dtypes @@ -29,7 +30,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class MultiDeviceIteratorTest(test.TestCase): +class MultiDeviceIteratorTest(test_base.DatasetTestBase): def testNoGetNext(self): dataset = dataset_ops.Dataset.range(10) diff --git a/tensorflow/python/data/kernel_tests/optional_ops_test.py b/tensorflow/python/data/kernel_tests/optional_ops_test.py index 706a65fe55..604e3ad88e 100644 --- a/tensorflow/python/data/kernel_tests/optional_ops_test.py +++ b/tensorflow/python/data/kernel_tests/optional_ops_test.py @@ -20,6 +20,7 @@ from __future__ import print_function from absl.testing import parameterized import numpy as np +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.ops import optional_ops @@ -35,7 +36,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class OptionalTest(test.TestCase, parameterized.TestCase): +class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase): @test_util.run_in_graph_and_eager_modes def testFromValue(self): diff --git a/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py b/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py index cc97bac609..76e2697b29 100644 --- a/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py @@ -19,6 +19,7 @@ from __future__ import print_function from absl.testing import parameterized +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -26,7 +27,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class PrefetchDatasetTest(test.TestCase, parameterized.TestCase): +class PrefetchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): @parameterized.parameters((-1), (0), (5)) def testBufferSize(self, buffer_size): diff --git a/tensorflow/python/data/kernel_tests/range_dataset_op_test.py b/tensorflow/python/data/kernel_tests/range_dataset_op_test.py index 51e90785e7..b7e2a5f615 100644 --- a/tensorflow/python/data/kernel_tests/range_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/range_dataset_op_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import os +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import dtypes @@ -34,7 +35,7 @@ from tensorflow.python.platform import gfile from tensorflow.python.platform import test -class RangeDatasetTest(test.TestCase): +class RangeDatasetTest(test_base.DatasetTestBase): def tearDown(self): # Remove all checkpoint files. diff --git a/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py b/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py index aa3636364d..aef2dd1d9c 100644 --- a/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py +++ b/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py @@ -21,6 +21,7 @@ import gzip import os import zlib +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.ops import readers @@ -46,7 +47,7 @@ except ImportError: psutil_import_succeeded = False -class TextLineDatasetTest(test.TestCase): +class TextLineDatasetTest(test_base.DatasetTestBase): def _lineText(self, f, l): return compat.as_bytes("%d: %d" % (f, l)) @@ -199,7 +200,7 @@ class TextLineDatasetTest(test.TestCase): self.assertNotIn(filename, [open_file.path for open_file in open_files]) -class FixedLengthRecordReaderTest(test.TestCase): +class FixedLengthRecordReaderTest(test_base.DatasetTestBase): def setUp(self): super(FixedLengthRecordReaderTest, self).setUp() @@ -621,7 +622,7 @@ class FixedLengthRecordReaderTest(test.TestCase): sess.run(get_next_op) -class TFRecordDatasetTest(test.TestCase): +class TFRecordDatasetTest(test_base.DatasetTestBase): def setUp(self): super(TFRecordDatasetTest, self).setUp() diff --git a/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py b/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py index 37e2333560..e86356dee7 100644 --- a/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -26,7 +27,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class SequenceDatasetTest(test.TestCase): +class SequenceDatasetTest(test_base.DatasetTestBase): def testRepeatTensorDataset(self): """Test a dataset that repeats its input multiple times.""" diff --git a/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py b/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py index 137f6341ce..b9f3c79da5 100644 --- a/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py @@ -17,12 +17,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import errors from tensorflow.python.platform import test -class ShardDatasetOpTest(test.TestCase): +class ShardDatasetOpTest(test_base.DatasetTestBase): def testSimpleCase(self): dataset = dataset_ops.Dataset.range(10).shard(5, 2) diff --git a/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py index f294840706..347af18576 100644 --- a/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py @@ -21,6 +21,7 @@ import collections import numpy as np +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import constant_op @@ -30,7 +31,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class ShuffleDatasetTest(test.TestCase): +class ShuffleDatasetTest(test_base.DatasetTestBase): def testShuffleDataset(self): components = ( diff --git a/tensorflow/python/data/kernel_tests/test_base.py b/tensorflow/python/data/kernel_tests/test_base.py new file mode 100644 index 0000000000..b4f64115b7 --- /dev/null +++ b/tensorflow/python/data/kernel_tests/test_base.py @@ -0,0 +1,29 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test utilities for tf.data functionality.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.platform import test + + +class DatasetTestBase(test.TestCase): + """Base class for dataset tests.""" + + def assertSparseValuesEqual(self, a, b): + self.assertAllEqual(a.indices, b.indices) + self.assertAllEqual(a.values, b.values) + self.assertAllEqual(a.dense_shape, b.dense_shape) diff --git a/tensorflow/python/data/kernel_tests/window_dataset_op_test.py b/tensorflow/python/data/kernel_tests/window_dataset_op_test.py index fd4348426d..9d06781094 100644 --- a/tensorflow/python/data/kernel_tests/window_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/window_dataset_op_test.py @@ -20,6 +20,7 @@ from __future__ import print_function from absl.testing import parameterized import numpy as np +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -29,7 +30,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -class WindowDatasetTest(test.TestCase, parameterized.TestCase): +class WindowDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): @parameterized.named_parameters( ("1", 20, 14, 7, 1), @@ -150,11 +151,6 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): stride_t: stride }) - def assertSparseValuesEqual(self, a, b): - self.assertAllEqual(a.indices, b.indices) - self.assertAllEqual(a.values, b.values) - self.assertAllEqual(a.dense_shape, b.dense_shape) - def testWindowSparse(self): def _sparse(i): diff --git a/tensorflow/python/data/kernel_tests/zip_dataset_op_test.py b/tensorflow/python/data/kernel_tests/zip_dataset_op_test.py index 3106effbd3..9d76387a34 100644 --- a/tensorflow/python/data/kernel_tests/zip_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/zip_dataset_op_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -26,7 +27,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class ZipDatasetTest(test.TestCase): +class ZipDatasetTest(test_base.DatasetTestBase): def testZipDataset(self): component_placeholders = [ |