aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/data/kernel_tests
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/data/kernel_tests')
-rw-r--r--tensorflow/python/data/kernel_tests/BUILD327
-rw-r--r--tensorflow/python/data/kernel_tests/batch_dataset_op_test.py10
-rw-r--r--tensorflow/python/data/kernel_tests/cache_dataset_op_test.py5
-rw-r--r--tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py8
-rw-r--r--tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/dataset_ops_test.py159
-rw-r--r--tensorflow/python/data/kernel_tests/filter_dataset_op_test.py10
-rw-r--r--tensorflow/python/data/kernel_tests/flat_map_dataset_op_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/inputs_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/map_dataset_op_test.py134
-rw-r--r--tensorflow/python/data/kernel_tests/multi_device_iterator_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/optional_ops_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/range_dataset_op_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py7
-rw-r--r--tensorflow/python/data/kernel_tests/reduce_dataset_op_test.py124
-rw-r--r--tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/shard_dataset_op_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/test_base.py138
-rw-r--r--tensorflow/python/data/kernel_tests/window_dataset_op_test.py8
-rw-r--r--tensorflow/python/data/kernel_tests/zip_dataset_op_test.py3
25 files changed, 760 insertions, 212 deletions
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD
index 7a6f03d4d3..c7295d6e69 100644
--- a/tensorflow/python/data/kernel_tests/BUILD
+++ b/tensorflow/python/data/kernel_tests/BUILD
@@ -15,6 +15,7 @@ tf_py_test(
size = "small",
srcs = ["batch_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
"@absl_py//absl/testing:parameterized",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
@@ -31,10 +32,44 @@ tf_py_test(
)
tf_py_test(
+ name = "cache_dataset_op_test",
+ size = "small",
+ srcs = ["cache_dataset_op_test.py"],
+ additional_deps = [
+ ":test_base",
+ "//third_party/py/numpy",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
+ ],
+)
+
+tf_py_test(
+ name = "concatenate_dataset_op_test",
+ size = "small",
+ srcs = ["concatenate_dataset_op_test.py"],
+ additional_deps = [
+ ":test_base",
+ "//third_party/py/numpy",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ ],
+)
+
+tf_py_test(
name = "dataset_constructor_op_test",
size = "small",
srcs = ["dataset_constructor_op_test.py"],
additional_deps = [
+ ":test_base",
"//third_party/py/numpy",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
@@ -63,6 +98,7 @@ tf_py_test(
size = "medium",
srcs = ["dataset_from_generator_op_test.py"],
additional_deps = [
+ ":test_base",
"//third_party/py/numpy",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
@@ -78,8 +114,11 @@ tf_py_test(
size = "small",
srcs = ["dataset_ops_test.py"],
additional_deps = [
- "//tensorflow/core:protos_all_py",
+ ":test_base",
+ "@absl_py//absl/testing:parameterized",
+ "//third_party/py/numpy",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python:sparse_tensor",
"//tensorflow/python/data/ops:dataset_ops",
],
)
@@ -89,6 +128,7 @@ tf_py_test(
size = "small",
srcs = ["filter_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
@@ -106,6 +146,7 @@ tf_py_test(
size = "small",
srcs = ["flat_map_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
"//third_party/py/numpy",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
@@ -123,6 +164,7 @@ tf_py_test(
size = "small",
srcs = ["list_files_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
@@ -137,6 +179,7 @@ tf_py_test(
size = "small",
srcs = ["interleave_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
"@absl_py//absl/testing:parameterized",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
@@ -151,11 +194,80 @@ tf_py_test(
],
)
+cuda_py_test(
+ name = "iterator_ops_test",
+ size = "small",
+ srcs = ["iterator_ops_test.py"],
+ additional_deps = [
+ "//third_party/py/numpy",
+ "//tensorflow/python/data/ops:readers",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python/data/util:sparse",
+ "//tensorflow/python/eager:context",
+ "//tensorflow/python/training/checkpointable:util",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:function",
+ "//tensorflow/python:functional_ops",
+ "//tensorflow/python:gradients",
+ "//tensorflow/python:io_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:random_ops",
+ "//tensorflow/python:script_ops",
+ "//tensorflow/python:session",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python:training",
+ "//tensorflow/python/compat:compat",
+ "//tensorflow/python:util",
+ "//tensorflow/python:variables",
+ ],
+ grpc_enabled = True,
+)
+
+tf_py_test(
+ name = "iterator_ops_cluster_test",
+ size = "small",
+ srcs = ["iterator_ops_cluster_test.py"],
+ additional_deps = [
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:function",
+ "//tensorflow/python:functional_ops",
+ "//tensorflow/python:session",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:string_ops",
+ "//tensorflow/python:lookup_ops",
+ ],
+ grpc_enabled = True,
+ tags = [
+ "no_oss", # Test flaky due to port collisions.
+ "no_windows",
+ ],
+)
+
tf_py_test(
name = "map_dataset_op_test",
size = "small",
srcs = ["map_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
"@absl_py//absl/testing:parameterized",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
@@ -177,11 +289,54 @@ tf_py_test(
],
)
+cuda_py_test(
+ name = "multi_device_iterator_test",
+ size = "medium",
+ srcs = ["multi_device_iterator_test.py"],
+ additional_deps = [
+ ":test_base",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:multi_device_iterator_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_test_lib",
+ ],
+ tags = [
+ "no_windows_gpu",
+ ],
+)
+
+cuda_py_test(
+ name = "optional_ops_test",
+ size = "small",
+ srcs = ["optional_ops_test.py"],
+ additional_deps = [
+ ":test_base",
+ "@absl_py//absl/testing:parameterized",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python/data/ops:optional_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:tensor_shape",
+ ],
+)
+
tf_py_test(
name = "prefetch_dataset_op_test",
size = "small",
srcs = ["prefetch_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
"@absl_py//absl/testing:parameterized",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
@@ -197,6 +352,7 @@ tf_py_test(
size = "small",
srcs = ["range_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dataset_ops_gen",
@@ -218,6 +374,7 @@ tf_py_test(
size = "small",
srcs = ["reader_dataset_ops_test.py"],
additional_deps = [
+ ":test_base",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
@@ -236,32 +393,35 @@ tf_py_test(
)
tf_py_test(
- name = "sequence_dataset_op_test",
+ name = "reduce_dataset_op_test",
size = "small",
- srcs = ["sequence_dataset_op_test.py"],
+ srcs = ["reduce_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
+ "@absl_py//absl/testing:parameterized",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:sparse_tensor",
"//tensorflow/python/data/ops:dataset_ops",
],
)
tf_py_test(
- name = "shuffle_dataset_op_test",
+ name = "sequence_dataset_op_test",
size = "small",
- srcs = ["shuffle_dataset_op_test.py"],
+ srcs = ["sequence_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/ops:iterator_ops",
],
)
@@ -270,6 +430,7 @@ tf_py_test(
size = "small",
srcs = ["shard_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python/data/ops:dataset_ops",
@@ -277,155 +438,30 @@ tf_py_test(
)
tf_py_test(
- name = "cache_dataset_op_test",
+ name = "shuffle_dataset_op_test",
size = "small",
- srcs = ["cache_dataset_op_test.py"],
+ srcs = ["shuffle_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
- "//tensorflow/python:variables",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:iterator_ops",
],
)
-tf_py_test(
- name = "zip_dataset_op_test",
- size = "small",
- srcs = ["zip_dataset_op_test.py"],
- additional_deps = [
- "//third_party/py/numpy",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python/data/ops:dataset_ops",
- ],
-)
-
-tf_py_test(
- name = "concatenate_dataset_op_test",
- size = "small",
- srcs = ["concatenate_dataset_op_test.py"],
- additional_deps = [
- "//third_party/py/numpy",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:errors",
- "//tensorflow/python:tensor_shape",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- ],
-)
-
-cuda_py_test(
- name = "iterator_ops_test",
- size = "small",
- srcs = ["iterator_ops_test.py"],
- additional_deps = [
- "//third_party/py/numpy",
- "//tensorflow/python/data/ops:readers",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/ops:iterator_ops",
- "//tensorflow/python/data/util:sparse",
- "//tensorflow/python/eager:context",
- "//tensorflow/python/training/checkpointable:util",
- "//tensorflow/python:array_ops",
+py_library(
+ name = "test_base",
+ srcs = ["test_base.py"],
+ deps = [
"//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dataset_ops_gen",
- "//tensorflow/python:dtypes",
"//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:function",
- "//tensorflow/python:functional_ops",
- "//tensorflow/python:gradients",
- "//tensorflow/python:io_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:parsing_ops",
- "//tensorflow/python:random_ops",
- "//tensorflow/python:script_ops",
- "//tensorflow/python:session",
"//tensorflow/python:sparse_tensor",
- "//tensorflow/python:tensor_shape",
- "//tensorflow/python:training",
- "//tensorflow/python/compat:compat",
- "//tensorflow/python:util",
- "//tensorflow/python:variables",
- ],
- grpc_enabled = True,
-)
-
-tf_py_test(
- name = "iterator_ops_cluster_test",
- size = "small",
- srcs = ["iterator_ops_cluster_test.py"],
- additional_deps = [
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:function",
- "//tensorflow/python:functional_ops",
- "//tensorflow/python:session",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/ops:iterator_ops",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:string_ops",
- "//tensorflow/python:lookup_ops",
- ],
- grpc_enabled = True,
- tags = [
- "no_oss", # Test flaky due to port collisions.
- "no_windows",
- ],
-)
-
-cuda_py_test(
- name = "optional_ops_test",
- size = "small",
- srcs = ["optional_ops_test.py"],
- additional_deps = [
- "@absl_py//absl/testing:parameterized",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/ops:iterator_ops",
- "//tensorflow/python/data/ops:optional_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:tensor_shape",
- ],
-)
-
-cuda_py_test(
- name = "multi_device_iterator_test",
- size = "small",
- srcs = ["multi_device_iterator_test.py"],
- additional_deps = [
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/ops:multi_device_iterator_ops",
- "//tensorflow/python/data/ops:iterator_ops",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_test_lib",
- ],
- tags = [
- "no_windows_gpu",
+ "//tensorflow/python/data/util:nest",
],
)
@@ -434,6 +470,7 @@ tf_py_test(
size = "small",
srcs = ["window_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
"@absl_py//absl/testing:parameterized",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
@@ -447,14 +484,16 @@ tf_py_test(
)
tf_py_test(
- name = "inputs_test",
+ name = "zip_dataset_op_test",
size = "small",
- srcs = ["inputs_test.py"],
+ srcs = ["zip_dataset_op_test.py"],
additional_deps = [
- "@absl_py//absl/testing:parameterized",
+ ":test_base",
"//third_party/py/numpy",
+ "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
- "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
"//tensorflow/python/data/ops:dataset_ops",
],
)
diff --git a/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py b/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py
index c48708a2b9..9cb4daf284 100644
--- a/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py
@@ -24,6 +24,7 @@ from absl.testing import parameterized
import numpy as np
from tensorflow.python.client import session
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -37,7 +38,7 @@ from tensorflow.python.platform import test
from tensorflow.python.util import compat
-class BatchDatasetTest(test.TestCase, parameterized.TestCase):
+class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.named_parameters(
('even', 28, 14, False),
@@ -115,11 +116,6 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
with self.assertRaises(errors.InvalidArgumentError):
sess.run(get_next)
- def assertSparseValuesEqual(self, a, b):
- self.assertAllEqual(a.indices, b.indices)
- self.assertAllEqual(a.values, b.values)
- self.assertAllEqual(a.dense_shape, b.dense_shape)
-
def testBatchSparse(self):
def _sparse(i):
@@ -227,7 +223,7 @@ def _random_seq_lens(count):
return np.random.randint(20, size=(count,)).astype(np.int32)
-class PaddedBatchDatasetTest(test.TestCase, parameterized.TestCase):
+class PaddedBatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.named_parameters(
('default_padding', _random_seq_lens(32), 4, [-1], False),
diff --git a/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py b/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py
index d5f5b2fe05..63625fac03 100644
--- a/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py
@@ -23,6 +23,7 @@ import tempfile
import numpy as np
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import constant_op
@@ -34,7 +35,7 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import test
-class FileCacheDatasetTest(test.TestCase):
+class FileCacheDatasetTest(test_base.DatasetTestBase):
def setUp(self):
self.tmp_dir = tempfile.mkdtemp()
@@ -200,7 +201,7 @@ class FileCacheDatasetTest(test.TestCase):
self.assertAllEqual(elements, elements_itr2)
-class MemoryCacheDatasetTest(test.TestCase):
+class MemoryCacheDatasetTest(test_base.DatasetTestBase):
def testCacheDatasetPassthrough(self):
with ops.device("cpu:0"):
diff --git a/tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py b/tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py
index 5dfb84f28e..83af31f380 100644
--- a/tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import numpy as np
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.framework import errors
@@ -26,7 +27,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.platform import test
-class ConcatenateDatasetTest(test.TestCase):
+class ConcatenateDatasetTest(test_base.DatasetTestBase):
def testConcatenateDataset(self):
input_components = (
diff --git a/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py b/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py
index e43564a2eb..bc6b36285a 100644
--- a/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py
+++ b/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py
@@ -23,6 +23,7 @@ import numpy as np
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.framework import dtypes
@@ -36,7 +37,7 @@ from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
-class DatasetConstructorTest(test.TestCase):
+class DatasetConstructorTest(test_base.DatasetTestBase):
def testFromTensors(self):
"""Test a dataset that represents a single tuple of tensors."""
@@ -58,11 +59,6 @@ class DatasetConstructorTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- def assertSparseValuesEqual(self, a, b):
- self.assertAllEqual(a.indices, b.indices)
- self.assertAllEqual(a.values, b.values)
- self.assertAllEqual(a.dense_shape, b.dense_shape)
-
def testFromTensorsSparse(self):
"""Test a dataset that represents a single tuple of tensors."""
components = (sparse_tensor.SparseTensorValue(
diff --git a/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py b/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py
index cd0c1ddf1e..cb8cb9a77d 100644
--- a/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py
+++ b/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py
@@ -22,6 +22,7 @@ import threading
import numpy as np
from tensorflow.python.client import session
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -30,7 +31,7 @@ from tensorflow.python.ops import script_ops
from tensorflow.python.platform import test
-class DatasetConstructorTest(test.TestCase):
+class DatasetConstructorTest(test_base.DatasetTestBase):
def _testFromGenerator(self, generator, elem_sequence, num_repeats,
output_types=None):
diff --git a/tensorflow/python/data/kernel_tests/dataset_ops_test.py b/tensorflow/python/data/kernel_tests/dataset_ops_test.py
index 239aa85175..b9f8875b9f 100644
--- a/tensorflow/python/data/kernel_tests/dataset_ops_test.py
+++ b/tensorflow/python/data/kernel_tests/dataset_ops_test.py
@@ -18,12 +18,20 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
+import numpy as np
+
from tensorflow.core.framework import graph_pb2
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.ops import readers
+from tensorflow.python.data.util import nest
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import sparse_tensor
from tensorflow.python.platform import test
-class DatasetOpsTest(test.TestCase):
+class DatasetOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
def testAsSerializedGraph(self):
dataset = dataset_ops.Dataset.range(10)
@@ -32,6 +40,155 @@ class DatasetOpsTest(test.TestCase):
sess.run(dataset._as_serialized_graph()))
self.assertTrue(any([node.op != "RangeDataset" for node in graph.node]))
+ @staticmethod
+ def make_apply_fn(dataset):
+
+ def apply_fn(dataset):
+
+ def _apply_fn(dataset):
+ return dataset.cache()
+
+ return dataset.apply(_apply_fn)
+
+ return apply_fn
+
+ @staticmethod
+ def make_gen():
+
+ def gen():
+ yield 42
+
+ return gen
+
+ @staticmethod
+ def make_interleave_fn(dataset, num_parallel_calls=None):
+
+ def interleave_fn(dataset):
+ return dataset.interleave(
+ lambda x: dataset_ops.Dataset.range(0),
+ cycle_length=2,
+ num_parallel_calls=num_parallel_calls)
+
+ return interleave_fn
+
+ @parameterized.named_parameters(
+ ("FixedLengthRecord", readers.FixedLengthRecordDataset("", 42)),
+ ("FromGenerator",
+ dataset_ops.Dataset.from_generator(make_gen.__func__(), dtypes.int32),
+ 1),
+ ("FromSparseTensorSlices",
+ dataset_ops.Dataset.from_sparse_tensor_slices(
+ sparse_tensor.SparseTensor(
+ indices=np.array([[0, 0], [1, 0], [2, 0]]),
+ values=np.array([0, 0, 0]),
+ dense_shape=np.array([3, 1])))),
+ ("FromTensors", dataset_ops.Dataset.from_tensors([42])),
+ ("FromTensorSlices", dataset_ops.Dataset.from_tensors([42])),
+ ("Range", dataset_ops.Dataset.range(10)),
+ ("TextLine", readers.TextLineDataset("")),
+ ("TFRecord", readers.TFRecordDataset(""), 1),
+ )
+ def testDatasetSourceInputs(self, dataset, num_inputs=0):
+ self.assertEqual(num_inputs, len(dataset._inputs()))
+
+ @parameterized.named_parameters(
+ ("Apply", make_apply_fn.__func__(dataset_ops.Dataset.range(0)),
+ dataset_ops.Dataset.range(0)),
+ ("Batch", lambda x: x.batch(10), dataset_ops.Dataset.range(0)),
+ ("Cache", lambda x: x.cache(), dataset_ops.Dataset.range(0)),
+ ("Filter", lambda x: x.filter(lambda x: True),
+ dataset_ops.Dataset.range(0)),
+ ("FlatMap", lambda x: x.flat_map(lambda x: dataset_ops.Dataset.range(0)),
+ dataset_ops.Dataset.range(0)),
+ ("Interleave", make_interleave_fn.__func__(dataset_ops.Dataset.range(0)),
+ dataset_ops.Dataset.range(0)),
+ ("Map", lambda x: x.map(lambda x: x), dataset_ops.Dataset.range(0)),
+ ("PaddedBatch", lambda x: x.padded_batch(10, []),
+ dataset_ops.Dataset.range(0)),
+ ("ParallelInterleave",
+ make_interleave_fn.__func__(dataset_ops.Dataset.range(0), 2),
+ dataset_ops.Dataset.range(0)),
+ ("ParallelMap", lambda x: x.map(lambda x: x, num_parallel_calls=2),
+ dataset_ops.Dataset.range(0)),
+ ("Repeat", lambda x: x.repeat(), dataset_ops.Dataset.range(0)),
+ ("Shuffle", lambda x: x.shuffle(10), dataset_ops.Dataset.range(0)),
+ ("Skip", lambda x: x.skip(1), dataset_ops.Dataset.range(0)),
+ ("Take", lambda x: x.take(1), dataset_ops.Dataset.range(0)),
+ ("Window", lambda x: x.window(10), dataset_ops.Dataset.range(0)),
+ )
+ def testUnaryTransformationInputs(self, dataset_fn, input_dataset):
+ self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())
+
+ @parameterized.named_parameters(
+ ("Concatenate", lambda x, y: x.concatenate(y),
+ dataset_ops.Dataset.range(0), dataset_ops.Dataset.range(1)))
+ def testBinaryTransformationInputs(self, dataset_fn, input1, input2):
+ self.assertEqual([input1, input2], dataset_fn(input1, input2)._inputs())
+
+ @parameterized.named_parameters(
+ ("ZipOne", dataset_ops.Dataset.zip, (dataset_ops.Dataset.range(0))),
+ ("ZipNest", dataset_ops.Dataset.zip,
+ (dataset_ops.Dataset.range(0),
+ (dataset_ops.Dataset.range(1), dataset_ops.Dataset.range(2)))),
+ ("ZipTuple", dataset_ops.Dataset.zip,
+ (dataset_ops.Dataset.range(0), dataset_ops.Dataset.range(1))))
+ def testVariadicTransformationInputs(self, dataset_fn, input_datasets):
+ self.assertEqual(
+ nest.flatten(input_datasets),
+ dataset_fn(input_datasets)._inputs())
+
+ def testCollectInputs(self):
+ ds1 = dataset_ops.Dataset.range(0)
+ ds2 = ds1.concatenate(ds1)
+ ds3 = dataset_ops.Dataset.zip((ds2, ds1, ds2))
+
+ inputs = []
+ queue = [ds3]
+ while queue:
+ ds = queue[0]
+ queue = queue[1:]
+ queue.extend(ds._inputs())
+ inputs.append(ds)
+
+ self.assertEqual(5, inputs.count(ds1))
+ self.assertEqual(2, inputs.count(ds2))
+ self.assertEqual(1, inputs.count(ds3))
+
+ def testOptionsDefault(self):
+ ds = dataset_ops.Dataset.range(0)
+ self.assertEqual(dataset_ops.Options(), ds.options())
+
+ def testOptionsOnce(self):
+ options = dataset_ops.Options()
+ ds = dataset_ops.Dataset.range(0).with_options(options).cache()
+ self.assertEqual(options, ds.options())
+
+ def testOptionsTwiceSame(self):
+ options = dataset_ops.Options()
+ options.experimental_autotune = True
+ ds = dataset_ops.Dataset.range(0).with_options(options).with_options(
+ options)
+ self.assertEqual(options, ds.options())
+
+ def testOptionsTwiceDifferent(self):
+ options1 = dataset_ops.Options()
+ options1.experimental_autotune = True
+ options2 = dataset_ops.Options()
+ options2.experimental_filter_fusion = False
+ ds = dataset_ops.Dataset.range(0).with_options(options1).with_options(
+ options2)
+ self.assertTrue(ds.options().experimental_autotune)
+ self.assertFalse(ds.options().experimental_filter_fusion)
+
+ def testOptionsTwiceDifferentError(self):
+ options1 = dataset_ops.Options()
+ options1.experimental_autotune = True
+ options2 = dataset_ops.Options()
+ options2.experimental_autotune = False
+ with self.assertRaisesRegexp(ValueError,
+ "Cannot merge incompatible values of option"):
+ dataset_ops.Dataset.range(0).with_options(options1).with_options(options2)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py b/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py
index 19944d389f..a0c6b37a6d 100644
--- a/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py
@@ -22,6 +22,7 @@ import time
import numpy as np
from tensorflow.python.client import session
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -33,7 +34,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class FilterDatasetTest(test.TestCase):
+class FilterDatasetTest(test_base.DatasetTestBase):
def testFilterDataset(self):
components = (
@@ -129,11 +130,6 @@ class FilterDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- def assertSparseValuesEqual(self, a, b):
- self.assertAllEqual(a.indices, b.indices)
- self.assertAllEqual(a.values, b.values)
- self.assertAllEqual(a.dense_shape, b.dense_shape)
-
def testSparse(self):
def _map_fn(i):
@@ -160,7 +156,7 @@ class FilterDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- def testReturnComponent(self):
+ def testShortCircuit(self):
iterator = (
dataset_ops.Dataset.zip(
(dataset_ops.Dataset.range(10),
diff --git a/tensorflow/python/data/kernel_tests/flat_map_dataset_op_test.py b/tensorflow/python/data/kernel_tests/flat_map_dataset_op_test.py
index 1123cbff62..68038f9cfc 100644
--- a/tensorflow/python/data/kernel_tests/flat_map_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/flat_map_dataset_op_test.py
@@ -22,6 +22,7 @@ import random
import numpy as np
from tensorflow.python.client import session
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor
@@ -30,7 +31,7 @@ from tensorflow.python.platform import test
from tensorflow.python.training import server_lib
-class FlatMapDatasetTest(test.TestCase):
+class FlatMapDatasetTest(test_base.DatasetTestBase):
# pylint: disable=g-long-lambda
def testFlatMapDataset(self):
diff --git a/tensorflow/python/data/kernel_tests/inputs_test.py b/tensorflow/python/data/kernel_tests/inputs_test.py
index 4c9279dd95..d089b49bcc 100644
--- a/tensorflow/python/data/kernel_tests/inputs_test.py
+++ b/tensorflow/python/data/kernel_tests/inputs_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
from absl.testing import parameterized
import numpy as np
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers
from tensorflow.python.data.util import nest
@@ -27,7 +28,7 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.platform import test
-class InputsTest(test.TestCase, parameterized.TestCase):
+class InputsTest(test_base.DatasetTestBase, parameterized.TestCase):
@staticmethod
def make_apply_fn(dataset):
diff --git a/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py b/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py
index e7e51df65e..92bb67b6ff 100644
--- a/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py
@@ -22,6 +22,7 @@ import itertools
from absl.testing import parameterized
import numpy as np
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor
@@ -30,7 +31,7 @@ from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import test
-class InterleaveDatasetTest(test.TestCase, parameterized.TestCase):
+class InterleaveDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
def _interleave(self, lists, cycle_length, block_length):
num_open = 0
diff --git a/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py b/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py
index c4b338a58f..8eb13815d4 100644
--- a/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py
@@ -22,6 +22,7 @@ from os import path
import shutil
import tempfile
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -30,7 +31,7 @@ from tensorflow.python.platform import test
from tensorflow.python.util import compat
-class ListFilesDatasetOpTest(test.TestCase):
+class ListFilesDatasetOpTest(test_base.DatasetTestBase):
def setUp(self):
self.tmp_dir = tempfile.mkdtemp()
diff --git a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
index ae04995436..4683b1db91 100644
--- a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
+"""Tests for `tf.data.Dataset.map()`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -27,6 +27,7 @@ import numpy as np
from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.client import session
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -47,7 +48,7 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
-class MapDatasetTest(test.TestCase, parameterized.TestCase):
+class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
def _buildMapDataset(self, components, count):
def _map_fn(x, y, z):
@@ -266,6 +267,35 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
+ def testCaptureIterator(self):
+
+ def _build_ds(iterator):
+
+ def _map_fn(x):
+ get_next = iterator.get_next()
+ return x * get_next
+
+ return dataset_ops.Dataset.range(10).map(_map_fn)
+
+ def _build_graph():
+ captured_iterator = dataset_ops.Dataset.range(
+ 10).make_initializable_iterator()
+ ds = _build_ds(captured_iterator)
+ iterator = ds.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ return captured_iterator.initializer, init_op, get_next
+
+ with ops.Graph().as_default() as g:
+ captured_init_op, init_op, get_next = _build_graph()
+ with self.session(graph=g) as sess:
+ sess.run(captured_init_op)
+ sess.run(init_op)
+ for i in range(10):
+ self.assertEqual(i * i, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
def testCaptureHashTable(self):
# NOTE(mrry): We must use the V2 variants of `HashTable`
# etc. because these produce a `tf.resource`-typed output that is
@@ -574,11 +604,6 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- def assertSparseValuesEqual(self, a, b):
- self.assertAllEqual(a.indices, b.indices)
- self.assertAllEqual(a.values, b.values)
- self.assertAllEqual(a.dense_shape, b.dense_shape)
-
def testSparse(self):
def _sparse(i):
@@ -597,7 +622,7 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase):
sess.run(init_op)
for i in range(10):
actual = sess.run(get_next)
- self.assertTrue(isinstance(actual, sparse_tensor.SparseTensorValue))
+ self.assertIsInstance(actual, sparse_tensor.SparseTensorValue)
self.assertSparseValuesEqual(actual, _sparse(i))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -624,7 +649,7 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase):
sess.run(init_op)
for i in range(10):
actual = sess.run(get_next)
- self.assertTrue(isinstance(actual, sparse_tensor.SparseTensorValue))
+ self.assertIsInstance(actual, sparse_tensor.SparseTensorValue)
self.assertSparseValuesEqual(actual, _check(_sparse(i)).eval())
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -758,19 +783,72 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase):
self.assertTrue(all(tids[0] == tid for tid in tids))
# pylint: enable=g-long-lambda
+ @parameterized.named_parameters(
+ ("SequentialIdentity", None, lambda x: x, None),
+ ("SequentialReplicate", None, lambda x: (x, x), None),
+ ("SequentialSwap", (None, None), lambda x, y: (y, x), None),
+ ("SequentialProject", (None, None), lambda x, y: x, None),
+ ("ParallelIdentity", None, lambda x: x, 10),
+ ("ParallelReplicate", None, lambda x: (x, x), 10),
+ ("ParallelSwap", (None, None), lambda x, y: (y, x), 10),
+ ("ParallelProject", (None, None), lambda x, y: x, 10),
+ )
+ def testShortCircuit(self, structure, map_fn, num_parallel_calls):
+ dataset = self.structuredDataset(structure).repeat().map(
+ map_fn, num_parallel_calls=num_parallel_calls)
+ get_next = dataset.make_one_shot_iterator().get_next()
+
+ with self.cached_session() as sess:
+ if isinstance(structure, tuple):
+ expected = map_fn(*sess.run(self.structuredElement(structure)))
+ else:
+ expected = map_fn(sess.run(self.structuredElement(structure)))
+ self.assertEqual(expected, sess.run(get_next))
+
+ @parameterized.named_parameters(
+ ("Sequential", None),
+ ("Parallel", 10),
+ )
+ def testShortCircuitCapturedInput(self, num_parallel_calls):
+ captured_t = array_ops.placeholder(dtypes.int64, shape=[])
+ dataset = self.structuredDataset(None).repeat().map(
+ lambda x: captured_t, num_parallel_calls=num_parallel_calls)
+ iterator = dataset.make_initializable_iterator()
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer, feed_dict={captured_t: 42})
+ self.assertEqual(42, sess.run(get_next))
+
class MapDatasetBenchmark(test.Benchmark):
def benchmarkChainOfMaps(self):
chain_lengths = [0, 1, 2, 5, 10, 20, 50]
for chain_length in chain_lengths:
- for use_inter_op_parallelism in [False, True]:
+ for mode in ["general", "single-threaded", "short-circuit"]:
+ if mode == "general":
+ map_fn = lambda x: x + 1
+ use_inter_op_parallelism = True
+ print_label = ""
+ benchmark_label = ""
+ if mode == "single-threaded":
+ map_fn = lambda x: x + 1
+ use_inter_op_parallelism = False
+ print_label = " (single threaded mode)"
+ benchmark_label = "_single_threaded"
+ if mode == "short-circuit":
+ map_fn = lambda x: x
+ use_inter_op_parallelism = True # should not have any significance
+ print_label = " (short circuit mode)"
+ benchmark_label = "_short_circuit"
+
with ops.Graph().as_default():
dataset = dataset_ops.Dataset.from_tensors(0).repeat(None)
for _ in range(chain_length):
dataset = dataset_ops.MapDataset(
dataset,
- lambda x: x,
+ map_fn,
use_inter_op_parallelism=use_inter_op_parallelism)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
@@ -788,25 +866,39 @@ class MapDatasetBenchmark(test.Benchmark):
median_wall_time = np.median(deltas) / 100
print("Map dataset chain length%s: %d Median wall time: %f" %
- (" (single threaded mode)" if not use_inter_op_parallelism
- else "", chain_length, median_wall_time))
+ (print_label, chain_length, median_wall_time))
self.report_benchmark(
iters=1000,
wall_time=median_wall_time,
name="benchmark_map_dataset_chain_latency_%d%s" %
- (chain_length, "_single_threaded"
- if not use_inter_op_parallelism else ""))
+ (chain_length, benchmark_label))
def benchmarkMapFanOut(self):
fan_outs = [1, 2, 5, 10, 20, 50, 100]
for fan_out in fan_outs:
- for use_inter_op_parallelism in [False, True]:
+ for mode in ["general", "single-threaded", "short-circuit"]:
+ if mode == "general":
+ map_fn = lambda *xs: [x + 1 for x in xs]
+ use_inter_op_parallelism = True
+ print_label = ""
+ benchmark_label = ""
+ if mode == "single-threaded":
+ map_fn = lambda *xs: [x + 1 for x in xs]
+ use_inter_op_parallelism = False
+ print_label = " (single threaded mode)"
+ benchmark_label = "_single_threaded"
+ if mode == "short-circuit":
+ map_fn = lambda *xs: xs
+ use_inter_op_parallelism = True # should not have any significance
+ print_label = " (short circuit mode)"
+ benchmark_label = "_short_circuit"
+
with ops.Graph().as_default():
dataset = dataset_ops.Dataset.from_tensors(
tuple(0 for _ in range(fan_out))).repeat(None)
dataset = dataset_ops.MapDataset(
dataset,
- lambda *xs: xs,
+ map_fn,
use_inter_op_parallelism=use_inter_op_parallelism)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
@@ -824,14 +916,12 @@ class MapDatasetBenchmark(test.Benchmark):
median_wall_time = np.median(deltas) / 100
print("Map dataset fan out%s: %d Median wall time: %f" %
- (" (single threaded mode)" if not use_inter_op_parallelism
- else "", fan_out, median_wall_time))
+ (print_label, fan_out, median_wall_time))
self.report_benchmark(
iters=1000,
wall_time=median_wall_time,
- name="benchmark_map_dataset_fan_out_%d%s" %
- (fan_out, "_single_threaded"
- if not use_inter_op_parallelism else ""))
+ name="benchmark_map_dataset_fan_out_%d%s" % (fan_out,
+ benchmark_label))
if __name__ == "__main__":
diff --git a/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py b/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py
index 056664b83b..1cf6dd1bea 100644
--- a/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py
+++ b/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import multi_device_iterator_ops
from tensorflow.python.framework import dtypes
@@ -29,7 +30,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class MultiDeviceIteratorTest(test.TestCase):
+class MultiDeviceIteratorTest(test_base.DatasetTestBase):
def testNoGetNext(self):
dataset = dataset_ops.Dataset.range(10)
diff --git a/tensorflow/python/data/kernel_tests/optional_ops_test.py b/tensorflow/python/data/kernel_tests/optional_ops_test.py
index 706a65fe55..604e3ad88e 100644
--- a/tensorflow/python/data/kernel_tests/optional_ops_test.py
+++ b/tensorflow/python/data/kernel_tests/optional_ops_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
from absl.testing import parameterized
import numpy as np
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.ops import optional_ops
@@ -35,7 +36,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class OptionalTest(test.TestCase, parameterized.TestCase):
+class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
@test_util.run_in_graph_and_eager_modes
def testFromValue(self):
diff --git a/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py b/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py
index cc97bac609..76e2697b29 100644
--- a/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
from absl.testing import parameterized
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -26,7 +27,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class PrefetchDatasetTest(test.TestCase, parameterized.TestCase):
+class PrefetchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.parameters((-1), (0), (5))
def testBufferSize(self, buffer_size):
diff --git a/tensorflow/python/data/kernel_tests/range_dataset_op_test.py b/tensorflow/python/data/kernel_tests/range_dataset_op_test.py
index 51e90785e7..b7e2a5f615 100644
--- a/tensorflow/python/data/kernel_tests/range_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/range_dataset_op_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import os
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import dtypes
@@ -34,7 +35,7 @@ from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
-class RangeDatasetTest(test.TestCase):
+class RangeDatasetTest(test_base.DatasetTestBase):
def tearDown(self):
# Remove all checkpoint files.
diff --git a/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py b/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py
index aa3636364d..aef2dd1d9c 100644
--- a/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py
+++ b/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py
@@ -21,6 +21,7 @@ import gzip
import os
import zlib
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.ops import readers
@@ -46,7 +47,7 @@ except ImportError:
psutil_import_succeeded = False
-class TextLineDatasetTest(test.TestCase):
+class TextLineDatasetTest(test_base.DatasetTestBase):
def _lineText(self, f, l):
return compat.as_bytes("%d: %d" % (f, l))
@@ -199,7 +200,7 @@ class TextLineDatasetTest(test.TestCase):
self.assertNotIn(filename, [open_file.path for open_file in open_files])
-class FixedLengthRecordReaderTest(test.TestCase):
+class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
def setUp(self):
super(FixedLengthRecordReaderTest, self).setUp()
@@ -621,7 +622,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
sess.run(get_next_op)
-class TFRecordDatasetTest(test.TestCase):
+class TFRecordDatasetTest(test_base.DatasetTestBase):
def setUp(self):
super(TFRecordDatasetTest, self).setUp()
diff --git a/tensorflow/python/data/kernel_tests/reduce_dataset_op_test.py b/tensorflow/python/data/kernel_tests/reduce_dataset_op_test.py
new file mode 100644
index 0000000000..11e07300b9
--- /dev/null
+++ b/tensorflow/python/data/kernel_tests/reduce_dataset_op_test.py
@@ -0,0 +1,124 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class ReduceDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
+
+ def testSum(self):
+ for i in range(10):
+ ds = dataset_ops.Dataset.range(1, i + 1)
+ result = ds.reduce(np.int64(0), lambda x, y: x + y)
+ with self.cached_session() as sess:
+ self.assertEqual(((i + 1) * i) // 2, sess.run(result))
+
+ def testSumTuple(self):
+
+ def reduce_fn(state, value):
+ v1, v2 = value
+ return state + v1 + v2
+
+ for i in range(10):
+ ds = dataset_ops.Dataset.range(1, i + 1)
+ ds = dataset_ops.Dataset.zip((ds, ds))
+ result = ds.reduce(np.int64(0), reduce_fn)
+ with self.cached_session() as sess:
+ self.assertEqual(((i + 1) * i), sess.run(result))
+
+ def testSumAndCount(self):
+
+ def reduce_fn(state, value):
+ s, c = state
+ return s + value, c + 1
+
+ for i in range(10):
+ ds = dataset_ops.Dataset.range(1, i + 1)
+ result = ds.reduce((np.int64(0), np.int64(0)), reduce_fn)
+ with self.cached_session() as sess:
+ s, c = sess.run(result)
+ self.assertEqual(((i + 1) * i) // 2, s)
+ self.assertEqual(i, c)
+
+ def testSquareUsingPlaceholder(self):
+ delta = array_ops.placeholder(dtype=dtypes.int64)
+
+ def reduce_fn(state, _):
+ return state + delta
+
+ for i in range(10):
+ ds = dataset_ops.Dataset.range(1, i + 1)
+ result = ds.reduce(np.int64(0), reduce_fn)
+ with self.cached_session() as sess:
+ square = sess.run(result, feed_dict={delta: i})
+ self.assertEqual(i * i, square)
+
+ def testSparse(self):
+
+ def reduce_fn(_, value):
+ return value
+
+ def make_sparse_fn(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=np.array([[0, 0]]),
+ values=(i * np.array([1])),
+ dense_shape=np.array([1, 1]))
+
+ for i in range(10):
+ ds = dataset_ops.Dataset.from_tensors(make_sparse_fn(i+1))
+ result = ds.reduce(make_sparse_fn(0), reduce_fn)
+ with self.cached_session() as sess:
+ self.assertSparseValuesEqual(make_sparse_fn(i+1), sess.run(result))
+
+ def testNested(self):
+
+ def reduce_fn(state, value):
+ state["dense"] += value["dense"]
+ state["sparse"] = value["sparse"]
+ return state
+
+ def make_sparse_fn(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=np.array([[0, 0]]),
+ values=(i * np.array([1])),
+ dense_shape=np.array([1, 1]))
+
+ def map_fn(i):
+ return {"dense": math_ops.cast(i, dtype=dtypes.int64),
+ "sparse": make_sparse_fn(math_ops.cast(i, dtype=dtypes.int64))}
+
+ for i in range(10):
+ ds = dataset_ops.Dataset.range(1, i + 1).map(map_fn)
+ result = ds.reduce(map_fn(0), reduce_fn)
+ with self.cached_session() as sess:
+ result = sess.run(result)
+ self.assertEqual(((i + 1) * i) // 2, result["dense"])
+ self.assertSparseValuesEqual(make_sparse_fn(i), result["sparse"])
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py b/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py
index 37e2333560..e86356dee7 100644
--- a/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import numpy as np
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -26,7 +27,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class SequenceDatasetTest(test.TestCase):
+class SequenceDatasetTest(test_base.DatasetTestBase):
def testRepeatTensorDataset(self):
"""Test a dataset that repeats its input multiple times."""
diff --git a/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py b/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py
index 137f6341ce..b9f3c79da5 100644
--- a/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py
@@ -17,12 +17,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.platform import test
-class ShardDatasetOpTest(test.TestCase):
+class ShardDatasetOpTest(test_base.DatasetTestBase):
def testSimpleCase(self):
dataset = dataset_ops.Dataset.range(10).shard(5, 2)
diff --git a/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py
index f294840706..347af18576 100644
--- a/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py
@@ -21,6 +21,7 @@ import collections
import numpy as np
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import constant_op
@@ -30,7 +31,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class ShuffleDatasetTest(test.TestCase):
+class ShuffleDatasetTest(test_base.DatasetTestBase):
def testShuffleDataset(self):
components = (
diff --git a/tensorflow/python/data/kernel_tests/test_base.py b/tensorflow/python/data/kernel_tests/test_base.py
new file mode 100644
index 0000000000..b73a94e683
--- /dev/null
+++ b/tensorflow/python/data/kernel_tests/test_base.py
@@ -0,0 +1,138 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Test utilities for tf.data functionality."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import re
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import nest
+from tensorflow.python.eager import context
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class DatasetTestBase(test.TestCase):
+ """Base class for dataset tests."""
+
+ def assertSparseValuesEqual(self, a, b):
+ """Asserts that two SparseTensors/SparseTensorValues are equal."""
+ self.assertAllEqual(a.indices, b.indices)
+ self.assertAllEqual(a.values, b.values)
+ self.assertAllEqual(a.dense_shape, b.dense_shape)
+
+ def getNext(self, dataset):
+ """Returns a callable that returns the next element of the dataset.
+
+ Example use:
+ ```python
+ # In both graph and eager modes
+ dataset = ...
+ nxt = self.getNext(dataset)
+ result = self.evaluate(nxt())
+ ```
+
+ Args:
+ dataset: A dataset whose next element is returned
+
+ Returns:
+ A callable that returns the next element of `dataset`
+ """
+ it = dataset.make_one_shot_iterator()
+ if context.executing_eagerly():
+ return it.get_next
+ else:
+ nxt = it.get_next()
+ return lambda: nxt
+
+ def assertDatasetsEqual(self, dataset1, dataset2):
+ """Checks that datasets are equal. Supports both graph and eager mode."""
+ self.assertEqual(dataset1.output_types, dataset2.output_types)
+ self.assertEqual(dataset1.output_classes, dataset2.output_classes)
+
+ next1 = self.getNext(dataset1)
+ next2 = self.getNext(dataset2)
+ while True:
+ try:
+ op1 = self.evaluate(next1())
+ except errors.OutOfRangeError:
+ with self.assertRaises(errors.OutOfRangeError):
+ self.evaluate(next2())
+ break
+ op2 = self.evaluate(next2())
+
+ op1 = nest.flatten(op1)
+ op2 = nest.flatten(op2)
+ assert len(op1) == len(op2)
+ for i in range(len(op1)):
+ if isinstance(
+ op1[i],
+ (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
+ self.assertSparseValuesEqual(op1[i], op2[i])
+ else:
+ self.assertAllEqual(op1[i], op2[i])
+
+ def assertDatasetsRaiseSameError(self,
+ dataset1,
+ dataset2,
+ exception_class,
+ replacements=None):
+ """Checks that datasets raise the same error on the first get_next call."""
+ next1 = self.getNext(dataset1)
+ next2 = self.getNext(dataset2)
+ try:
+ self.evaluate(next1())
+ raise ValueError(
+ 'Expected dataset to raise an error of type %s, but it did not.' %
+ repr(exception_class))
+ except exception_class as e:
+ expected_message = e.message
+ for old, new, count in replacements:
+ expected_message = expected_message.replace(old, new, count)
+ # Check that the first segment of the error messages are the same.
+ with self.assertRaisesRegexp(exception_class,
+ re.escape(expected_message)):
+ self.evaluate(next2())
+
+ def structuredDataset(self, structure, shape=None, dtype=dtypes.int64):
+ """Returns a singleton dataset with the given structure."""
+ if shape is None:
+ shape = []
+ if structure is None:
+ return dataset_ops.Dataset.from_tensors(
+ array_ops.zeros(shape, dtype=dtype))
+ else:
+ return dataset_ops.Dataset.zip(
+ tuple([
+ self.structuredDataset(substructure, shape, dtype)
+ for substructure in structure
+ ]))
+
+ def structuredElement(self, structure, shape=None, dtype=dtypes.int64):
+ """Returns an element with the given structure."""
+ if shape is None:
+ shape = []
+ if structure is None:
+ return array_ops.zeros(shape, dtype=dtype)
+ else:
+ return tuple([
+ self.structuredElement(substructure, shape, dtype)
+ for substructure in structure
+ ])
diff --git a/tensorflow/python/data/kernel_tests/window_dataset_op_test.py b/tensorflow/python/data/kernel_tests/window_dataset_op_test.py
index fd4348426d..9d06781094 100644
--- a/tensorflow/python/data/kernel_tests/window_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/window_dataset_op_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
from absl.testing import parameterized
import numpy as np
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -29,7 +30,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class WindowDatasetTest(test.TestCase, parameterized.TestCase):
+class WindowDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.named_parameters(
("1", 20, 14, 7, 1),
@@ -150,11 +151,6 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
stride_t: stride
})
- def assertSparseValuesEqual(self, a, b):
- self.assertAllEqual(a.indices, b.indices)
- self.assertAllEqual(a.values, b.values)
- self.assertAllEqual(a.dense_shape, b.dense_shape)
-
def testWindowSparse(self):
def _sparse(i):
diff --git a/tensorflow/python/data/kernel_tests/zip_dataset_op_test.py b/tensorflow/python/data/kernel_tests/zip_dataset_op_test.py
index 3106effbd3..9d76387a34 100644
--- a/tensorflow/python/data/kernel_tests/zip_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/zip_dataset_op_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import numpy as np
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -26,7 +27,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class ZipDatasetTest(test.TestCase):
+class ZipDatasetTest(test_base.DatasetTestBase):
def testZipDataset(self):
component_placeholders = [