diff options
Diffstat (limited to 'tensorflow/python/data/kernel_tests')
25 files changed, 760 insertions, 212 deletions
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD index 7a6f03d4d3..c7295d6e69 100644 --- a/tensorflow/python/data/kernel_tests/BUILD +++ b/tensorflow/python/data/kernel_tests/BUILD @@ -15,6 +15,7 @@ tf_py_test( size = "small", srcs = ["batch_dataset_op_test.py"], additional_deps = [ + ":test_base", "@absl_py//absl/testing:parameterized", "//third_party/py/numpy", "//tensorflow/python:array_ops", @@ -31,10 +32,44 @@ tf_py_test( ) tf_py_test( + name = "cache_dataset_op_test", + size = "small", + srcs = ["cache_dataset_op_test.py"], + additional_deps = [ + ":test_base", + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:variables", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/ops:iterator_ops", + ], +) + +tf_py_test( + name = "concatenate_dataset_op_test", + size = "small", + srcs = ["concatenate_dataset_op_test.py"], + additional_deps = [ + ":test_base", + "//third_party/py/numpy", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python:tensor_shape", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + ], +) + +tf_py_test( name = "dataset_constructor_op_test", size = "small", srcs = ["dataset_constructor_op_test.py"], additional_deps = [ + ":test_base", "//third_party/py/numpy", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", @@ -63,6 +98,7 @@ tf_py_test( size = "medium", srcs = ["dataset_from_generator_op_test.py"], additional_deps = [ + ":test_base", "//third_party/py/numpy", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", @@ -78,8 +114,11 @@ tf_py_test( size = "small", srcs = ["dataset_ops_test.py"], additional_deps = [ - "//tensorflow/core:protos_all_py", + ":test_base", + "@absl_py//absl/testing:parameterized", + "//third_party/py/numpy", "//tensorflow/python:client_testlib", + "//tensorflow/python:sparse_tensor", "//tensorflow/python/data/ops:dataset_ops", ], ) @@ -89,6 +128,7 @@ tf_py_test( size = "small", srcs = ["filter_dataset_op_test.py"], additional_deps = [ + ":test_base", "//third_party/py/numpy", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -106,6 +146,7 @@ tf_py_test( size = "small", srcs = ["flat_map_dataset_op_test.py"], additional_deps = [ + ":test_base", "//third_party/py/numpy", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", @@ -123,6 +164,7 @@ tf_py_test( size = "small", srcs = ["list_files_dataset_op_test.py"], additional_deps = [ + ":test_base", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", @@ -137,6 +179,7 @@ tf_py_test( size = "small", srcs = ["interleave_dataset_op_test.py"], additional_deps = [ + ":test_base", "@absl_py//absl/testing:parameterized", "//third_party/py/numpy", "//tensorflow/python:array_ops", @@ -151,11 +194,80 @@ tf_py_test( ], ) +cuda_py_test( + name = "iterator_ops_test", + size = "small", + srcs = ["iterator_ops_test.py"], + additional_deps = [ + "//third_party/py/numpy", + "//tensorflow/python/data/ops:readers", + "//tensorflow/core:protos_all_py", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/ops:iterator_ops", + "//tensorflow/python/data/util:sparse", + "//tensorflow/python/eager:context", + "//tensorflow/python/training/checkpointable:util", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:function", + "//tensorflow/python:functional_ops", + "//tensorflow/python:gradients", + "//tensorflow/python:io_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:script_ops", + "//tensorflow/python:session", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:training", + "//tensorflow/python/compat:compat", + "//tensorflow/python:util", + "//tensorflow/python:variables", + ], + grpc_enabled = True, +) + +tf_py_test( + name = "iterator_ops_cluster_test", + size = "small", + srcs = ["iterator_ops_cluster_test.py"], + additional_deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:function", + "//tensorflow/python:functional_ops", + "//tensorflow/python:session", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/ops:iterator_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:string_ops", + "//tensorflow/python:lookup_ops", + ], + grpc_enabled = True, + tags = [ + "no_oss", # Test flaky due to port collisions. + "no_windows", + ], +) + tf_py_test( name = "map_dataset_op_test", size = "small", srcs = ["map_dataset_op_test.py"], additional_deps = [ + ":test_base", "@absl_py//absl/testing:parameterized", "//third_party/py/numpy", "//tensorflow/python:array_ops", @@ -177,11 +289,54 @@ tf_py_test( ], ) +cuda_py_test( + name = "multi_device_iterator_test", + size = "medium", + srcs = ["multi_device_iterator_test.py"], + additional_deps = [ + ":test_base", + "//tensorflow/core:protos_all_py", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/ops:multi_device_iterator_ops", + "//tensorflow/python/data/ops:iterator_ops", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:framework_test_lib", + ], + tags = [ + "no_windows_gpu", + ], +) + +cuda_py_test( + name = "optional_ops_test", + size = "small", + srcs = ["optional_ops_test.py"], + additional_deps = [ + ":test_base", + "@absl_py//absl/testing:parameterized", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/ops:iterator_ops", + "//tensorflow/python/data/ops:optional_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:tensor_shape", + ], +) + tf_py_test( name = "prefetch_dataset_op_test", size = "small", srcs = ["prefetch_dataset_op_test.py"], additional_deps = [ + ":test_base", "@absl_py//absl/testing:parameterized", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -197,6 +352,7 @@ tf_py_test( size = "small", srcs = ["range_dataset_op_test.py"], additional_deps = [ + ":test_base", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dataset_ops_gen", @@ -218,6 +374,7 @@ tf_py_test( size = "small", srcs = ["reader_dataset_ops_test.py"], additional_deps = [ + ":test_base", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", @@ -236,32 +393,35 @@ tf_py_test( ) tf_py_test( - name = "sequence_dataset_op_test", + name = "reduce_dataset_op_test", size = "small", - srcs = ["sequence_dataset_op_test.py"], + srcs = ["reduce_dataset_op_test.py"], additional_deps = [ + ":test_base", + "@absl_py//absl/testing:parameterized", "//third_party/py/numpy", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "//tensorflow/python:math_ops", + "//tensorflow/python:sparse_tensor", "//tensorflow/python/data/ops:dataset_ops", ], ) tf_py_test( - name = "shuffle_dataset_op_test", + name = "sequence_dataset_op_test", size = "small", - srcs = ["shuffle_dataset_op_test.py"], + srcs = ["sequence_dataset_op_test.py"], additional_deps = [ + ":test_base", "//third_party/py/numpy", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/ops:iterator_ops", ], ) @@ -270,6 +430,7 @@ tf_py_test( size = "small", srcs = ["shard_dataset_op_test.py"], additional_deps = [ + ":test_base", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", "//tensorflow/python/data/ops:dataset_ops", @@ -277,155 +438,30 @@ tf_py_test( ) tf_py_test( - name = "cache_dataset_op_test", + name = "shuffle_dataset_op_test", size = "small", - srcs = ["cache_dataset_op_test.py"], + srcs = ["shuffle_dataset_op_test.py"], additional_deps = [ + ":test_base", "//third_party/py/numpy", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:errors", - "//tensorflow/python:variables", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:iterator_ops", ], ) -tf_py_test( - name = "zip_dataset_op_test", - size = "small", - srcs = ["zip_dataset_op_test.py"], - additional_deps = [ - "//third_party/py/numpy", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python/data/ops:dataset_ops", - ], -) - -tf_py_test( - name = "concatenate_dataset_op_test", - size = "small", - srcs = ["concatenate_dataset_op_test.py"], - additional_deps = [ - "//third_party/py/numpy", - "//tensorflow/python:client_testlib", - "//tensorflow/python:errors", - "//tensorflow/python:tensor_shape", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/util:nest", - ], -) - -cuda_py_test( - name = "iterator_ops_test", - size = "small", - srcs = ["iterator_ops_test.py"], - additional_deps = [ - "//third_party/py/numpy", - "//tensorflow/python/data/ops:readers", - "//tensorflow/core:protos_all_py", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/ops:iterator_ops", - "//tensorflow/python/data/util:sparse", - "//tensorflow/python/eager:context", - "//tensorflow/python/training/checkpointable:util", - "//tensorflow/python:array_ops", +py_library( + name = "test_base", + srcs = ["test_base.py"], + deps = [ "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:dataset_ops_gen", - "//tensorflow/python:dtypes", "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:function", - "//tensorflow/python:functional_ops", - "//tensorflow/python:gradients", - "//tensorflow/python:io_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:random_ops", - "//tensorflow/python:script_ops", - "//tensorflow/python:session", "//tensorflow/python:sparse_tensor", - "//tensorflow/python:tensor_shape", - "//tensorflow/python:training", - "//tensorflow/python/compat:compat", - "//tensorflow/python:util", - "//tensorflow/python:variables", - ], - grpc_enabled = True, -) - -tf_py_test( - name = "iterator_ops_cluster_test", - size = "small", - srcs = ["iterator_ops_cluster_test.py"], - additional_deps = [ - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:function", - "//tensorflow/python:functional_ops", - "//tensorflow/python:session", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/ops:iterator_ops", - "//tensorflow/python:constant_op", - "//tensorflow/python:string_ops", - "//tensorflow/python:lookup_ops", - ], - grpc_enabled = True, - tags = [ - "no_oss", # Test flaky due to port collisions. - "no_windows", - ], -) - -cuda_py_test( - name = "optional_ops_test", - size = "small", - srcs = ["optional_ops_test.py"], - additional_deps = [ - "@absl_py//absl/testing:parameterized", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/ops:iterator_ops", - "//tensorflow/python/data/ops:optional_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:array_ops", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:tensor_shape", - ], -) - -cuda_py_test( - name = "multi_device_iterator_test", - size = "small", - srcs = ["multi_device_iterator_test.py"], - additional_deps = [ - "//tensorflow/core:protos_all_py", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/ops:multi_device_iterator_ops", - "//tensorflow/python/data/ops:iterator_ops", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:framework_test_lib", - ], - tags = [ - "no_windows_gpu", + "//tensorflow/python/data/util:nest", ], ) @@ -434,6 +470,7 @@ tf_py_test( size = "small", srcs = ["window_dataset_op_test.py"], additional_deps = [ + ":test_base", "@absl_py//absl/testing:parameterized", "//third_party/py/numpy", "//tensorflow/python:array_ops", @@ -447,14 +484,16 @@ tf_py_test( ) tf_py_test( - name = "inputs_test", + name = "zip_dataset_op_test", size = "small", - srcs = ["inputs_test.py"], + srcs = ["zip_dataset_op_test.py"], additional_deps = [ - "@absl_py//absl/testing:parameterized", + ":test_base", "//third_party/py/numpy", + "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:sparse_tensor", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", "//tensorflow/python/data/ops:dataset_ops", ], ) diff --git a/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py b/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py index c48708a2b9..9cb4daf284 100644 --- a/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py @@ -24,6 +24,7 @@ from absl.testing import parameterized import numpy as np from tensorflow.python.client import session +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -37,7 +38,7 @@ from tensorflow.python.platform import test from tensorflow.python.util import compat -class BatchDatasetTest(test.TestCase, parameterized.TestCase): +class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): @parameterized.named_parameters( ('even', 28, 14, False), @@ -115,11 +116,6 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): with self.assertRaises(errors.InvalidArgumentError): sess.run(get_next) - def assertSparseValuesEqual(self, a, b): - self.assertAllEqual(a.indices, b.indices) - self.assertAllEqual(a.values, b.values) - self.assertAllEqual(a.dense_shape, b.dense_shape) - def testBatchSparse(self): def _sparse(i): @@ -227,7 +223,7 @@ def _random_seq_lens(count): return np.random.randint(20, size=(count,)).astype(np.int32) -class PaddedBatchDatasetTest(test.TestCase, parameterized.TestCase): +class PaddedBatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): @parameterized.named_parameters( ('default_padding', _random_seq_lens(32), 4, [-1], False), diff --git a/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py b/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py index d5f5b2fe05..63625fac03 100644 --- a/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py @@ -23,6 +23,7 @@ import tempfile import numpy as np +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import constant_op @@ -34,7 +35,7 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import test -class FileCacheDatasetTest(test.TestCase): +class FileCacheDatasetTest(test_base.DatasetTestBase): def setUp(self): self.tmp_dir = tempfile.mkdtemp() @@ -200,7 +201,7 @@ class FileCacheDatasetTest(test.TestCase): self.assertAllEqual(elements, elements_itr2) -class MemoryCacheDatasetTest(test.TestCase): +class MemoryCacheDatasetTest(test_base.DatasetTestBase): def testCacheDatasetPassthrough(self): with ops.device("cpu:0"): diff --git a/tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py b/tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py index 5dfb84f28e..83af31f380 100644 --- a/tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.framework import errors @@ -26,7 +27,7 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.platform import test -class ConcatenateDatasetTest(test.TestCase): +class ConcatenateDatasetTest(test_base.DatasetTestBase): def testConcatenateDataset(self): input_components = ( diff --git a/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py b/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py index e43564a2eb..bc6b36285a 100644 --- a/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py +++ b/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py @@ -23,6 +23,7 @@ import numpy as np from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.framework import dtypes @@ -36,7 +37,7 @@ from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test -class DatasetConstructorTest(test.TestCase): +class DatasetConstructorTest(test_base.DatasetTestBase): def testFromTensors(self): """Test a dataset that represents a single tuple of tensors.""" @@ -58,11 +59,6 @@ class DatasetConstructorTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - def assertSparseValuesEqual(self, a, b): - self.assertAllEqual(a.indices, b.indices) - self.assertAllEqual(a.values, b.values) - self.assertAllEqual(a.dense_shape, b.dense_shape) - def testFromTensorsSparse(self): """Test a dataset that represents a single tuple of tensors.""" components = (sparse_tensor.SparseTensorValue( diff --git a/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py b/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py index cd0c1ddf1e..cb8cb9a77d 100644 --- a/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py +++ b/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py @@ -22,6 +22,7 @@ import threading import numpy as np from tensorflow.python.client import session +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -30,7 +31,7 @@ from tensorflow.python.ops import script_ops from tensorflow.python.platform import test -class DatasetConstructorTest(test.TestCase): +class DatasetConstructorTest(test_base.DatasetTestBase): def _testFromGenerator(self, generator, elem_sequence, num_repeats, output_types=None): diff --git a/tensorflow/python/data/kernel_tests/dataset_ops_test.py b/tensorflow/python/data/kernel_tests/dataset_ops_test.py index 239aa85175..b9f8875b9f 100644 --- a/tensorflow/python/data/kernel_tests/dataset_ops_test.py +++ b/tensorflow/python/data/kernel_tests/dataset_ops_test.py @@ -18,12 +18,20 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized +import numpy as np + from tensorflow.core.framework import graph_pb2 +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import readers +from tensorflow.python.data.util import nest +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import sparse_tensor from tensorflow.python.platform import test -class DatasetOpsTest(test.TestCase): +class DatasetOpsTest(test_base.DatasetTestBase, parameterized.TestCase): def testAsSerializedGraph(self): dataset = dataset_ops.Dataset.range(10) @@ -32,6 +40,155 @@ class DatasetOpsTest(test.TestCase): sess.run(dataset._as_serialized_graph())) self.assertTrue(any([node.op != "RangeDataset" for node in graph.node])) + @staticmethod + def make_apply_fn(dataset): + + def apply_fn(dataset): + + def _apply_fn(dataset): + return dataset.cache() + + return dataset.apply(_apply_fn) + + return apply_fn + + @staticmethod + def make_gen(): + + def gen(): + yield 42 + + return gen + + @staticmethod + def make_interleave_fn(dataset, num_parallel_calls=None): + + def interleave_fn(dataset): + return dataset.interleave( + lambda x: dataset_ops.Dataset.range(0), + cycle_length=2, + num_parallel_calls=num_parallel_calls) + + return interleave_fn + + @parameterized.named_parameters( + ("FixedLengthRecord", readers.FixedLengthRecordDataset("", 42)), + ("FromGenerator", + dataset_ops.Dataset.from_generator(make_gen.__func__(), dtypes.int32), + 1), + ("FromSparseTensorSlices", + dataset_ops.Dataset.from_sparse_tensor_slices( + sparse_tensor.SparseTensor( + indices=np.array([[0, 0], [1, 0], [2, 0]]), + values=np.array([0, 0, 0]), + dense_shape=np.array([3, 1])))), + ("FromTensors", dataset_ops.Dataset.from_tensors([42])), + ("FromTensorSlices", dataset_ops.Dataset.from_tensors([42])), + ("Range", dataset_ops.Dataset.range(10)), + ("TextLine", readers.TextLineDataset("")), + ("TFRecord", readers.TFRecordDataset(""), 1), + ) + def testDatasetSourceInputs(self, dataset, num_inputs=0): + self.assertEqual(num_inputs, len(dataset._inputs())) + + @parameterized.named_parameters( + ("Apply", make_apply_fn.__func__(dataset_ops.Dataset.range(0)), + dataset_ops.Dataset.range(0)), + ("Batch", lambda x: x.batch(10), dataset_ops.Dataset.range(0)), + ("Cache", lambda x: x.cache(), dataset_ops.Dataset.range(0)), + ("Filter", lambda x: x.filter(lambda x: True), + dataset_ops.Dataset.range(0)), + ("FlatMap", lambda x: x.flat_map(lambda x: dataset_ops.Dataset.range(0)), + dataset_ops.Dataset.range(0)), + ("Interleave", make_interleave_fn.__func__(dataset_ops.Dataset.range(0)), + dataset_ops.Dataset.range(0)), + ("Map", lambda x: x.map(lambda x: x), dataset_ops.Dataset.range(0)), + ("PaddedBatch", lambda x: x.padded_batch(10, []), + dataset_ops.Dataset.range(0)), + ("ParallelInterleave", + make_interleave_fn.__func__(dataset_ops.Dataset.range(0), 2), + dataset_ops.Dataset.range(0)), + ("ParallelMap", lambda x: x.map(lambda x: x, num_parallel_calls=2), + dataset_ops.Dataset.range(0)), + ("Repeat", lambda x: x.repeat(), dataset_ops.Dataset.range(0)), + ("Shuffle", lambda x: x.shuffle(10), dataset_ops.Dataset.range(0)), + ("Skip", lambda x: x.skip(1), dataset_ops.Dataset.range(0)), + ("Take", lambda x: x.take(1), dataset_ops.Dataset.range(0)), + ("Window", lambda x: x.window(10), dataset_ops.Dataset.range(0)), + ) + def testUnaryTransformationInputs(self, dataset_fn, input_dataset): + self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs()) + + @parameterized.named_parameters( + ("Concatenate", lambda x, y: x.concatenate(y), + dataset_ops.Dataset.range(0), dataset_ops.Dataset.range(1))) + def testBinaryTransformationInputs(self, dataset_fn, input1, input2): + self.assertEqual([input1, input2], dataset_fn(input1, input2)._inputs()) + + @parameterized.named_parameters( + ("ZipOne", dataset_ops.Dataset.zip, (dataset_ops.Dataset.range(0))), + ("ZipNest", dataset_ops.Dataset.zip, + (dataset_ops.Dataset.range(0), + (dataset_ops.Dataset.range(1), dataset_ops.Dataset.range(2)))), + ("ZipTuple", dataset_ops.Dataset.zip, + (dataset_ops.Dataset.range(0), dataset_ops.Dataset.range(1)))) + def testVariadicTransformationInputs(self, dataset_fn, input_datasets): + self.assertEqual( + nest.flatten(input_datasets), + dataset_fn(input_datasets)._inputs()) + + def testCollectInputs(self): + ds1 = dataset_ops.Dataset.range(0) + ds2 = ds1.concatenate(ds1) + ds3 = dataset_ops.Dataset.zip((ds2, ds1, ds2)) + + inputs = [] + queue = [ds3] + while queue: + ds = queue[0] + queue = queue[1:] + queue.extend(ds._inputs()) + inputs.append(ds) + + self.assertEqual(5, inputs.count(ds1)) + self.assertEqual(2, inputs.count(ds2)) + self.assertEqual(1, inputs.count(ds3)) + + def testOptionsDefault(self): + ds = dataset_ops.Dataset.range(0) + self.assertEqual(dataset_ops.Options(), ds.options()) + + def testOptionsOnce(self): + options = dataset_ops.Options() + ds = dataset_ops.Dataset.range(0).with_options(options).cache() + self.assertEqual(options, ds.options()) + + def testOptionsTwiceSame(self): + options = dataset_ops.Options() + options.experimental_autotune = True + ds = dataset_ops.Dataset.range(0).with_options(options).with_options( + options) + self.assertEqual(options, ds.options()) + + def testOptionsTwiceDifferent(self): + options1 = dataset_ops.Options() + options1.experimental_autotune = True + options2 = dataset_ops.Options() + options2.experimental_filter_fusion = False + ds = dataset_ops.Dataset.range(0).with_options(options1).with_options( + options2) + self.assertTrue(ds.options().experimental_autotune) + self.assertFalse(ds.options().experimental_filter_fusion) + + def testOptionsTwiceDifferentError(self): + options1 = dataset_ops.Options() + options1.experimental_autotune = True + options2 = dataset_ops.Options() + options2.experimental_autotune = False + with self.assertRaisesRegexp(ValueError, + "Cannot merge incompatible values of option"): + dataset_ops.Dataset.range(0).with_options(options1).with_options(options2) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py b/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py index 19944d389f..a0c6b37a6d 100644 --- a/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py @@ -22,6 +22,7 @@ import time import numpy as np from tensorflow.python.client import session +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -33,7 +34,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -class FilterDatasetTest(test.TestCase): +class FilterDatasetTest(test_base.DatasetTestBase): def testFilterDataset(self): components = ( @@ -129,11 +130,6 @@ class FilterDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - def assertSparseValuesEqual(self, a, b): - self.assertAllEqual(a.indices, b.indices) - self.assertAllEqual(a.values, b.values) - self.assertAllEqual(a.dense_shape, b.dense_shape) - def testSparse(self): def _map_fn(i): @@ -160,7 +156,7 @@ class FilterDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - def testReturnComponent(self): + def testShortCircuit(self): iterator = ( dataset_ops.Dataset.zip( (dataset_ops.Dataset.range(10), diff --git a/tensorflow/python/data/kernel_tests/flat_map_dataset_op_test.py b/tensorflow/python/data/kernel_tests/flat_map_dataset_op_test.py index 1123cbff62..68038f9cfc 100644 --- a/tensorflow/python/data/kernel_tests/flat_map_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/flat_map_dataset_op_test.py @@ -22,6 +22,7 @@ import random import numpy as np from tensorflow.python.client import session +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import errors from tensorflow.python.framework import sparse_tensor @@ -30,7 +31,7 @@ from tensorflow.python.platform import test from tensorflow.python.training import server_lib -class FlatMapDatasetTest(test.TestCase): +class FlatMapDatasetTest(test_base.DatasetTestBase): # pylint: disable=g-long-lambda def testFlatMapDataset(self): diff --git a/tensorflow/python/data/kernel_tests/inputs_test.py b/tensorflow/python/data/kernel_tests/inputs_test.py index 4c9279dd95..d089b49bcc 100644 --- a/tensorflow/python/data/kernel_tests/inputs_test.py +++ b/tensorflow/python/data/kernel_tests/inputs_test.py @@ -19,6 +19,7 @@ from __future__ import print_function from absl.testing import parameterized import numpy as np +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import readers from tensorflow.python.data.util import nest @@ -27,7 +28,7 @@ from tensorflow.python.framework import sparse_tensor from tensorflow.python.platform import test -class InputsTest(test.TestCase, parameterized.TestCase): +class InputsTest(test_base.DatasetTestBase, parameterized.TestCase): @staticmethod def make_apply_fn(dataset): diff --git a/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py b/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py index e7e51df65e..92bb67b6ff 100644 --- a/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py @@ -22,6 +22,7 @@ import itertools from absl.testing import parameterized import numpy as np +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import errors from tensorflow.python.framework import sparse_tensor @@ -30,7 +31,7 @@ from tensorflow.python.ops import sparse_ops from tensorflow.python.platform import test -class InterleaveDatasetTest(test.TestCase, parameterized.TestCase): +class InterleaveDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): def _interleave(self, lists, cycle_length, block_length): num_open = 0 diff --git a/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py b/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py index c4b338a58f..8eb13815d4 100644 --- a/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py @@ -22,6 +22,7 @@ from os import path import shutil import tempfile +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -30,7 +31,7 @@ from tensorflow.python.platform import test from tensorflow.python.util import compat -class ListFilesDatasetOpTest(test.TestCase): +class ListFilesDatasetOpTest(test_base.DatasetTestBase): def setUp(self): self.tmp_dir = tempfile.mkdtemp() diff --git a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py index ae04995436..4683b1db91 100644 --- a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the experimental input pipeline ops.""" +"""Tests for `tf.data.Dataset.map()`.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -27,6 +27,7 @@ import numpy as np from tensorflow.core.framework import attr_value_pb2 from tensorflow.python.client import session +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -47,7 +48,7 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test -class MapDatasetTest(test.TestCase, parameterized.TestCase): +class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): def _buildMapDataset(self, components, count): def _map_fn(x, y, z): @@ -266,6 +267,35 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + def testCaptureIterator(self): + + def _build_ds(iterator): + + def _map_fn(x): + get_next = iterator.get_next() + return x * get_next + + return dataset_ops.Dataset.range(10).map(_map_fn) + + def _build_graph(): + captured_iterator = dataset_ops.Dataset.range( + 10).make_initializable_iterator() + ds = _build_ds(captured_iterator) + iterator = ds.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + return captured_iterator.initializer, init_op, get_next + + with ops.Graph().as_default() as g: + captured_init_op, init_op, get_next = _build_graph() + with self.session(graph=g) as sess: + sess.run(captured_init_op) + sess.run(init_op) + for i in range(10): + self.assertEqual(i * i, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + def testCaptureHashTable(self): # NOTE(mrry): We must use the V2 variants of `HashTable` # etc. because these produce a `tf.resource`-typed output that is @@ -574,11 +604,6 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - def assertSparseValuesEqual(self, a, b): - self.assertAllEqual(a.indices, b.indices) - self.assertAllEqual(a.values, b.values) - self.assertAllEqual(a.dense_shape, b.dense_shape) - def testSparse(self): def _sparse(i): @@ -597,7 +622,7 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase): sess.run(init_op) for i in range(10): actual = sess.run(get_next) - self.assertTrue(isinstance(actual, sparse_tensor.SparseTensorValue)) + self.assertIsInstance(actual, sparse_tensor.SparseTensorValue) self.assertSparseValuesEqual(actual, _sparse(i)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) @@ -624,7 +649,7 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase): sess.run(init_op) for i in range(10): actual = sess.run(get_next) - self.assertTrue(isinstance(actual, sparse_tensor.SparseTensorValue)) + self.assertIsInstance(actual, sparse_tensor.SparseTensorValue) self.assertSparseValuesEqual(actual, _check(_sparse(i)).eval()) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) @@ -758,19 +783,72 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase): self.assertTrue(all(tids[0] == tid for tid in tids)) # pylint: enable=g-long-lambda + @parameterized.named_parameters( + ("SequentialIdentity", None, lambda x: x, None), + ("SequentialReplicate", None, lambda x: (x, x), None), + ("SequentialSwap", (None, None), lambda x, y: (y, x), None), + ("SequentialProject", (None, None), lambda x, y: x, None), + ("ParallelIdentity", None, lambda x: x, 10), + ("ParallelReplicate", None, lambda x: (x, x), 10), + ("ParallelSwap", (None, None), lambda x, y: (y, x), 10), + ("ParallelProject", (None, None), lambda x, y: x, 10), + ) + def testShortCircuit(self, structure, map_fn, num_parallel_calls): + dataset = self.structuredDataset(structure).repeat().map( + map_fn, num_parallel_calls=num_parallel_calls) + get_next = dataset.make_one_shot_iterator().get_next() + + with self.cached_session() as sess: + if isinstance(structure, tuple): + expected = map_fn(*sess.run(self.structuredElement(structure))) + else: + expected = map_fn(sess.run(self.structuredElement(structure))) + self.assertEqual(expected, sess.run(get_next)) + + @parameterized.named_parameters( + ("Sequential", None), + ("Parallel", 10), + ) + def testShortCircuitCapturedInput(self, num_parallel_calls): + captured_t = array_ops.placeholder(dtypes.int64, shape=[]) + dataset = self.structuredDataset(None).repeat().map( + lambda x: captured_t, num_parallel_calls=num_parallel_calls) + iterator = dataset.make_initializable_iterator() + get_next = iterator.get_next() + + with self.cached_session() as sess: + sess.run(iterator.initializer, feed_dict={captured_t: 42}) + self.assertEqual(42, sess.run(get_next)) + class MapDatasetBenchmark(test.Benchmark): def benchmarkChainOfMaps(self): chain_lengths = [0, 1, 2, 5, 10, 20, 50] for chain_length in chain_lengths: - for use_inter_op_parallelism in [False, True]: + for mode in ["general", "single-threaded", "short-circuit"]: + if mode == "general": + map_fn = lambda x: x + 1 + use_inter_op_parallelism = True + print_label = "" + benchmark_label = "" + if mode == "single-threaded": + map_fn = lambda x: x + 1 + use_inter_op_parallelism = False + print_label = " (single threaded mode)" + benchmark_label = "_single_threaded" + if mode == "short-circuit": + map_fn = lambda x: x + use_inter_op_parallelism = True # should not have any significance + print_label = " (short circuit mode)" + benchmark_label = "_short_circuit" + with ops.Graph().as_default(): dataset = dataset_ops.Dataset.from_tensors(0).repeat(None) for _ in range(chain_length): dataset = dataset_ops.MapDataset( dataset, - lambda x: x, + map_fn, use_inter_op_parallelism=use_inter_op_parallelism) iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() @@ -788,25 +866,39 @@ class MapDatasetBenchmark(test.Benchmark): median_wall_time = np.median(deltas) / 100 print("Map dataset chain length%s: %d Median wall time: %f" % - (" (single threaded mode)" if not use_inter_op_parallelism - else "", chain_length, median_wall_time)) + (print_label, chain_length, median_wall_time)) self.report_benchmark( iters=1000, wall_time=median_wall_time, name="benchmark_map_dataset_chain_latency_%d%s" % - (chain_length, "_single_threaded" - if not use_inter_op_parallelism else "")) + (chain_length, benchmark_label)) def benchmarkMapFanOut(self): fan_outs = [1, 2, 5, 10, 20, 50, 100] for fan_out in fan_outs: - for use_inter_op_parallelism in [False, True]: + for mode in ["general", "single-threaded", "short-circuit"]: + if mode == "general": + map_fn = lambda *xs: [x + 1 for x in xs] + use_inter_op_parallelism = True + print_label = "" + benchmark_label = "" + if mode == "single-threaded": + map_fn = lambda *xs: [x + 1 for x in xs] + use_inter_op_parallelism = False + print_label = " (single threaded mode)" + benchmark_label = "_single_threaded" + if mode == "short-circuit": + map_fn = lambda *xs: xs + use_inter_op_parallelism = True # should not have any significance + print_label = " (short circuit mode)" + benchmark_label = "_short_circuit" + with ops.Graph().as_default(): dataset = dataset_ops.Dataset.from_tensors( tuple(0 for _ in range(fan_out))).repeat(None) dataset = dataset_ops.MapDataset( dataset, - lambda *xs: xs, + map_fn, use_inter_op_parallelism=use_inter_op_parallelism) iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() @@ -824,14 +916,12 @@ class MapDatasetBenchmark(test.Benchmark): median_wall_time = np.median(deltas) / 100 print("Map dataset fan out%s: %d Median wall time: %f" % - (" (single threaded mode)" if not use_inter_op_parallelism - else "", fan_out, median_wall_time)) + (print_label, fan_out, median_wall_time)) self.report_benchmark( iters=1000, wall_time=median_wall_time, - name="benchmark_map_dataset_fan_out_%d%s" % - (fan_out, "_single_threaded" - if not use_inter_op_parallelism else "")) + name="benchmark_map_dataset_fan_out_%d%s" % (fan_out, + benchmark_label)) if __name__ == "__main__": diff --git a/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py b/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py index 056664b83b..1cf6dd1bea 100644 --- a/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py +++ b/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import multi_device_iterator_ops from tensorflow.python.framework import dtypes @@ -29,7 +30,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class MultiDeviceIteratorTest(test.TestCase): +class MultiDeviceIteratorTest(test_base.DatasetTestBase): def testNoGetNext(self): dataset = dataset_ops.Dataset.range(10) diff --git a/tensorflow/python/data/kernel_tests/optional_ops_test.py b/tensorflow/python/data/kernel_tests/optional_ops_test.py index 706a65fe55..604e3ad88e 100644 --- a/tensorflow/python/data/kernel_tests/optional_ops_test.py +++ b/tensorflow/python/data/kernel_tests/optional_ops_test.py @@ -20,6 +20,7 @@ from __future__ import print_function from absl.testing import parameterized import numpy as np +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.ops import optional_ops @@ -35,7 +36,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class OptionalTest(test.TestCase, parameterized.TestCase): +class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase): @test_util.run_in_graph_and_eager_modes def testFromValue(self): diff --git a/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py b/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py index cc97bac609..76e2697b29 100644 --- a/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py @@ -19,6 +19,7 @@ from __future__ import print_function from absl.testing import parameterized +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -26,7 +27,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class PrefetchDatasetTest(test.TestCase, parameterized.TestCase): +class PrefetchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): @parameterized.parameters((-1), (0), (5)) def testBufferSize(self, buffer_size): diff --git a/tensorflow/python/data/kernel_tests/range_dataset_op_test.py b/tensorflow/python/data/kernel_tests/range_dataset_op_test.py index 51e90785e7..b7e2a5f615 100644 --- a/tensorflow/python/data/kernel_tests/range_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/range_dataset_op_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import os +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import dtypes @@ -34,7 +35,7 @@ from tensorflow.python.platform import gfile from tensorflow.python.platform import test -class RangeDatasetTest(test.TestCase): +class RangeDatasetTest(test_base.DatasetTestBase): def tearDown(self): # Remove all checkpoint files. diff --git a/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py b/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py index aa3636364d..aef2dd1d9c 100644 --- a/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py +++ b/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py @@ -21,6 +21,7 @@ import gzip import os import zlib +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.ops import readers @@ -46,7 +47,7 @@ except ImportError: psutil_import_succeeded = False -class TextLineDatasetTest(test.TestCase): +class TextLineDatasetTest(test_base.DatasetTestBase): def _lineText(self, f, l): return compat.as_bytes("%d: %d" % (f, l)) @@ -199,7 +200,7 @@ class TextLineDatasetTest(test.TestCase): self.assertNotIn(filename, [open_file.path for open_file in open_files]) -class FixedLengthRecordReaderTest(test.TestCase): +class FixedLengthRecordReaderTest(test_base.DatasetTestBase): def setUp(self): super(FixedLengthRecordReaderTest, self).setUp() @@ -621,7 +622,7 @@ class FixedLengthRecordReaderTest(test.TestCase): sess.run(get_next_op) -class TFRecordDatasetTest(test.TestCase): +class TFRecordDatasetTest(test_base.DatasetTestBase): def setUp(self): super(TFRecordDatasetTest, self).setUp() diff --git a/tensorflow/python/data/kernel_tests/reduce_dataset_op_test.py b/tensorflow/python/data/kernel_tests/reduce_dataset_op_test.py new file mode 100644 index 0000000000..11e07300b9 --- /dev/null +++ b/tensorflow/python/data/kernel_tests/reduce_dataset_op_test.py @@ -0,0 +1,124 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class ReduceDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): + + def testSum(self): + for i in range(10): + ds = dataset_ops.Dataset.range(1, i + 1) + result = ds.reduce(np.int64(0), lambda x, y: x + y) + with self.cached_session() as sess: + self.assertEqual(((i + 1) * i) // 2, sess.run(result)) + + def testSumTuple(self): + + def reduce_fn(state, value): + v1, v2 = value + return state + v1 + v2 + + for i in range(10): + ds = dataset_ops.Dataset.range(1, i + 1) + ds = dataset_ops.Dataset.zip((ds, ds)) + result = ds.reduce(np.int64(0), reduce_fn) + with self.cached_session() as sess: + self.assertEqual(((i + 1) * i), sess.run(result)) + + def testSumAndCount(self): + + def reduce_fn(state, value): + s, c = state + return s + value, c + 1 + + for i in range(10): + ds = dataset_ops.Dataset.range(1, i + 1) + result = ds.reduce((np.int64(0), np.int64(0)), reduce_fn) + with self.cached_session() as sess: + s, c = sess.run(result) + self.assertEqual(((i + 1) * i) // 2, s) + self.assertEqual(i, c) + + def testSquareUsingPlaceholder(self): + delta = array_ops.placeholder(dtype=dtypes.int64) + + def reduce_fn(state, _): + return state + delta + + for i in range(10): + ds = dataset_ops.Dataset.range(1, i + 1) + result = ds.reduce(np.int64(0), reduce_fn) + with self.cached_session() as sess: + square = sess.run(result, feed_dict={delta: i}) + self.assertEqual(i * i, square) + + def testSparse(self): + + def reduce_fn(_, value): + return value + + def make_sparse_fn(i): + return sparse_tensor.SparseTensorValue( + indices=np.array([[0, 0]]), + values=(i * np.array([1])), + dense_shape=np.array([1, 1])) + + for i in range(10): + ds = dataset_ops.Dataset.from_tensors(make_sparse_fn(i+1)) + result = ds.reduce(make_sparse_fn(0), reduce_fn) + with self.cached_session() as sess: + self.assertSparseValuesEqual(make_sparse_fn(i+1), sess.run(result)) + + def testNested(self): + + def reduce_fn(state, value): + state["dense"] += value["dense"] + state["sparse"] = value["sparse"] + return state + + def make_sparse_fn(i): + return sparse_tensor.SparseTensorValue( + indices=np.array([[0, 0]]), + values=(i * np.array([1])), + dense_shape=np.array([1, 1])) + + def map_fn(i): + return {"dense": math_ops.cast(i, dtype=dtypes.int64), + "sparse": make_sparse_fn(math_ops.cast(i, dtype=dtypes.int64))} + + for i in range(10): + ds = dataset_ops.Dataset.range(1, i + 1).map(map_fn) + result = ds.reduce(map_fn(0), reduce_fn) + with self.cached_session() as sess: + result = sess.run(result) + self.assertEqual(((i + 1) * i) // 2, result["dense"]) + self.assertSparseValuesEqual(make_sparse_fn(i), result["sparse"]) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py b/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py index 37e2333560..e86356dee7 100644 --- a/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -26,7 +27,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class SequenceDatasetTest(test.TestCase): +class SequenceDatasetTest(test_base.DatasetTestBase): def testRepeatTensorDataset(self): """Test a dataset that repeats its input multiple times.""" diff --git a/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py b/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py index 137f6341ce..b9f3c79da5 100644 --- a/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py @@ -17,12 +17,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import errors from tensorflow.python.platform import test -class ShardDatasetOpTest(test.TestCase): +class ShardDatasetOpTest(test_base.DatasetTestBase): def testSimpleCase(self): dataset = dataset_ops.Dataset.range(10).shard(5, 2) diff --git a/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py index f294840706..347af18576 100644 --- a/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py @@ -21,6 +21,7 @@ import collections import numpy as np +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import constant_op @@ -30,7 +31,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class ShuffleDatasetTest(test.TestCase): +class ShuffleDatasetTest(test_base.DatasetTestBase): def testShuffleDataset(self): components = ( diff --git a/tensorflow/python/data/kernel_tests/test_base.py b/tensorflow/python/data/kernel_tests/test_base.py new file mode 100644 index 0000000000..b73a94e683 --- /dev/null +++ b/tensorflow/python/data/kernel_tests/test_base.py @@ -0,0 +1,138 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test utilities for tf.data functionality.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import re + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import nest +from tensorflow.python.eager import context +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class DatasetTestBase(test.TestCase): + """Base class for dataset tests.""" + + def assertSparseValuesEqual(self, a, b): + """Asserts that two SparseTensors/SparseTensorValues are equal.""" + self.assertAllEqual(a.indices, b.indices) + self.assertAllEqual(a.values, b.values) + self.assertAllEqual(a.dense_shape, b.dense_shape) + + def getNext(self, dataset): + """Returns a callable that returns the next element of the dataset. + + Example use: + ```python + # In both graph and eager modes + dataset = ... + nxt = self.getNext(dataset) + result = self.evaluate(nxt()) + ``` + + Args: + dataset: A dataset whose next element is returned + + Returns: + A callable that returns the next element of `dataset` + """ + it = dataset.make_one_shot_iterator() + if context.executing_eagerly(): + return it.get_next + else: + nxt = it.get_next() + return lambda: nxt + + def assertDatasetsEqual(self, dataset1, dataset2): + """Checks that datasets are equal. Supports both graph and eager mode.""" + self.assertEqual(dataset1.output_types, dataset2.output_types) + self.assertEqual(dataset1.output_classes, dataset2.output_classes) + + next1 = self.getNext(dataset1) + next2 = self.getNext(dataset2) + while True: + try: + op1 = self.evaluate(next1()) + except errors.OutOfRangeError: + with self.assertRaises(errors.OutOfRangeError): + self.evaluate(next2()) + break + op2 = self.evaluate(next2()) + + op1 = nest.flatten(op1) + op2 = nest.flatten(op2) + assert len(op1) == len(op2) + for i in range(len(op1)): + if isinstance( + op1[i], + (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)): + self.assertSparseValuesEqual(op1[i], op2[i]) + else: + self.assertAllEqual(op1[i], op2[i]) + + def assertDatasetsRaiseSameError(self, + dataset1, + dataset2, + exception_class, + replacements=None): + """Checks that datasets raise the same error on the first get_next call.""" + next1 = self.getNext(dataset1) + next2 = self.getNext(dataset2) + try: + self.evaluate(next1()) + raise ValueError( + 'Expected dataset to raise an error of type %s, but it did not.' % + repr(exception_class)) + except exception_class as e: + expected_message = e.message + for old, new, count in replacements: + expected_message = expected_message.replace(old, new, count) + # Check that the first segment of the error messages are the same. + with self.assertRaisesRegexp(exception_class, + re.escape(expected_message)): + self.evaluate(next2()) + + def structuredDataset(self, structure, shape=None, dtype=dtypes.int64): + """Returns a singleton dataset with the given structure.""" + if shape is None: + shape = [] + if structure is None: + return dataset_ops.Dataset.from_tensors( + array_ops.zeros(shape, dtype=dtype)) + else: + return dataset_ops.Dataset.zip( + tuple([ + self.structuredDataset(substructure, shape, dtype) + for substructure in structure + ])) + + def structuredElement(self, structure, shape=None, dtype=dtypes.int64): + """Returns an element with the given structure.""" + if shape is None: + shape = [] + if structure is None: + return array_ops.zeros(shape, dtype=dtype) + else: + return tuple([ + self.structuredElement(substructure, shape, dtype) + for substructure in structure + ]) diff --git a/tensorflow/python/data/kernel_tests/window_dataset_op_test.py b/tensorflow/python/data/kernel_tests/window_dataset_op_test.py index fd4348426d..9d06781094 100644 --- a/tensorflow/python/data/kernel_tests/window_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/window_dataset_op_test.py @@ -20,6 +20,7 @@ from __future__ import print_function from absl.testing import parameterized import numpy as np +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -29,7 +30,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -class WindowDatasetTest(test.TestCase, parameterized.TestCase): +class WindowDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): @parameterized.named_parameters( ("1", 20, 14, 7, 1), @@ -150,11 +151,6 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): stride_t: stride }) - def assertSparseValuesEqual(self, a, b): - self.assertAllEqual(a.indices, b.indices) - self.assertAllEqual(a.values, b.values) - self.assertAllEqual(a.dense_shape, b.dense_shape) - def testWindowSparse(self): def _sparse(i): diff --git a/tensorflow/python/data/kernel_tests/zip_dataset_op_test.py b/tensorflow/python/data/kernel_tests/zip_dataset_op_test.py index 3106effbd3..9d76387a34 100644 --- a/tensorflow/python/data/kernel_tests/zip_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/zip_dataset_op_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -26,7 +27,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class ZipDatasetTest(test.TestCase): +class ZipDatasetTest(test_base.DatasetTestBase): def testZipDataset(self): component_placeholders = [ |