aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/data
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-10-04 12:41:23 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-04 12:45:57 -0700
commit2c75da86ffdb9d04b2b94ce89891f17a8656da22 (patch)
treec22cdd5d9c3e324dad52d3fe5616fe52aac0fdca /tensorflow/python/data
parent900d115135656229e3667025f925eb92687dce18 (diff)
[tf.data] Clean up tests for `tf.data.experimental`.
This change splits up large test files into smaller ones, and re-enables tests that were disabled for obsolete reasons. PiperOrigin-RevId: 215785396
Diffstat (limited to 'tensorflow/python/data')
-rw-r--r--tensorflow/python/data/experimental/benchmarks/BUILD25
-rw-r--r--tensorflow/python/data/experimental/benchmarks/map_benchmark.py (renamed from tensorflow/python/data/experimental/kernel_tests/map_dataset_op_test.py)114
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/BUILD545
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/batch_dataset_op_test.py686
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/bucket_by_sequence_length_test.py322
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/bucketing_test.py824
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/copy_to_device_test.py (renamed from tensorflow/python/data/experimental/kernel_tests/prefetching_ops_test.py)417
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/counter_test.py51
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/csv_dataset_test.py (renamed from tensorflow/python/data/experimental/kernel_tests/csv_dataset_op_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/dataset_serialization_test_base.py692
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/dense_to_sparse_batch_test.py124
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/enumerate_dataset_test.py (renamed from tensorflow/python/data/experimental/kernel_tests/range_dataset_op_test.py)26
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/function_buffering_resource_test.py247
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/group_by_reducer_test.py199
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/group_by_window_test.py367
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/ignore_errors_test.py115
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/make_batched_features_dataset_test.py239
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/make_csv_dataset_test.py (renamed from tensorflow/python/data/experimental/kernel_tests/reader_dataset_ops_test.py)425
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/make_tf_record_dataset_test.py243
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py337
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/override_threadpool_test.py (renamed from tensorflow/python/data/experimental/kernel_tests/threadpool_dataset_ops_test.py)6
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/parallel_interleave_test.py (renamed from tensorflow/python/data/experimental/kernel_tests/interleave_dataset_op_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/parse_example_dataset_test.py (renamed from tensorflow/python/data/experimental/kernel_tests/parsing_ops_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py234
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/reader_dataset_ops_test_base.py4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/rejection_resample_test.py (renamed from tensorflow/python/data/experimental/kernel_tests/resample_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/restructured_dataset_test.py (renamed from tensorflow/python/data/experimental/kernel_tests/dataset_constructor_op_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/scan_test.py (renamed from tensorflow/python/data/experimental/kernel_tests/scan_dataset_op_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/BUILD22
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/checkpoint_input_pipeline_hook_test.py (renamed from tensorflow/python/data/experimental/kernel_tests/iterator_ops_test.py)0
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/parse_example_dataset_serialization_test.py2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/sql_dataset_serialization_test.py4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization_integration_test.py85
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/shuffle_and_repeat_test.py (renamed from tensorflow/python/data/experimental/kernel_tests/shuffle_dataset_op_test.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/sql_dataset_test.py (renamed from tensorflow/python/data/experimental/kernel_tests/sql_dataset_op_test.py)6
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/sql_dataset_test_base.py (renamed from tensorflow/python/data/experimental/kernel_tests/sql_dataset_op_test_base.py)3
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/tf_record_writer_test.py (renamed from tensorflow/python/data/experimental/kernel_tests/writer_ops_test.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/unbatch_test.py300
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/unique_test.py (renamed from tensorflow/python/data/experimental/kernel_tests/unique_dataset_op_test.py)4
-rw-r--r--tensorflow/python/data/kernel_tests/map_dataset_op_test.py31
41 files changed, 3172 insertions, 3557 deletions
diff --git a/tensorflow/python/data/experimental/benchmarks/BUILD b/tensorflow/python/data/experimental/benchmarks/BUILD
new file mode 100644
index 0000000000..b9398aebe7
--- /dev/null
+++ b/tensorflow/python/data/experimental/benchmarks/BUILD
@@ -0,0 +1,25 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "cuda_py_test")
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+py_test(
+ name = "map_benchmark",
+ size = "medium",
+ srcs = ["map_benchmark.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:session",
+ "//tensorflow/python/data/experimental/ops:batching",
+ "//tensorflow/python/data/experimental/ops:optimization",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
diff --git a/tensorflow/python/data/experimental/kernel_tests/map_dataset_op_test.py b/tensorflow/python/data/experimental/benchmarks/map_benchmark.py
index 2f0bd1456b..ad253cffa5 100644
--- a/tensorflow/python/data/experimental/kernel_tests/map_dataset_op_test.py
+++ b/tensorflow/python/data/experimental/benchmarks/map_benchmark.py
@@ -19,7 +19,6 @@ from __future__ import print_function
import hashlib
import itertools
-import os
import time
import numpy as np
@@ -27,128 +26,15 @@ import numpy as np
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.data.experimental.ops import batching
-from tensorflow.python.data.experimental.ops import error_ops
from tensorflow.python.data.experimental.ops import optimization
-from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import io_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-from tensorflow.python.util import compat
_NUMPY_RANDOM_SEED = 42
-class MapDatasetTest(test_base.DatasetTestBase):
-
- def testMapIgnoreError(self):
- components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
-
- dataset = (
- dataset_ops.Dataset.from_tensor_slices(components)
- .map(lambda x: array_ops.check_numerics(x, "message")).apply(
- error_ops.ignore_errors()))
- iterator = dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
- for x in [1., 2., 3., 5.]:
- self.assertEqual(x, sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testParallelMapIgnoreError(self):
- components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
-
- dataset = (
- dataset_ops.Dataset.from_tensor_slices(components).map(
- lambda x: array_ops.check_numerics(x, "message"),
- num_parallel_calls=2).prefetch(2).apply(error_ops.ignore_errors()))
- iterator = dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
- for x in [1., 2., 3., 5.]:
- self.assertEqual(x, sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testReadFileIgnoreError(self):
-
- def write_string_to_file(value, filename):
- with open(filename, "w") as f:
- f.write(value)
-
- filenames = [
- os.path.join(self.get_temp_dir(), "file_%d.txt" % i) for i in range(5)
- ]
- for filename in filenames:
- write_string_to_file(filename, filename)
-
- dataset = (
- dataset_ops.Dataset.from_tensor_slices(filenames).map(
- io_ops.read_file,
- num_parallel_calls=2).prefetch(2).apply(error_ops.ignore_errors()))
- iterator = dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- # All of the files are present.
- sess.run(init_op)
- for filename in filenames:
- self.assertEqual(compat.as_bytes(filename), sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Delete one of the files.
- os.remove(filenames[0])
-
- # Attempting to read filenames[0] will fail, but ignore_errors()
- # will catch the error.
- sess.run(init_op)
- for filename in filenames[1:]:
- self.assertEqual(compat.as_bytes(filename), sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testCaptureResourceInMapFn(self):
-
- def _build_ds(iterator):
-
- def _map_fn(x):
- get_next = iterator.get_next()
- return x * get_next
-
- return dataset_ops.Dataset.range(10).map(_map_fn)
-
- def _build_graph():
- captured_iterator = dataset_ops.Dataset.range(
- 10).make_initializable_iterator()
- ds = _build_ds(captured_iterator)
- iterator = ds.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
- return captured_iterator.initializer, init_op, get_next
-
- with ops.Graph().as_default() as g:
- captured_init_op, init_op, get_next = _build_graph()
- with self.session(graph=g) as sess:
- sess.run(captured_init_op)
- sess.run(init_op)
- for i in range(10):
- self.assertEquals(i * i, sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
-
class MapDatasetBenchmark(test.Benchmark):
# The purpose of this benchmark is to compare the performance of chaining vs
diff --git a/tensorflow/python/data/experimental/kernel_tests/BUILD b/tensorflow/python/data/experimental/kernel_tests/BUILD
index f56127f3ef..4eef9580ad 100644
--- a/tensorflow/python/data/experimental/kernel_tests/BUILD
+++ b/tensorflow/python/data/experimental/kernel_tests/BUILD
@@ -8,75 +8,62 @@ load("//tensorflow:tensorflow.bzl", "cuda_py_test")
load("//tensorflow:tensorflow.bzl", "py_test")
py_test(
- name = "batch_dataset_op_test",
+ name = "bucket_by_sequence_length_test",
size = "medium",
- srcs = ["batch_dataset_op_test.py"],
+ srcs = ["bucket_by_sequence_length_test.py"],
srcs_version = "PY2AND3",
- tags = [
- "no_oss", # (b/79552534)
- "no_pip",
- "no_windows",
- ],
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:script_ops",
- "//tensorflow/python:session",
"//tensorflow/python:sparse_tensor",
- "//tensorflow/python:string_ops",
"//tensorflow/python:tensor_shape",
- "//tensorflow/python:util",
- "//tensorflow/python/data/experimental/ops:batching",
+ "//tensorflow/python/data/experimental/ops:grouping",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
- "@absl_py//absl/testing:parameterized",
],
)
+cuda_py_test(
+ name = "copy_to_device_test",
+ size = "small",
+ srcs = ["copy_to_device_test.py"],
+ additional_deps = [
+ "//tensorflow/python/data/experimental/ops:prefetching_ops",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python/compat:compat",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
+ ],
+ tags = ["no_windows_gpu"],
+)
+
py_test(
- name = "bucketing_test",
- size = "medium",
- srcs = ["bucketing_test.py"],
+ name = "counter_test",
+ size = "small",
+ srcs = ["counter_test.py"],
srcs_version = "PY2AND3",
- tags = [
- "no_oss",
- "no_pip",
- "no_windows",
- ],
deps = [
- "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python:string_ops",
- "//tensorflow/python:tensor_shape",
- "//tensorflow/python/data/experimental/ops:grouping",
+ "//tensorflow/python/data/experimental/ops:counter",
"//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
],
)
py_test(
- name = "csv_dataset_op_test",
+ name = "csv_dataset_test",
size = "medium",
- srcs = ["csv_dataset_op_test.py"],
+ srcs = ["csv_dataset_test.py"],
srcs_version = "PY2AND3",
- tags = [
- "no_oss",
- "no_pip",
- "no_windows",
- ],
+ tags = ["no_pip"],
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
@@ -97,25 +84,18 @@ py_test(
)
py_test(
- name = "dataset_constructor_op_test",
- size = "medium",
- srcs = ["dataset_constructor_op_test.py"],
+ name = "dense_to_sparse_batch_test",
+ srcs = ["dense_to_sparse_batch_test.py"],
srcs_version = "PY2AND3",
- tags = [
- "manual",
- "no_oss",
- "no_pip",
- "no_windows",
- "nomac", # b/62040583
- ],
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
"//tensorflow/python/data/experimental/ops:batching",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
+ "//third_party/py/numpy",
],
)
@@ -124,11 +104,6 @@ py_test(
size = "medium",
srcs = ["directed_interleave_dataset_test.py"],
srcs_version = "PY2AND3",
- tags = [
- "no_oss",
- "no_pip",
- "no_windows",
- ],
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
@@ -141,14 +116,67 @@ py_test(
)
py_test(
+ name = "enumerate_dataset_test",
+ size = "small",
+ srcs = ["enumerate_dataset_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/experimental/ops:enumerate_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "filter_dataset_op_test",
+ size = "medium",
+ srcs = ["filter_dataset_op_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:io_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/experimental/ops:optimization",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+cuda_py_test(
+ name = "function_buffering_resource_test",
+ size = "small",
+ srcs = ["function_buffering_resource_test.py"],
+ additional_deps = [
+ "//tensorflow/python/data/experimental/ops:prefetching_ops",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:function",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
+ ],
+ tags = ["no_windows_gpu"],
+)
+
+py_test(
name = "get_single_element_test",
size = "small",
srcs = ["get_single_element_test.py"],
- tags = [
- "no_oss",
- "no_pip",
- "no_windows",
- ],
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
@@ -165,19 +193,20 @@ py_test(
)
py_test(
- name = "indexed_dataset_ops_test",
- srcs = ["indexed_dataset_ops_test.py"],
- tags = [
- "no_oss",
- "no_pip",
- "no_windows",
- ],
+ name = "group_by_reducer_test",
+ size = "medium",
+ srcs = ["group_by_reducer_test.py"],
+ srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
- "//tensorflow/python:experimental_dataset_ops_gen",
- "//tensorflow/python/data/experimental/ops:indexed_dataset_ops",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/experimental/ops:grouping",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
@@ -185,107 +214,134 @@ py_test(
)
py_test(
- name = "interleave_dataset_op_test",
+ name = "group_by_window_test",
size = "medium",
- srcs = ["interleave_dataset_op_test.py"],
+ srcs = ["group_by_window_test.py"],
srcs_version = "PY2AND3",
- tags = [
- "no_oss",
- "no_pip",
- "no_windows",
- "notap",
- ],
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
- "//tensorflow/python:script_ops",
- "//tensorflow/python:sparse_ops",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python/data/experimental/ops:interleave_ops",
+ "//tensorflow/python:string_ops",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/experimental/ops:grouping",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
- "@six_archive//:six",
+ "//third_party/py/numpy",
],
)
py_test(
- name = "iterator_ops_test",
- size = "small",
- srcs = ["iterator_ops_test.py"],
+ name = "ignore_errors_test",
+ srcs = ["ignore_errors_test.py"],
srcs_version = "PY2AND3",
- tags = [
- "no_oss",
- "no_pip",
- "no_windows",
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:io_ops",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/experimental/ops:error_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
],
+)
+
+py_test(
+ name = "indexed_dataset_ops_test",
+ srcs = ["indexed_dataset_ops_test.py"],
deps = [
+ "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:training",
- "//tensorflow/python:variables",
- "//tensorflow/python/data/experimental/ops:iterator_ops",
+ "//tensorflow/python:experimental_dataset_ops_gen",
+ "//tensorflow/python/data/experimental/ops:indexed_dataset_ops",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/estimator:estimator_py",
+ "//third_party/py/numpy",
],
)
py_test(
- name = "map_dataset_op_test",
+ name = "make_batched_features_dataset_test",
size = "medium",
- srcs = ["map_dataset_op_test.py"],
+ srcs = ["make_batched_features_dataset_test.py"],
srcs_version = "PY2AND3",
- tags = [
- "no_oss",
- "no_pip",
- "no_windows",
- "noasan", # times out
- "optonly",
+ tags = ["no_pip"],
+ deps = [
+ ":reader_dataset_ops_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python/data/experimental/ops:readers",
+ "//tensorflow/python/data/ops:readers",
+ "//tensorflow/python/data/util:nest",
+ "//third_party/py/numpy",
],
+)
+
+py_test(
+ name = "make_csv_dataset_test",
+ size = "medium",
+ srcs = ["make_csv_dataset_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
deps = [
- "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
- "//tensorflow/python:io_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:util",
- "//tensorflow/python/data/experimental/ops:batching",
- "//tensorflow/python/data/experimental/ops:error_ops",
- "//tensorflow/python/data/experimental/ops:optimization",
+ "//tensorflow/python/data/experimental/ops:readers",
"//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
"//third_party/py/numpy",
],
)
py_test(
- name = "filter_dataset_op_test",
+ name = "make_tf_record_dataset_test",
size = "medium",
- srcs = ["filter_dataset_op_test.py"],
+ srcs = ["make_tf_record_dataset_test.py"],
srcs_version = "PY2AND3",
- tags = [
- "no_oss",
- "no_pip",
- "no_windows",
+ tags = ["no_pip"],
+ deps = [
+ ":reader_dataset_ops_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:string_ops",
+ "//tensorflow/python/data/experimental/ops:readers",
+ "//tensorflow/python/data/util:nest",
],
+)
+
+py_test(
+ name = "map_and_batch_test",
+ size = "medium",
+ srcs = ["map_and_batch_test.py"],
+ srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
"//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:io_ops",
"//tensorflow/python:math_ops",
- "//tensorflow/python:util",
- "//tensorflow/python/data/experimental/ops:optimization",
+ "//tensorflow/python:script_ops",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/experimental/ops:batching",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
],
)
@@ -294,11 +350,7 @@ py_test(
size = "small",
srcs = ["map_defun_op_test.py"],
srcs_version = "PY2AND3",
- tags = [
- "no_oss",
- "no_pip",
- "no_windows",
- ],
+ tags = ["no_pip"],
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:check_ops",
@@ -317,16 +369,57 @@ py_test(
)
py_test(
- name = "parsing_ops_test",
+ name = "override_threadpool_test",
size = "small",
- srcs = ["parsing_ops_test.py"],
+ srcs = ["override_threadpool_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:script_ops",
+ "//tensorflow/python/data/experimental/ops:threadpool",
+ "//tensorflow/python/data/experimental/ops:unique",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+py_test(
+ name = "parallel_interleave_test",
+ size = "medium",
+ srcs = ["parallel_interleave_test.py"],
srcs_version = "PY2AND3",
tags = [
"no_oss",
"no_pip",
- "no_windows",
+ "notap",
],
deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:script_ops",
+ "//tensorflow/python:sparse_ops",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/experimental/ops:interleave_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "@six_archive//:six",
+ ],
+)
+
+py_test(
+ name = "parse_example_dataset_test",
+ size = "small",
+ srcs = ["parse_example_dataset_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
"//tensorflow/core:protos_all_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
@@ -344,53 +437,20 @@ py_test(
)
cuda_py_test(
- name = "prefetching_ops_test",
+ name = "prefetch_to_device_test",
size = "small",
- srcs = ["prefetching_ops_test.py"],
+ srcs = ["prefetch_to_device_test.py"],
additional_deps = [
"//tensorflow/python/data/experimental/ops:prefetching_ops",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
- "//tensorflow/python:function",
- "//tensorflow/python:resource_variable_ops",
- "//tensorflow/python/compat:compat",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/ops:iterator_ops",
- ],
- tags = [
- "no_oss",
- "no_pip",
- "no_windows",
- "no_windows_gpu",
- ],
-)
-
-py_test(
- name = "range_dataset_op_test",
- size = "small",
- srcs = ["range_dataset_op_test.py"],
- srcs_version = "PY2AND3",
- tags = [
- "no_oss",
- "no_pip",
- "no_windows",
- ],
- deps = [
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:tensor_shape",
- "//tensorflow/python/data/experimental/ops:counter",
- "//tensorflow/python/data/experimental/ops:enumerate_ops",
- "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
],
+ tags = ["no_windows_gpu"],
)
py_library(
@@ -421,41 +481,12 @@ py_library(
)
py_test(
- name = "reader_dataset_ops_test",
- size = "medium",
- srcs = ["reader_dataset_ops_test.py"],
- srcs_version = "PY2AND3",
- tags = [
- "no_oss",
- "no_pip",
- "no_windows",
- ],
- deps = [
- ":reader_dataset_ops_test_base",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:parsing_ops",
- "//tensorflow/python:string_ops",
- "//tensorflow/python/data/experimental/ops:readers",
- "//tensorflow/python/data/ops:readers",
- "//tensorflow/python/data/util:nest",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "resample_test",
+ name = "rejection_resample_test",
size = "medium",
- srcs = ["resample_test.py"],
+ srcs = ["rejection_resample_test.py"],
shard_count = 2,
srcs_version = "PY2AND3",
tags = [
- "no_oss",
- "no_pip",
- "no_windows",
"noasan",
"optonly",
],
@@ -477,15 +508,27 @@ py_test(
)
py_test(
- name = "scan_dataset_op_test",
- size = "small",
- srcs = ["scan_dataset_op_test.py"],
+ name = "restructured_dataset_test",
+ size = "medium",
+ srcs = ["restructured_dataset_test.py"],
srcs_version = "PY2AND3",
- tags = [
- "no_oss",
- "no_pip",
- "no_windows",
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python/data/experimental/ops:batching",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
],
+)
+
+py_test(
+ name = "scan_test",
+ size = "small",
+ srcs = ["scan_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
@@ -503,14 +546,12 @@ py_test(
)
py_test(
- name = "shuffle_dataset_op_test",
+ name = "shuffle_and_repeat_test",
size = "medium",
- srcs = ["shuffle_dataset_op_test.py"],
+ srcs = ["shuffle_and_repeat_test.py"],
srcs_version = "PY2AND3",
tags = [
- "no_oss",
"no_pip",
- "no_windows",
"optonly",
],
deps = [
@@ -525,8 +566,8 @@ py_test(
)
py_library(
- name = "sql_dataset_op_test_base",
- srcs = ["sql_dataset_op_test_base.py"],
+ name = "sql_dataset_test_base",
+ srcs = ["sql_dataset_test_base.py"],
srcs_version = "PY2AND3",
visibility = [
"//tensorflow/python/data/experimental/kernel_tests:__pkg__",
@@ -543,17 +584,13 @@ py_library(
)
py_test(
- name = "sql_dataset_op_test",
+ name = "sql_dataset_test",
size = "small",
- srcs = ["sql_dataset_op_test.py"],
+ srcs = ["sql_dataset_test.py"],
srcs_version = "PY2AND3",
- tags = [
- "no_oss",
- "no_pip",
- "no_windows",
- ],
+ tags = ["no_pip"],
deps = [
- ":sql_dataset_op_test_base",
+ ":sql_dataset_test_base",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
@@ -565,11 +602,7 @@ py_test(
size = "medium",
srcs = ["stats_dataset_ops_test.py"],
srcs_version = "PY2AND3",
- tags = [
- "no_oss",
- "no_pip",
- "no_windows",
- ],
+ tags = ["no_pip"],
deps = [
":reader_dataset_ops_test_base",
":stats_dataset_test_base",
@@ -595,68 +628,60 @@ py_library(
)
py_test(
- name = "threadpool_dataset_ops_test",
+ name = "tf_record_writer_test",
size = "small",
- srcs = ["threadpool_dataset_ops_test.py"],
- srcs_version = "PY2AND3",
- tags = [
- "no_oss",
- "no_pip",
- "no_windows",
- ],
+ srcs = ["tf_record_writer_test.py"],
deps = [
+ "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:script_ops",
- "//tensorflow/python/data/experimental/ops:threadpool",
- "//tensorflow/python/data/experimental/ops:unique",
+ "//tensorflow/python:lib",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/experimental/ops:writers",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- "@absl_py//absl/testing:parameterized",
+ "//tensorflow/python/data/ops:readers",
],
)
py_test(
- name = "unique_dataset_op_test",
- size = "small",
- srcs = ["unique_dataset_op_test.py"],
- srcs_version = "PY2AND3",
- tags = [
- "no_oss",
- "no_pip",
- "no_windows",
- ],
+ name = "unbatch_test",
+ size = "medium",
+ srcs = ["unbatch_test.py"],
deps = [
+ "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:session",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:string_ops",
+ "//tensorflow/python:tensor_shape",
"//tensorflow/python:util",
- "//tensorflow/python/data/experimental/ops:unique",
+ "//tensorflow/python/data/experimental/ops:batching",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
],
)
py_test(
- name = "writer_ops_test",
+ name = "unique_test",
size = "small",
- srcs = ["writer_ops_test.py"],
- tags = [
- "no_oss",
- "no_pip",
- "no_windows",
- ],
+ srcs = ["unique_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
deps = [
- "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
- "//tensorflow/python:lib",
+ "//tensorflow/python:errors",
"//tensorflow/python:util",
- "//tensorflow/python/data/experimental/ops:writers",
+ "//tensorflow/python/data/experimental/ops:unique",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/ops:readers",
],
)
diff --git a/tensorflow/python/data/experimental/kernel_tests/batch_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/batch_dataset_op_test.py
deleted file mode 100644
index 956b4518f6..0000000000
--- a/tensorflow/python/data/experimental/kernel_tests/batch_dataset_op_test.py
+++ /dev/null
@@ -1,686 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import math
-import time
-
-from absl.testing import parameterized
-import numpy as np
-
-from tensorflow.python.client import session
-from tensorflow.python.data.experimental.ops import batching
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import script_ops
-from tensorflow.python.ops import string_ops
-from tensorflow.python.platform import test
-from tensorflow.python.util import compat
-
-
-class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
-
- def testDenseToSparseBatchDataset(self):
- components = np.random.randint(12, size=(100,)).astype(np.int32)
- iterator = (
- dataset_ops.Dataset.from_tensor_slices(components)
- .map(lambda x: array_ops.fill([x], x)).apply(
- batching.dense_to_sparse_batch(4, [12]))
- .make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
-
- for start in range(0, len(components), 4):
- results = sess.run(get_next)
- self.assertAllEqual([[i, j]
- for i, c in enumerate(components[start:start + 4])
- for j in range(c)], results.indices)
- self.assertAllEqual(
- [c for c in components[start:start + 4] for _ in range(c)],
- results.values)
- self.assertAllEqual([min(4,
- len(components) - start), 12],
- results.dense_shape)
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testDenseToSparseBatchDatasetWithUnknownShape(self):
- components = np.random.randint(5, size=(40,)).astype(np.int32)
- iterator = (
- dataset_ops.Dataset.from_tensor_slices(components)
- .map(lambda x: array_ops.fill([x, x], x)).apply(
- batching.dense_to_sparse_batch(
- 4, [5, None])).make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
-
- for start in range(0, len(components), 4):
- results = sess.run(get_next)
- self.assertAllEqual([[i, j, z]
- for i, c in enumerate(components[start:start + 4])
- for j in range(c)
- for z in range(c)], results.indices)
- self.assertAllEqual([
- c
- for c in components[start:start + 4] for _ in range(c)
- for _ in range(c)
- ], results.values)
- self.assertAllEqual([
- min(4,
- len(components) - start), 5,
- np.max(components[start:start + 4])
- ], results.dense_shape)
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testDenseToSparseBatchDatasetWithInvalidShape(self):
- input_tensor = array_ops.constant([[1]])
- with self.assertRaisesRegexp(ValueError, "Dimension -2 must be >= 0"):
- dataset_ops.Dataset.from_tensors(input_tensor).apply(
- batching.dense_to_sparse_batch(4, [-2])).make_initializable_iterator()
-
- def testDenseToSparseBatchDatasetShapeErrors(self):
- input_tensor = array_ops.placeholder(dtypes.int32)
- iterator = (
- dataset_ops.Dataset.from_tensors(input_tensor).apply(
- batching.dense_to_sparse_batch(4, [12]))
- .make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- # Initialize with an input tensor of incompatible rank.
- sess.run(init_op, feed_dict={input_tensor: [[1]]})
- with self.assertRaisesRegexp(errors.InvalidArgumentError,
- "incompatible with the row shape"):
- sess.run(get_next)
-
- # Initialize with an input tensor that is larger than `row_shape`.
- sess.run(init_op, feed_dict={input_tensor: range(13)})
- with self.assertRaisesRegexp(errors.DataLossError,
- "larger than the row shape"):
- sess.run(get_next)
-
- def testUnbatchWithUnknownRankInput(self):
- placeholder = array_ops.placeholder(dtypes.int32)
- dataset = dataset_ops.Dataset.from_tensors(placeholder).apply(
- batching.unbatch())
- iterator = dataset.make_initializable_iterator()
- next_elem = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(iterator.initializer, feed_dict={placeholder: [0, 1, 2, 3]})
- for i in range(4):
- self.assertEqual(i, sess.run(next_elem))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_elem)
-
- def testUnbatchScalarDataset(self):
- data = tuple([math_ops.range(10) for _ in range(3)])
- data = dataset_ops.Dataset.from_tensor_slices(data)
- expected_types = (dtypes.int32,) * 3
- data = data.batch(2)
- self.assertEqual(expected_types, data.output_types)
- data = data.apply(batching.unbatch())
- self.assertEqual(expected_types, data.output_types)
-
- iterator = data.make_one_shot_iterator()
- op = iterator.get_next()
-
- with self.cached_session() as sess:
- for i in range(10):
- self.assertEqual((i,) * 3, sess.run(op))
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(op)
-
- def testUnbatchDatasetWithStrings(self):
- data = tuple([math_ops.range(10) for _ in range(3)])
- data = dataset_ops.Dataset.from_tensor_slices(data)
- data = data.map(lambda x, y, z: (x, string_ops.as_string(y), z))
- expected_types = (dtypes.int32, dtypes.string, dtypes.int32)
- data = data.batch(2)
- self.assertEqual(expected_types, data.output_types)
- data = data.apply(batching.unbatch())
- self.assertEqual(expected_types, data.output_types)
-
- iterator = data.make_one_shot_iterator()
- op = iterator.get_next()
-
- with self.cached_session() as sess:
- for i in range(10):
- self.assertEqual((i, compat.as_bytes(str(i)), i), sess.run(op))
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(op)
-
- def testUnbatchDatasetWithSparseTensor(self):
- st = sparse_tensor.SparseTensorValue(
- indices=[[i, i] for i in range(10)],
- values=list(range(10)),
- dense_shape=[10, 10])
- data = dataset_ops.Dataset.from_tensors(st)
- data = data.apply(batching.unbatch())
- data = data.batch(5)
- data = data.apply(batching.unbatch())
- iterator = data.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- for i in range(10):
- st_row = sess.run(next_element)
- self.assertEqual([i], st_row.indices)
- self.assertEqual([i], st_row.values)
- self.assertEqual([10], st_row.dense_shape)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testUnbatchDatasetWithDenseAndSparseTensor(self):
- st = sparse_tensor.SparseTensorValue(
- indices=[[i, i] for i in range(10)],
- values=list(range(10)),
- dense_shape=[10, 10])
- data = dataset_ops.Dataset.from_tensors((list(range(10)), st))
- data = data.apply(batching.unbatch())
- data = data.batch(5)
- data = data.apply(batching.unbatch())
- iterator = data.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- for i in range(10):
- dense_elem, st_row = sess.run(next_element)
- self.assertEqual(i, dense_elem)
- self.assertEqual([i], st_row.indices)
- self.assertEqual([i], st_row.values)
- self.assertEqual([10], st_row.dense_shape)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testUnbatchSingleElementTupleDataset(self):
- data = tuple([(math_ops.range(10),) for _ in range(3)])
- data = dataset_ops.Dataset.from_tensor_slices(data)
- expected_types = ((dtypes.int32,),) * 3
- data = data.batch(2)
- self.assertEqual(expected_types, data.output_types)
- data = data.apply(batching.unbatch())
- self.assertEqual(expected_types, data.output_types)
-
- iterator = data.make_one_shot_iterator()
- op = iterator.get_next()
-
- with self.cached_session() as sess:
- for i in range(10):
- self.assertEqual(((i,),) * 3, sess.run(op))
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(op)
-
- def testUnbatchMultiElementTupleDataset(self):
- data = tuple([(math_ops.range(10 * i, 10 * i + 10),
- array_ops.fill([10], "hi")) for i in range(3)])
- data = dataset_ops.Dataset.from_tensor_slices(data)
- expected_types = ((dtypes.int32, dtypes.string),) * 3
- data = data.batch(2)
- self.assertAllEqual(expected_types, data.output_types)
- data = data.apply(batching.unbatch())
- self.assertAllEqual(expected_types, data.output_types)
-
- iterator = data.make_one_shot_iterator()
- op = iterator.get_next()
-
- with self.cached_session() as sess:
- for i in range(10):
- self.assertEqual(((i, b"hi"), (10 + i, b"hi"), (20 + i, b"hi")),
- sess.run(op))
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(op)
-
- def testUnbatchEmpty(self):
- data = dataset_ops.Dataset.from_tensors(
- (constant_op.constant([]), constant_op.constant([], shape=[0, 4]),
- constant_op.constant([], shape=[0, 4, 0])))
- data = data.apply(batching.unbatch())
- iterator = data.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testUnbatchStaticShapeMismatch(self):
- data = dataset_ops.Dataset.from_tensors((np.arange(7), np.arange(8),
- np.arange(9)))
- with self.assertRaises(ValueError):
- data.apply(batching.unbatch())
-
- def testUnbatchDynamicShapeMismatch(self):
- ph1 = array_ops.placeholder(dtypes.int32, shape=[None])
- ph2 = array_ops.placeholder(dtypes.int32, shape=None)
- data = dataset_ops.Dataset.from_tensors((ph1, ph2))
- data = data.apply(batching.unbatch())
- iterator = data.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- # Mismatch in the 0th dimension.
- sess.run(
- iterator.initializer,
- feed_dict={
- ph1: np.arange(7).astype(np.int32),
- ph2: np.arange(8).astype(np.int32)
- })
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(next_element)
-
- # No 0th dimension (i.e. scalar value) for one component.
- sess.run(
- iterator.initializer,
- feed_dict={
- ph1: np.arange(7).astype(np.int32),
- ph2: 7
- })
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(next_element)
-
- @parameterized.named_parameters(
- ("Default", None, None),
- ("SequentialCalls", 1, None),
- ("ParallelCalls", 2, None),
- ("ParallelBatches", None, 10),
- )
- def testMapAndBatch(self, num_parallel_calls, num_parallel_batches):
- """Test a dataset that maps a TF function across its input elements."""
- # The pipeline is TensorSliceDataset ->
- # RepeatDataset(count) -> MapAndBatchDataset(square_3, batch_size).
- components = (np.arange(7),
- np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
- np.array(37.0) * np.arange(7))
-
- count = array_ops.placeholder(dtypes.int64, shape=[])
- batch_size = array_ops.placeholder(dtypes.int64, shape=[])
-
- def _map_fn(x, y, z):
- return math_ops.square(x), math_ops.square(y), math_ops.square(z)
-
- iterator = (
- dataset_ops.Dataset.from_tensor_slices(components).repeat(count).apply(
- batching.map_and_batch(
- map_func=_map_fn,
- batch_size=batch_size,
- num_parallel_calls=num_parallel_calls,
- num_parallel_batches=num_parallel_batches))
- .make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- self.assertEqual([[None] + list(c.shape[1:]) for c in components],
- [t.shape.as_list() for t in get_next])
-
- with self.cached_session() as sess:
- # Batch of a finite input, where the batch_size divides the
- # total number of elements.
- sess.run(init_op, feed_dict={count: 28, batch_size: 14})
- num_batches = (28 * 7) // 14
- for i in range(num_batches):
- result = sess.run(get_next)
- for component, result_component in zip(components, result):
- for j in range(14):
- self.assertAllEqual(component[(i * 14 + j) % 7]**2,
- result_component[j])
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Batch of a finite input, where the batch_size does not
- # divide the total number of elements.
- sess.run(init_op, feed_dict={count: 14, batch_size: 8})
-
- # We expect (num_batches - 1) full-sized batches.
- num_batches = int(math.ceil((14 * 7) / 8))
- for i in range(num_batches - 1):
- result = sess.run(get_next)
- for component, result_component in zip(components, result):
- for j in range(8):
- self.assertAllEqual(component[(i * 8 + j) % 7]**2,
- result_component[j])
- result = sess.run(get_next)
- for component, result_component in zip(components, result):
- for j in range((14 * 7) % 8):
- self.assertAllEqual(component[((num_batches - 1) * 8 + j) % 7]**2,
- result_component[j])
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Batch of an empty input should fail straight away.
- sess.run(init_op, feed_dict={count: 0, batch_size: 8})
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Empty batch should be an initialization time error.
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(init_op, feed_dict={count: 14, batch_size: 0})
-
- @parameterized.named_parameters(
- ("Even", False),
- ("Uneven", True),
- )
- def testMapAndBatchPartialBatch(self, drop_remainder):
- iterator = (
- dataset_ops.Dataset.range(10).apply(
- batching.map_and_batch(
- lambda x: array_ops.reshape(x * x, [1]),
- batch_size=4,
- drop_remainder=drop_remainder)).make_one_shot_iterator())
- if drop_remainder:
- self.assertEqual([4, 1], iterator.output_shapes.as_list())
- else:
- self.assertEqual([None, 1], iterator.output_shapes.as_list())
- next_element = iterator.get_next()
- with self.cached_session() as sess:
- self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
- self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
- if not drop_remainder:
- self.assertAllEqual([[64], [81]], sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testMapAndBatchYieldsPartialBatch(self):
- iterator = (dataset_ops.Dataset.range(10)
- .apply(batching.map_and_batch(
- lambda x: array_ops.reshape(x * x, [1]), 4))
- .make_one_shot_iterator())
- self.assertEqual([None, 1], iterator.output_shapes.as_list())
- next_element = iterator.get_next()
- with self.cached_session() as sess:
- self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
- self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
- self.assertAllEqual([[64], [81]], sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testMapAndBatchParallelGetNext(self):
- iterator = (dataset_ops.Dataset.range(50000)
- .apply(batching.map_and_batch(lambda x: x, batch_size=100))
- .make_one_shot_iterator())
- elements = []
- for _ in range(100):
- elements.append(iterator.get_next())
- with self.cached_session() as sess:
- for i in range(5):
- got = sess.run(elements)
- got.sort(key=lambda x: x[0])
- expected = []
- for j in range(100):
- expected.append(range(i*10000+j*100, i*10000+(j+1)*100))
- self.assertAllEqual(got, expected)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(elements)
-
- def testMapAndBatchParallelGetNextDropRemainder(self):
- iterator = (
- dataset_ops.Dataset.range(49999).apply(
- batching.map_and_batch(
- lambda x: x, batch_size=100, drop_remainder=True))
- .make_one_shot_iterator())
- elements = []
- for _ in range(100):
- elements.append(iterator.get_next())
- with self.cached_session() as sess:
- for i in range(4):
- got = sess.run(elements)
- got.sort(key=lambda x: x[0])
- expected = []
- for j in range(100):
- expected.append(range(i*10000+j*100, i*10000+(j+1)*100))
- self.assertAllEqual(got, expected)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(elements)
-
- def testMapAndBatchSparse(self):
-
- def _sparse(i):
- return sparse_tensor.SparseTensorValue(
- indices=[[0]], values=(i * [1]), dense_shape=[1])
-
- iterator = dataset_ops.Dataset.range(10).apply(
- batching.map_and_batch(_sparse, 5)).make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
- for i in range(2):
- actual = sess.run(get_next)
- expected = sparse_tensor.SparseTensorValue(
- indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
- values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4],
- dense_shape=[5, 1])
- self.assertTrue(sparse_tensor.is_sparse(actual))
- self.assertSparseValuesEqual(actual, expected)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testMapAndBatchFails(self):
- """Test a dataset that maps a TF function across its input elements."""
- dataset = dataset_ops.Dataset.from_tensors(
- array_ops.check_numerics(
- constant_op.constant(1.0) / constant_op.constant(0.0), "oops"))
- batch_size = array_ops.placeholder(dtypes.int64, shape=[])
- iterator = (
- dataset.apply(batching.map_and_batch(lambda x: x, batch_size))
- .make_initializable_iterator())
- init_op = iterator.initializer
- with self.cached_session() as sess:
- with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"):
- sess.run(init_op, feed_dict={batch_size: 14})
-
- def testMapAndBatchShapeMismatch(self):
- """Test a dataset that maps a TF function across its input elements."""
-
- def generator():
- yield [1]
- yield [2]
- yield [3]
- yield [[4, 5, 6]]
-
- dataset = dataset_ops.Dataset.from_generator(
- generator, output_types=dtypes.int32)
- batch_size = 4
- iterator = (
- dataset.apply(batching.map_and_batch(lambda x: x, batch_size))
- .make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- sess.run(init_op)
- with self.assertRaisesRegexp(errors.InvalidArgumentError,
- "number of elements does not match"):
- sess.run(get_next)
-
- def testMapAndBatchImplicitDispose(self):
- # Tests whether a map and batch dataset will be cleaned up correctly when
- # the pipeline does not run it until exhaustion.
- # The pipeline is TensorSliceDataset -> RepeatDataset(1000) ->
- # MapAndBatchDataset(f=square_3, batch_size=100).
- components = (np.arange(1000),
- np.array([[1, 2, 3]]) * np.arange(1000)[:, np.newaxis],
- np.array(37.0) * np.arange(1000))
-
- def _map_fn(x, y, z):
- return math_ops.square(x), math_ops.square(y), math_ops.square(z)
-
- dataset = dataset_ops.Dataset.from_tensor_slices(components).repeat(
- 1000).apply(batching.map_and_batch(_map_fn, batch_size=100))
- dataset = dataset.prefetch(5)
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- for _ in range(3):
- sess.run(get_next)
-
- @parameterized.named_parameters(
- ("1", 0),
- ("2", 5),
- ("3", 10),
- ("4", 90),
- ("5", 95),
- ("6", 99),
- )
- def testMapAndBatchOutOfRangeError(self, threshold):
-
- def raising_py_fn(i):
- if i >= threshold:
- raise StopIteration()
- else:
- return i
-
- iterator = (
- dataset_ops.Dataset.range(100).apply(
- batching.map_and_batch(
- lambda x: script_ops.py_func(raising_py_fn, [x], dtypes.int64),
- batch_size=10)).make_one_shot_iterator())
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- for i in range(threshold // 10):
- self.assertAllEqual([i * 10 + j for j in range(10)], sess.run(get_next))
- if threshold % 10 != 0:
- self.assertAllEqual(
- [threshold // 10 * 10 + j for j in range(threshold % 10)],
- sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- @parameterized.named_parameters(
- ("1", False, dtypes.bool),
- ("2", -42, dtypes.int8),
- ("3", -42, dtypes.int16),
- ("4", -42, dtypes.int32),
- ("5", -42, dtypes.int64),
- ("6", 42, dtypes.uint8),
- ("7", 42, dtypes.uint16),
- ("8", 42.0, dtypes.float16),
- ("9", 42.0, dtypes.float32),
- ("10", 42.0, dtypes.float64),
- ("11", b"hello", dtypes.string),
- )
- def testMapAndBatchTypes(self, element, dtype):
- def gen():
- yield element
-
- dataset = dataset_ops.Dataset.from_generator(gen, dtype).repeat(100).apply(
- batching.map_and_batch(lambda x: x, batch_size=10))
-
- get_next = dataset.make_one_shot_iterator().get_next()
-
- with self.cached_session() as sess:
- for _ in range(10):
- self.assertAllEqual([element for _ in range(10)], sess.run(get_next))
-
-
-class UnbatchDatasetBenchmark(test.Benchmark):
-
- def benchmarkNativeUnbatch(self):
- batch_sizes = [1, 2, 5, 10, 20, 50]
- elems_per_trial = 10000
- with ops.Graph().as_default():
- dataset = dataset_ops.Dataset.from_tensors("element").repeat(None)
- batch_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
- dataset = dataset.batch(batch_size_placeholder)
- dataset = dataset.apply(batching.unbatch())
- dataset = dataset.skip(elems_per_trial)
- iterator = dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with session.Session() as sess:
- for batch_size in batch_sizes:
- deltas = []
- for _ in range(5):
- sess.run(
- iterator.initializer,
- feed_dict={batch_size_placeholder: batch_size})
- start = time.time()
- sess.run(next_element.op)
- end = time.time()
- deltas.append((end - start) / elems_per_trial)
-
- median_wall_time = np.median(deltas)
- print("Unbatch (native) batch size: %d Median wall time per element:"
- " %f microseconds" % (batch_size, median_wall_time * 1e6))
- self.report_benchmark(
- iters=10000,
- wall_time=median_wall_time,
- name="benchmark_unbatch_dataset_native_batch_size_%d" %
- batch_size)
-
- # Include a benchmark of the previous `unbatch()` implementation that uses
- # a composition of more primitive ops. Eventually we'd hope to generate code
- # that is as good in both cases.
- def benchmarkOldUnbatchImplementation(self):
- batch_sizes = [1, 2, 5, 10, 20, 50]
- elems_per_trial = 10000
- with ops.Graph().as_default():
- dataset = dataset_ops.Dataset.from_tensors("element").repeat(None)
- batch_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
- dataset = dataset.batch(batch_size_placeholder)
- dataset = dataset.flat_map(dataset_ops.Dataset.from_tensor_slices)
- dataset = dataset.skip(elems_per_trial)
- iterator = dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with session.Session() as sess:
- for batch_size in batch_sizes:
- deltas = []
- for _ in range(5):
- sess.run(
- iterator.initializer,
- feed_dict={batch_size_placeholder: batch_size})
- start = time.time()
- sess.run(next_element.op)
- end = time.time()
- deltas.append((end - start) / elems_per_trial)
-
- median_wall_time = np.median(deltas)
- print("Unbatch (unfused) batch size: %d Median wall time per element:"
- " %f microseconds" % (batch_size, median_wall_time * 1e6))
- self.report_benchmark(
- iters=10000,
- wall_time=median_wall_time,
- name="benchmark_unbatch_dataset_unfused_batch_size_%d" %
- batch_size)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/bucket_by_sequence_length_test.py b/tensorflow/python/data/experimental/kernel_tests/bucket_by_sequence_length_test.py
new file mode 100644
index 0000000000..3903ec49b9
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/bucket_by_sequence_length_test.py
@@ -0,0 +1,322 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.experimental.bucket_by_sequence_length()."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import random
+
+from tensorflow.python.data.experimental.ops import grouping
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+def _element_length_fn(x, y=None):
+ del y
+ return array_ops.shape(x)[0]
+
+
+def _to_sparse_tensor(record):
+ return sparse_tensor.SparseTensor(**record)
+
+
+def _format_record(array, sparse):
+ if sparse:
+ return {
+ "values": array,
+ "indices": [[i] for i in range(len(array))],
+ "dense_shape": (len(array),)
+ }
+ return array
+
+
+def _get_record_type(sparse):
+ if sparse:
+ return {
+ "values": dtypes.int64,
+ "indices": dtypes.int64,
+ "dense_shape": dtypes.int64
+ }
+ return dtypes.int32
+
+
+def _get_record_shape(sparse):
+ if sparse:
+ return {
+ "values": tensor_shape.TensorShape([None,]),
+ "indices": tensor_shape.TensorShape([None, 1]),
+ "dense_shape": tensor_shape.TensorShape([1,])
+ }
+ return tensor_shape.TensorShape([None])
+
+
+class BucketBySequenceLengthTest(test_base.DatasetTestBase):
+
+ def testBucket(self):
+
+ boundaries = [10, 20, 30]
+ batch_sizes = [10, 8, 4, 2]
+ lengths = [8, 13, 25, 35]
+
+ def build_dataset(sparse):
+ def _generator():
+ # Produce 1 batch for each bucket
+ elements = []
+ for batch_size, length in zip(batch_sizes, lengths):
+ record_len = length - 1
+ for _ in range(batch_size):
+ elements.append([1] * record_len)
+ record_len = length
+ random.shuffle(elements)
+ for el in elements:
+ yield (_format_record(el, sparse),)
+ dataset = dataset_ops.Dataset.from_generator(
+ _generator,
+ (_get_record_type(sparse),),
+ (_get_record_shape(sparse),))
+ if sparse:
+ dataset = dataset.map(lambda x: (_to_sparse_tensor(x),))
+ return dataset
+
+ def _test_bucket_by_padding(no_padding):
+ dataset = build_dataset(sparse=no_padding)
+ dataset = dataset.apply(
+ grouping.bucket_by_sequence_length(
+ _element_length_fn,
+ boundaries,
+ batch_sizes,
+ no_padding=no_padding))
+ batch, = dataset.make_one_shot_iterator().get_next()
+
+ with self.cached_session() as sess:
+ batches = []
+ for _ in range(4):
+ batches.append(sess.run(batch))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(batch)
+ batch_sizes_val = []
+ lengths_val = []
+ for batch in batches:
+ shape = batch.dense_shape if no_padding else batch.shape
+ batch_size = shape[0]
+ length = shape[1]
+ batch_sizes_val.append(batch_size)
+ lengths_val.append(length)
+ sum_check = batch.values.sum() if no_padding else batch.sum()
+ self.assertEqual(sum_check, batch_size * length - 1)
+ self.assertEqual(sum(batch_sizes_val), sum(batch_sizes))
+ self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val))
+ self.assertEqual(sorted(lengths), sorted(lengths_val))
+
+ for no_padding in (True, False):
+ _test_bucket_by_padding(no_padding)
+
+ def testPadToBoundary(self):
+
+ boundaries = [10, 20, 30]
+ batch_sizes = [10, 8, 4, 2]
+ lengths = [8, 13, 25]
+
+ def element_gen():
+ # Produce 1 batch for each bucket
+ elements = []
+ for batch_size, length in zip(batch_sizes[:-1], lengths):
+ for _ in range(batch_size):
+ elements.append([1] * length)
+ random.shuffle(elements)
+ for el in elements:
+ yield (el,)
+ for _ in range(batch_sizes[-1]):
+ el = [1] * (boundaries[-1] + 5)
+ yield (el,)
+
+ element_len = lambda el: array_ops.shape(el)[0]
+ dataset = dataset_ops.Dataset.from_generator(
+ element_gen, (dtypes.int64,), ([None],)).apply(
+ grouping.bucket_by_sequence_length(
+ element_len, boundaries, batch_sizes,
+ pad_to_bucket_boundary=True))
+ batch, = dataset.make_one_shot_iterator().get_next()
+
+ with self.cached_session() as sess:
+ batches = []
+ for _ in range(3):
+ batches.append(sess.run(batch))
+ with self.assertRaisesOpError("bucket_boundaries"):
+ sess.run(batch)
+ batch_sizes_val = []
+ lengths_val = []
+ for batch in batches:
+ batch_size = batch.shape[0]
+ length = batch.shape[1]
+ batch_sizes_val.append(batch_size)
+ lengths_val.append(length)
+ batch_sizes = batch_sizes[:-1]
+ self.assertEqual(sum(batch_sizes_val), sum(batch_sizes))
+ self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val))
+ self.assertEqual([boundary - 1 for boundary in sorted(boundaries)],
+ sorted(lengths_val))
+
+ def testPadToBoundaryNoExtraneousPadding(self):
+
+ boundaries = [3, 7, 11]
+ batch_sizes = [2, 2, 2, 2]
+ lengths = range(1, 11)
+
+ def element_gen():
+ for length in lengths:
+ yield ([1] * length,)
+
+ element_len = lambda element: array_ops.shape(element)[0]
+ dataset = dataset_ops.Dataset.from_generator(
+ element_gen, (dtypes.int64,), ([None],)).apply(
+ grouping.bucket_by_sequence_length(
+ element_len, boundaries, batch_sizes,
+ pad_to_bucket_boundary=True))
+ batch, = dataset.make_one_shot_iterator().get_next()
+
+ with self.cached_session() as sess:
+ batches = []
+ for _ in range(5):
+ batches.append(sess.run(batch))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(batch)
+
+ self.assertAllEqual(batches[0], [[1, 0],
+ [1, 1]])
+ self.assertAllEqual(batches[1], [[1, 1, 1, 0, 0, 0],
+ [1, 1, 1, 1, 0, 0]])
+ self.assertAllEqual(batches[2], [[1, 1, 1, 1, 1, 0],
+ [1, 1, 1, 1, 1, 1]])
+ self.assertAllEqual(batches[3], [[1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
+ [1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
+ self.assertAllEqual(batches[4], [[1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
+ [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
+
+ def testTupleElements(self):
+
+ def build_dataset(sparse):
+ def _generator():
+ text = [[1, 2, 3], [3, 4, 5, 6, 7], [1, 2], [8, 9, 0, 2, 3]]
+ label = [1, 2, 1, 2]
+ for x, y in zip(text, label):
+ yield (_format_record(x, sparse), y)
+ dataset = dataset_ops.Dataset.from_generator(
+ generator=_generator,
+ output_types=(_get_record_type(sparse), dtypes.int32),
+ output_shapes=(_get_record_shape(sparse),
+ tensor_shape.TensorShape([])))
+ if sparse:
+ dataset = dataset.map(lambda x, y: (_to_sparse_tensor(x), y))
+ return dataset
+
+ def _test_tuple_elements_by_padding(no_padding):
+ dataset = build_dataset(sparse=no_padding)
+ dataset = dataset.apply(grouping.bucket_by_sequence_length(
+ element_length_func=_element_length_fn,
+ bucket_batch_sizes=[2, 2, 2],
+ bucket_boundaries=[0, 8],
+ no_padding=no_padding))
+ shapes = dataset.output_shapes
+ self.assertEqual([None, None], shapes[0].as_list())
+ self.assertEqual([None], shapes[1].as_list())
+
+ for no_padding in (True, False):
+ _test_tuple_elements_by_padding(no_padding)
+
+ def testBucketSparse(self):
+ """Tests bucketing of sparse tensors (case where `no_padding` == True).
+
+ Test runs on following dataset:
+ [
+ [0],
+ [0, 1],
+ [0, 1, 2]
+ ...
+ [0, ..., max_len - 1]
+ ]
+ Sequences are bucketed by length and batched with
+ `batch_size` < `bucket_size`.
+ """
+
+ min_len = 0
+ max_len = 100
+ batch_size = 7
+ bucket_size = 10
+
+ def _build_dataset():
+ input_data = [range(i+1) for i in range(min_len, max_len)]
+ def generator_fn():
+ for record in input_data:
+ yield _format_record(record, sparse=True)
+ dataset = dataset_ops.Dataset.from_generator(
+ generator=generator_fn,
+ output_types=_get_record_type(sparse=True))
+ dataset = dataset.map(_to_sparse_tensor)
+ return dataset
+
+ def _compute_expected_batches():
+ """Computes expected batch outputs and stores in a set."""
+ all_expected_sparse_tensors = set()
+ for bucket_start_len in range(min_len, max_len, bucket_size):
+ for batch_offset in range(0, bucket_size, batch_size):
+ batch_start_len = bucket_start_len + batch_offset
+ batch_end_len = min(batch_start_len + batch_size,
+ bucket_start_len + bucket_size)
+ expected_indices = []
+ expected_values = []
+ for length in range(batch_start_len, batch_end_len):
+ for val in range(length + 1):
+ expected_indices.append((length - batch_start_len, val))
+ expected_values.append(val)
+ expected_sprs_tensor = (tuple(expected_indices),
+ tuple(expected_values))
+ all_expected_sparse_tensors.add(expected_sprs_tensor)
+ return all_expected_sparse_tensors
+
+ def _compute_batches(dataset):
+ """Computes actual batch outputs of dataset and stores in a set."""
+ batch = dataset.make_one_shot_iterator().get_next()
+ all_sparse_tensors = set()
+ with self.cached_session() as sess:
+ with self.assertRaises(errors.OutOfRangeError):
+ while True:
+ output = sess.run(batch)
+ sprs_tensor = (tuple([tuple(idx) for idx in output.indices]),
+ tuple(output.values))
+ all_sparse_tensors.add(sprs_tensor)
+ return all_sparse_tensors
+
+ dataset = _build_dataset()
+ boundaries = range(min_len + bucket_size + 1, max_len, bucket_size)
+ dataset = dataset.apply(grouping.bucket_by_sequence_length(
+ _element_length_fn,
+ boundaries,
+ [batch_size] * (len(boundaries) + 1),
+ no_padding=True))
+ batches = _compute_batches(dataset)
+ expected_batches = _compute_expected_batches()
+ self.assertEqual(batches, expected_batches)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/bucketing_test.py b/tensorflow/python/data/experimental/kernel_tests/bucketing_test.py
deleted file mode 100644
index 153a03989b..0000000000
--- a/tensorflow/python/data/experimental/kernel_tests/bucketing_test.py
+++ /dev/null
@@ -1,824 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import random
-
-import numpy as np
-
-from tensorflow.python.data.experimental.ops import grouping
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import string_ops
-from tensorflow.python.platform import test
-
-
-class GroupByReducerTest(test_base.DatasetTestBase):
-
- def checkResults(self, dataset, shapes, values):
- self.assertEqual(shapes, dataset.output_shapes)
- get_next = dataset.make_one_shot_iterator().get_next()
- with self.cached_session() as sess:
- for expected in values:
- got = sess.run(get_next)
- self.assertEqual(got, expected)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testSum(self):
- reducer = grouping.Reducer(
- init_func=lambda _: np.int64(0),
- reduce_func=lambda x, y: x + y,
- finalize_func=lambda x: x)
- for i in range(1, 11):
- dataset = dataset_ops.Dataset.range(2 * i).apply(
- grouping.group_by_reducer(lambda x: x % 2, reducer))
- self.checkResults(
- dataset, shapes=tensor_shape.scalar(), values=[(i - 1) * i, i * i])
-
- def testAverage(self):
-
- def reduce_fn(x, y):
- return (x[0] * x[1] + math_ops.cast(y, dtypes.float32)) / (
- x[1] + 1), x[1] + 1
-
- reducer = grouping.Reducer(
- init_func=lambda _: (0.0, 0.0),
- reduce_func=reduce_fn,
- finalize_func=lambda x, _: x)
- for i in range(1, 11):
- dataset = dataset_ops.Dataset.range(2 * i).apply(
- grouping.group_by_reducer(
- lambda x: math_ops.cast(x, dtypes.int64) % 2, reducer))
- self.checkResults(
- dataset, shapes=tensor_shape.scalar(), values=[i - 1, i])
-
- def testConcat(self):
- components = np.array(list("abcdefghijklmnopqrst")).view(np.chararray)
- reducer = grouping.Reducer(
- init_func=lambda x: "",
- reduce_func=lambda x, y: x + y[0],
- finalize_func=lambda x: x)
- for i in range(1, 11):
- dataset = dataset_ops.Dataset.zip(
- (dataset_ops.Dataset.from_tensor_slices(components),
- dataset_ops.Dataset.range(2 * i))).apply(
- grouping.group_by_reducer(lambda x, y: y % 2, reducer))
- self.checkResults(
- dataset,
- shapes=tensor_shape.scalar(),
- values=[b"acegikmoqs" [:i], b"bdfhjlnprt" [:i]])
-
- def testSparseSum(self):
- def _sparse(i):
- return sparse_tensor.SparseTensorValue(
- indices=np.array([[0, 0]]),
- values=(i * np.array([1], dtype=np.int64)),
- dense_shape=np.array([1, 1]))
-
- reducer = grouping.Reducer(
- init_func=lambda _: _sparse(np.int64(0)),
- reduce_func=lambda x, y: _sparse(x.values[0] + y.values[0]),
- finalize_func=lambda x: x.values[0])
- for i in range(1, 11):
- dataset = dataset_ops.Dataset.range(2 * i).map(_sparse).apply(
- grouping.group_by_reducer(lambda x: x.values[0] % 2, reducer))
- self.checkResults(
- dataset, shapes=tensor_shape.scalar(), values=[(i - 1) * i, i * i])
-
- def testChangingStateShape(self):
-
- def reduce_fn(x, _):
- # Statically known rank, but dynamic length.
- larger_dim = array_ops.concat([x[0], x[0]], 0)
- # Statically unknown rank.
- larger_rank = array_ops.expand_dims(x[1], 0)
- return larger_dim, larger_rank
-
- reducer = grouping.Reducer(
- init_func=lambda x: ([0], 1),
- reduce_func=reduce_fn,
- finalize_func=lambda x, y: (x, y))
-
- for i in range(1, 11):
- dataset = dataset_ops.Dataset.from_tensors(np.int64(0)).repeat(i).apply(
- grouping.group_by_reducer(lambda x: x, reducer))
- self.assertEqual([None], dataset.output_shapes[0].as_list())
- self.assertIs(None, dataset.output_shapes[1].ndims)
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- x, y = sess.run(get_next)
- self.assertAllEqual([0] * (2**i), x)
- self.assertAllEqual(np.array(1, ndmin=i), y)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testTypeMismatch(self):
- reducer = grouping.Reducer(
- init_func=lambda x: constant_op.constant(1, dtype=dtypes.int32),
- reduce_func=lambda x, y: constant_op.constant(1, dtype=dtypes.int64),
- finalize_func=lambda x: x)
-
- dataset = dataset_ops.Dataset.range(10)
- with self.assertRaisesRegexp(
- TypeError,
- "The element types for the new state must match the initial state."):
- dataset.apply(
- grouping.group_by_reducer(lambda _: np.int64(0), reducer))
-
- # TODO(b/78665031): Remove once non-scalar keys are supported.
- def testInvalidKeyShape(self):
- reducer = grouping.Reducer(
- init_func=lambda x: np.int64(0),
- reduce_func=lambda x, y: x + y,
- finalize_func=lambda x: x)
-
- dataset = dataset_ops.Dataset.range(10)
- with self.assertRaisesRegexp(
- ValueError, "`key_func` must return a single tf.int64 tensor."):
- dataset.apply(
- grouping.group_by_reducer(lambda _: np.int64((0, 0)), reducer))
-
- # TODO(b/78665031): Remove once non-int64 keys are supported.
- def testInvalidKeyType(self):
- reducer = grouping.Reducer(
- init_func=lambda x: np.int64(0),
- reduce_func=lambda x, y: x + y,
- finalize_func=lambda x: x)
-
- dataset = dataset_ops.Dataset.range(10)
- with self.assertRaisesRegexp(
- ValueError, "`key_func` must return a single tf.int64 tensor."):
- dataset.apply(
- grouping.group_by_reducer(lambda _: "wrong", reducer))
-
- def testTuple(self):
- def init_fn(_):
- return np.array([], dtype=np.int64), np.int64(0)
-
- def reduce_fn(state, value):
- s1, s2 = state
- v1, v2 = value
- return array_ops.concat([s1, [v1]], 0), s2 + v2
-
- def finalize_fn(s1, s2):
- return s1, s2
-
- reducer = grouping.Reducer(init_fn, reduce_fn, finalize_fn)
- dataset = dataset_ops.Dataset.zip(
- (dataset_ops.Dataset.range(10), dataset_ops.Dataset.range(10))).apply(
- grouping.group_by_reducer(lambda x, y: np.int64(0), reducer))
- get_next = dataset.make_one_shot_iterator().get_next()
- with self.cached_session() as sess:
- x, y = sess.run(get_next)
- self.assertAllEqual(x, np.asarray([x for x in range(10)]))
- self.assertEqual(y, 45)
-
-
-class GroupByWindowTest(test_base.DatasetTestBase):
-
- def testSimple(self):
- components = np.random.randint(100, size=(200,)).astype(np.int64)
- iterator = (
- dataset_ops.Dataset.from_tensor_slices(components).map(lambda x: x * x)
- .apply(
- grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(4),
- 4)).make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
- counts = []
- with self.assertRaises(errors.OutOfRangeError):
- while True:
- result = sess.run(get_next)
- self.assertTrue(
- all(x % 2 == 0
- for x in result) or all(x % 2 == 1)
- for x in result)
- counts.append(result.shape[0])
-
- self.assertEqual(len(components), sum(counts))
- num_full_batches = len([c for c in counts if c == 4])
- self.assertGreaterEqual(num_full_batches, 24)
- self.assertTrue(all(c == 4 for c in counts[:num_full_batches]))
-
- def testImmediateOutput(self):
- components = np.array(
- [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0], dtype=np.int64)
- iterator = (
- dataset_ops.Dataset.from_tensor_slices(components).repeat(-1).apply(
- grouping.group_by_window(lambda x: x % 3, lambda _, xs: xs.batch(4),
- 4)).make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
- # The input is infinite, so this test demonstrates that:
- # 1. We produce output without having to consume the entire input,
- # 2. Different buckets can produce output at different rates, and
- # 3. For deterministic input, the output is deterministic.
- for _ in range(3):
- self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
- self.assertAllEqual([1, 1, 1, 1], sess.run(get_next))
- self.assertAllEqual([2, 2, 2, 2], sess.run(get_next))
- self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
-
- def testSmallGroups(self):
- components = np.array([0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0], dtype=np.int64)
- iterator = (
- dataset_ops.Dataset.from_tensor_slices(components).apply(
- grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(4),
- 4)).make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
- self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
- self.assertAllEqual([1, 1, 1, 1], sess.run(get_next))
- # The small outputs at the end are deterministically produced in key
- # order.
- self.assertAllEqual([0, 0, 0], sess.run(get_next))
- self.assertAllEqual([1], sess.run(get_next))
-
- def testEmpty(self):
- iterator = (
- dataset_ops.Dataset.range(4).apply(
- grouping.group_by_window(lambda _: 0, lambda _, xs: xs, 0))
- .make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
- with self.assertRaisesRegexp(
- errors.InvalidArgumentError,
- "Window size must be greater than zero, but got 0."):
- print(sess.run(get_next))
-
- def testReduceFuncError(self):
- components = np.random.randint(100, size=(200,)).astype(np.int64)
-
- def reduce_func(_, xs):
- # Introduce an incorrect padded shape that cannot (currently) be
- # detected at graph construction time.
- return xs.padded_batch(
- 4,
- padded_shapes=(tensor_shape.TensorShape([]),
- constant_op.constant([5], dtype=dtypes.int64) * -1))
-
- iterator = (
- dataset_ops.Dataset.from_tensor_slices(components)
- .map(lambda x: (x, ops.convert_to_tensor([x * x]))).apply(
- grouping.group_by_window(lambda x, _: x % 2, reduce_func,
- 32)).make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(get_next)
-
- def testConsumeWindowDatasetMoreThanOnce(self):
- components = np.random.randint(50, size=(200,)).astype(np.int64)
-
- def reduce_func(key, window):
- # Apply two different kinds of padding to the input: tight
- # padding, and quantized (to a multiple of 10) padding.
- return dataset_ops.Dataset.zip((
- window.padded_batch(
- 4, padded_shapes=tensor_shape.TensorShape([None])),
- window.padded_batch(
- 4, padded_shapes=ops.convert_to_tensor([(key + 1) * 10])),
- ))
-
- iterator = (
- dataset_ops.Dataset.from_tensor_slices(components)
- .map(lambda x: array_ops.fill([math_ops.cast(x, dtypes.int32)], x))
- .apply(grouping.group_by_window(
- lambda x: math_ops.cast(array_ops.shape(x)[0] // 10, dtypes.int64),
- reduce_func, 4))
- .make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
- counts = []
- with self.assertRaises(errors.OutOfRangeError):
- while True:
- tight_result, multiple_of_10_result = sess.run(get_next)
- self.assertEqual(0, multiple_of_10_result.shape[1] % 10)
- self.assertAllEqual(tight_result,
- multiple_of_10_result[:, :tight_result.shape[1]])
- counts.append(tight_result.shape[0])
- self.assertEqual(len(components), sum(counts))
-
-
-# NOTE(mrry): These tests are based on the tests in bucket_ops_test.py.
-# Currently, they use a constant batch size, though should be made to use a
-# different batch size per key.
-class BucketTest(test_base.DatasetTestBase):
-
- def _dynamicPad(self, bucket, window, window_size):
- # TODO(mrry): To match `tf.contrib.training.bucket()`, implement a
- # generic form of padded_batch that pads every component
- # dynamically and does not rely on static shape information about
- # the arguments.
- return dataset_ops.Dataset.zip(
- (dataset_ops.Dataset.from_tensors(bucket),
- window.padded_batch(
- 32, (tensor_shape.TensorShape([]), tensor_shape.TensorShape(
- [None]), tensor_shape.TensorShape([3])))))
-
- def testSingleBucket(self):
-
- def _map_fn(v):
- return (v, array_ops.fill([v], v),
- array_ops.fill([3], string_ops.as_string(v)))
-
- input_dataset = (
- dataset_ops.Dataset.from_tensor_slices(math_ops.range(32)).map(_map_fn))
-
- bucketed_dataset = input_dataset.apply(
- grouping.group_by_window(
- lambda x, y, z: 0,
- lambda k, bucket: self._dynamicPad(k, bucket, 32), 32))
-
- iterator = bucketed_dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
-
- which_bucket, bucketed_values = sess.run(get_next)
-
- self.assertEqual(0, which_bucket)
-
- expected_scalar_int = np.arange(32, dtype=np.int64)
- expected_unk_int64 = np.zeros((32, 31)).astype(np.int64)
- for i in range(32):
- expected_unk_int64[i, :i] = i
- expected_vec3_str = np.vstack(3 * [np.arange(32).astype(bytes)]).T
-
- self.assertAllEqual(expected_scalar_int, bucketed_values[0])
- self.assertAllEqual(expected_unk_int64, bucketed_values[1])
- self.assertAllEqual(expected_vec3_str, bucketed_values[2])
-
- def testEvenOddBuckets(self):
-
- def _map_fn(v):
- return (v, array_ops.fill([v], v),
- array_ops.fill([3], string_ops.as_string(v)))
-
- input_dataset = (
- dataset_ops.Dataset.from_tensor_slices(math_ops.range(64)).map(_map_fn))
-
- bucketed_dataset = input_dataset.apply(
- grouping.group_by_window(
- lambda x, y, z: math_ops.cast(x % 2, dtypes.int64),
- lambda k, bucket: self._dynamicPad(k, bucket, 32), 32))
-
- iterator = bucketed_dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
-
- # Get two minibatches (one containing even values, one containing odds)
- which_bucket_even, bucketed_values_even = sess.run(get_next)
- which_bucket_odd, bucketed_values_odd = sess.run(get_next)
-
- # Count number of bucket_tensors.
- self.assertEqual(3, len(bucketed_values_even))
- self.assertEqual(3, len(bucketed_values_odd))
-
- # Ensure bucket 0 was used for all minibatch entries.
- self.assertAllEqual(0, which_bucket_even)
- self.assertAllEqual(1, which_bucket_odd)
-
- # Test the first bucket outputted, the events starting at 0
- expected_scalar_int = np.arange(0, 32 * 2, 2, dtype=np.int64)
- expected_unk_int64 = np.zeros((32, 31 * 2)).astype(np.int64)
- for i in range(0, 32):
- expected_unk_int64[i, :2 * i] = 2 * i
- expected_vec3_str = np.vstack(
- 3 * [np.arange(0, 32 * 2, 2).astype(bytes)]).T
-
- self.assertAllEqual(expected_scalar_int, bucketed_values_even[0])
- self.assertAllEqual(expected_unk_int64, bucketed_values_even[1])
- self.assertAllEqual(expected_vec3_str, bucketed_values_even[2])
-
- # Test the second bucket outputted, the odds starting at 1
- expected_scalar_int = np.arange(1, 32 * 2 + 1, 2, dtype=np.int64)
- expected_unk_int64 = np.zeros((32, 31 * 2 + 1)).astype(np.int64)
- for i in range(0, 32):
- expected_unk_int64[i, :2 * i + 1] = 2 * i + 1
- expected_vec3_str = np.vstack(
- 3 * [np.arange(1, 32 * 2 + 1, 2).astype(bytes)]).T
-
- self.assertAllEqual(expected_scalar_int, bucketed_values_odd[0])
- self.assertAllEqual(expected_unk_int64, bucketed_values_odd[1])
- self.assertAllEqual(expected_vec3_str, bucketed_values_odd[2])
-
- def testEvenOddBucketsFilterOutAllOdd(self):
-
- def _map_fn(v):
- return {
- "x": v,
- "y": array_ops.fill([v], v),
- "z": array_ops.fill([3], string_ops.as_string(v))
- }
-
- def _dynamic_pad_fn(bucket, window, _):
- return dataset_ops.Dataset.zip(
- (dataset_ops.Dataset.from_tensors(bucket),
- window.padded_batch(
- 32, {
- "x": tensor_shape.TensorShape([]),
- "y": tensor_shape.TensorShape([None]),
- "z": tensor_shape.TensorShape([3])
- })))
-
- input_dataset = (
- dataset_ops.Dataset.from_tensor_slices(math_ops.range(128)).map(_map_fn)
- .filter(lambda d: math_ops.equal(d["x"] % 2, 0)))
-
- bucketed_dataset = input_dataset.apply(
- grouping.group_by_window(
- lambda d: math_ops.cast(d["x"] % 2, dtypes.int64),
- lambda k, bucket: _dynamic_pad_fn(k, bucket, 32), 32))
-
- iterator = bucketed_dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
-
- # Get two minibatches ([0, 2, ...] and [64, 66, ...])
- which_bucket0, bucketed_values_even0 = sess.run(get_next)
- which_bucket1, bucketed_values_even1 = sess.run(get_next)
-
- # Ensure that bucket 1 was completely filtered out
- self.assertAllEqual(0, which_bucket0)
- self.assertAllEqual(0, which_bucket1)
- self.assertAllEqual(
- np.arange(0, 64, 2, dtype=np.int64), bucketed_values_even0["x"])
- self.assertAllEqual(
- np.arange(64, 128, 2, dtype=np.int64), bucketed_values_even1["x"])
-
- def testDynamicWindowSize(self):
- components = np.arange(100).astype(np.int64)
-
- # Key fn: even/odd
- # Reduce fn: batches of 5
- # Window size fn: even=5, odd=10
-
- def window_size_func(key):
- window_sizes = constant_op.constant([5, 10], dtype=dtypes.int64)
- return window_sizes[key]
-
- dataset = dataset_ops.Dataset.from_tensor_slices(components).apply(
- grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(20),
- None, window_size_func))
- iterator = dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
- with self.assertRaises(errors.OutOfRangeError):
- batches = 0
- while True:
- result = sess.run(get_next)
- is_even = all(x % 2 == 0 for x in result)
- is_odd = all(x % 2 == 1 for x in result)
- self.assertTrue(is_even or is_odd)
- expected_batch_size = 5 if is_even else 10
- self.assertEqual(expected_batch_size, result.shape[0])
- batches += 1
-
- self.assertEqual(batches, 15)
-
-
-def _element_length_fn(x, y=None):
- del y
- return array_ops.shape(x)[0]
-
-
-def _to_sparse_tensor(record):
- return sparse_tensor.SparseTensor(**record)
-
-
-def _format_record(array, sparse):
- if sparse:
- return {
- "values": array,
- "indices": [[i] for i in range(len(array))],
- "dense_shape": (len(array),)
- }
- return array
-
-
-def _get_record_type(sparse):
- if sparse:
- return {
- "values": dtypes.int64,
- "indices": dtypes.int64,
- "dense_shape": dtypes.int64
- }
- return dtypes.int32
-
-
-def _get_record_shape(sparse):
- if sparse:
- return {
- "values": tensor_shape.TensorShape([None,]),
- "indices": tensor_shape.TensorShape([None, 1]),
- "dense_shape": tensor_shape.TensorShape([1,])
- }
- return tensor_shape.TensorShape([None])
-
-
-class BucketBySequenceLength(test_base.DatasetTestBase):
-
- def testBucket(self):
-
- boundaries = [10, 20, 30]
- batch_sizes = [10, 8, 4, 2]
- lengths = [8, 13, 25, 35]
-
- def build_dataset(sparse):
- def _generator():
- # Produce 1 batch for each bucket
- elements = []
- for batch_size, length in zip(batch_sizes, lengths):
- record_len = length - 1
- for _ in range(batch_size):
- elements.append([1] * record_len)
- record_len = length
- random.shuffle(elements)
- for el in elements:
- yield (_format_record(el, sparse),)
- dataset = dataset_ops.Dataset.from_generator(
- _generator,
- (_get_record_type(sparse),),
- (_get_record_shape(sparse),))
- if sparse:
- dataset = dataset.map(lambda x: (_to_sparse_tensor(x),))
- return dataset
-
- def _test_bucket_by_padding(no_padding):
- dataset = build_dataset(sparse=no_padding)
- dataset = dataset.apply(
- grouping.bucket_by_sequence_length(
- _element_length_fn,
- boundaries,
- batch_sizes,
- no_padding=no_padding))
- batch, = dataset.make_one_shot_iterator().get_next()
-
- with self.cached_session() as sess:
- batches = []
- for _ in range(4):
- batches.append(sess.run(batch))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(batch)
- batch_sizes_val = []
- lengths_val = []
- for batch in batches:
- shape = batch.dense_shape if no_padding else batch.shape
- batch_size = shape[0]
- length = shape[1]
- batch_sizes_val.append(batch_size)
- lengths_val.append(length)
- sum_check = batch.values.sum() if no_padding else batch.sum()
- self.assertEqual(sum_check, batch_size * length - 1)
- self.assertEqual(sum(batch_sizes_val), sum(batch_sizes))
- self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val))
- self.assertEqual(sorted(lengths), sorted(lengths_val))
-
- for no_padding in (True, False):
- _test_bucket_by_padding(no_padding)
-
- def testPadToBoundary(self):
-
- boundaries = [10, 20, 30]
- batch_sizes = [10, 8, 4, 2]
- lengths = [8, 13, 25]
-
- def element_gen():
- # Produce 1 batch for each bucket
- elements = []
- for batch_size, length in zip(batch_sizes[:-1], lengths):
- for _ in range(batch_size):
- elements.append([1] * length)
- random.shuffle(elements)
- for el in elements:
- yield (el,)
- for _ in range(batch_sizes[-1]):
- el = [1] * (boundaries[-1] + 5)
- yield (el,)
-
- element_len = lambda el: array_ops.shape(el)[0]
- dataset = dataset_ops.Dataset.from_generator(
- element_gen, (dtypes.int64,), ([None],)).apply(
- grouping.bucket_by_sequence_length(
- element_len, boundaries, batch_sizes,
- pad_to_bucket_boundary=True))
- batch, = dataset.make_one_shot_iterator().get_next()
-
- with self.cached_session() as sess:
- batches = []
- for _ in range(3):
- batches.append(sess.run(batch))
- with self.assertRaisesOpError("bucket_boundaries"):
- sess.run(batch)
- batch_sizes_val = []
- lengths_val = []
- for batch in batches:
- batch_size = batch.shape[0]
- length = batch.shape[1]
- batch_sizes_val.append(batch_size)
- lengths_val.append(length)
- batch_sizes = batch_sizes[:-1]
- self.assertEqual(sum(batch_sizes_val), sum(batch_sizes))
- self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val))
- self.assertEqual([boundary - 1 for boundary in sorted(boundaries)],
- sorted(lengths_val))
-
- def testPadToBoundaryNoExtraneousPadding(self):
-
- boundaries = [3, 7, 11]
- batch_sizes = [2, 2, 2, 2]
- lengths = range(1, 11)
-
- def element_gen():
- for length in lengths:
- yield ([1] * length,)
-
- element_len = lambda element: array_ops.shape(element)[0]
- dataset = dataset_ops.Dataset.from_generator(
- element_gen, (dtypes.int64,), ([None],)).apply(
- grouping.bucket_by_sequence_length(
- element_len, boundaries, batch_sizes,
- pad_to_bucket_boundary=True))
- batch, = dataset.make_one_shot_iterator().get_next()
-
- with self.cached_session() as sess:
- batches = []
- for _ in range(5):
- batches.append(sess.run(batch))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(batch)
-
- self.assertAllEqual(batches[0], [[1, 0],
- [1, 1]])
- self.assertAllEqual(batches[1], [[1, 1, 1, 0, 0, 0],
- [1, 1, 1, 1, 0, 0]])
- self.assertAllEqual(batches[2], [[1, 1, 1, 1, 1, 0],
- [1, 1, 1, 1, 1, 1]])
- self.assertAllEqual(batches[3], [[1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
- [1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
- self.assertAllEqual(batches[4], [[1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
- [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
-
- def testTupleElements(self):
-
- def build_dataset(sparse):
- def _generator():
- text = [[1, 2, 3], [3, 4, 5, 6, 7], [1, 2], [8, 9, 0, 2, 3]]
- label = [1, 2, 1, 2]
- for x, y in zip(text, label):
- yield (_format_record(x, sparse), y)
- dataset = dataset_ops.Dataset.from_generator(
- generator=_generator,
- output_types=(_get_record_type(sparse), dtypes.int32),
- output_shapes=(_get_record_shape(sparse),
- tensor_shape.TensorShape([])))
- if sparse:
- dataset = dataset.map(lambda x, y: (_to_sparse_tensor(x), y))
- return dataset
-
- def _test_tuple_elements_by_padding(no_padding):
- dataset = build_dataset(sparse=no_padding)
- dataset = dataset.apply(grouping.bucket_by_sequence_length(
- element_length_func=_element_length_fn,
- bucket_batch_sizes=[2, 2, 2],
- bucket_boundaries=[0, 8],
- no_padding=no_padding))
- shapes = dataset.output_shapes
- self.assertEqual([None, None], shapes[0].as_list())
- self.assertEqual([None], shapes[1].as_list())
-
- for no_padding in (True, False):
- _test_tuple_elements_by_padding(no_padding)
-
- def testBucketSparse(self):
- """Tests bucketing of sparse tensors (case where `no_padding` == True).
-
- Test runs on following dataset:
- [
- [0],
- [0, 1],
- [0, 1, 2]
- ...
- [0, ..., max_len - 1]
- ]
- Sequences are bucketed by length and batched with
- `batch_size` < `bucket_size`.
- """
-
- min_len = 0
- max_len = 100
- batch_size = 7
- bucket_size = 10
-
- def _build_dataset():
- input_data = [range(i+1) for i in range(min_len, max_len)]
- def generator_fn():
- for record in input_data:
- yield _format_record(record, sparse=True)
- dataset = dataset_ops.Dataset.from_generator(
- generator=generator_fn,
- output_types=_get_record_type(sparse=True))
- dataset = dataset.map(_to_sparse_tensor)
- return dataset
-
- def _compute_expected_batches():
- """Computes expected batch outputs and stores in a set."""
- all_expected_sparse_tensors = set()
- for bucket_start_len in range(min_len, max_len, bucket_size):
- for batch_offset in range(0, bucket_size, batch_size):
- batch_start_len = bucket_start_len + batch_offset
- batch_end_len = min(batch_start_len + batch_size,
- bucket_start_len + bucket_size)
- expected_indices = []
- expected_values = []
- for length in range(batch_start_len, batch_end_len):
- for val in range(length + 1):
- expected_indices.append((length - batch_start_len, val))
- expected_values.append(val)
- expected_sprs_tensor = (tuple(expected_indices),
- tuple(expected_values))
- all_expected_sparse_tensors.add(expected_sprs_tensor)
- return all_expected_sparse_tensors
-
- def _compute_batches(dataset):
- """Computes actual batch outputs of dataset and stores in a set."""
- batch = dataset.make_one_shot_iterator().get_next()
- all_sparse_tensors = set()
- with self.cached_session() as sess:
- with self.assertRaises(errors.OutOfRangeError):
- while True:
- output = sess.run(batch)
- sprs_tensor = (tuple([tuple(idx) for idx in output.indices]),
- tuple(output.values))
- all_sparse_tensors.add(sprs_tensor)
- return all_sparse_tensors
-
- dataset = _build_dataset()
- boundaries = range(min_len + bucket_size + 1, max_len, bucket_size)
- dataset = dataset.apply(grouping.bucket_by_sequence_length(
- _element_length_fn,
- boundaries,
- [batch_size] * (len(boundaries) + 1),
- no_padding=True))
- batches = _compute_batches(dataset)
- expected_batches = _compute_expected_batches()
- self.assertEqual(batches, expected_batches)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/prefetching_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/copy_to_device_test.py
index 7d7b842c17..adfacf1c9f 100644
--- a/tensorflow/python/data/experimental/kernel_tests/prefetching_ops_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/copy_to_device_test.py
@@ -12,440 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for prefetching_ops."""
+"""Tests for `tf.data.experimental.copy_to_device()`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import threading
-
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.compat import compat
from tensorflow.python.data.experimental.ops import prefetching_ops
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
-from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
-from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
-from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
-class PrefetchingKernelsOpsTest(test_base.DatasetTestBase):
-
- def setUp(self):
- self._event = threading.Event()
-
- def _create_ds_and_iterator(self, device0, initializable=False):
-
- def gen():
- for i in range(1, 10):
- yield [float(i)]
- if i == 6:
- self._event.set()
-
- with ops.device(device0):
- ds = dataset_ops.Dataset.from_generator(gen, (dtypes.float32))
- if initializable:
- ds_iterator = ds.make_initializable_iterator()
- else:
- ds_iterator = ds.make_one_shot_iterator()
- return (ds, ds_iterator)
-
- def _create_ops(self, ds, ds_iterator, buffer_name, device0, device1):
- ds_iterator_handle = ds_iterator.string_handle()
-
- @function.Defun(dtypes.string)
- def _remote_fn(h):
- remote_iterator = iterator_ops.Iterator.from_string_handle(
- h, ds.output_types, ds.output_shapes)
- return remote_iterator.get_next()
-
- target = constant_op.constant(device0)
- with ops.device(device1):
- buffer_resource_handle = prefetching_ops.function_buffering_resource(
- f=_remote_fn,
- output_types=[dtypes.float32],
- target_device=target,
- string_arg=ds_iterator_handle,
- buffer_size=3,
- shared_name=buffer_name)
-
- with ops.device(device1):
- prefetch_op = prefetching_ops.function_buffering_resource_get_next(
- function_buffer_resource=buffer_resource_handle,
- output_types=[dtypes.float32])
- reset_op = prefetching_ops.function_buffering_resource_reset(
- function_buffer_resource=buffer_resource_handle)
- destroy_op = resource_variable_ops.destroy_resource_op(
- buffer_resource_handle, ignore_lookup_error=True)
-
- return (prefetch_op, reset_op, destroy_op)
-
- def _prefetch_fn_helper_one_shot(self, buffer_name, device0, device1):
- worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
-
- ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=False)
- prefetch_op, _, destroy_op = self._create_ops(ds, ds_iterator, buffer_name,
- device0, device1)
-
- with self.test_session(config=worker_config) as sess:
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [1.0])
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [2.0])
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [3.0])
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [4.0])
- self._event.wait()
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [5.0])
- sess.run(destroy_op)
-
- def testSameDeviceCPU(self):
- self._prefetch_fn_helper_one_shot("same_device_cpu",
- "/job:localhost/replica:0/task:0/cpu:0",
- "/job:localhost/replica:0/task:0/cpu:0")
-
- def testDifferentDeviceCPU(self):
- self._prefetch_fn_helper_one_shot("diff_device_cpu",
- "/job:localhost/replica:0/task:0/cpu:0",
- "/job:localhost/replica:0/task:0/cpu:1")
-
- def testDifferentDeviceCPUGPU(self):
- if not test_util.is_gpu_available():
- self.skipTest("No GPU available")
-
- self._prefetch_fn_helper_one_shot("cpu_gpu",
- "/job:localhost/replica:0/task:0/cpu:0",
- "/job:localhost/replica:0/task:0/gpu:0")
-
- def testReinitialization(self):
- worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
-
- device0 = "/job:localhost/replica:0/task:0/cpu:0"
- device1 = "/job:localhost/replica:0/task:0/cpu:1"
- ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=True)
- prefetch_op, reset_op, destroy_op = self._create_ops(
- ds, ds_iterator, "reinit", device0, device1)
-
- with self.test_session(config=worker_config) as sess:
- sess.run(ds_iterator.initializer)
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [1.0])
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [2.0])
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [3.0])
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [4.0])
- self._event.wait()
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [5.0])
- # Lets reset the function buffering resource and reinitialize the
- # iterator. Should be able to go through this again.
- self._event.clear()
- sess.run(reset_op)
- sess.run(ds_iterator.initializer)
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [1.0])
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [2.0])
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [3.0])
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [4.0])
- self._event.wait()
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [5.0])
- sess.run(destroy_op)
-
- def testReinitializationOutOfRange(self):
- worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
-
- device0 = "/job:localhost/replica:0/task:0/cpu:0"
- device1 = "/job:localhost/replica:0/task:0/cpu:1"
- ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=True)
- prefetch_op, reset_op, destroy_op = self._create_ops(
- ds, ds_iterator, "reinit", device0, device1)
-
- with self.test_session(config=worker_config) as sess:
- sess.run(ds_iterator.initializer)
- for i in range(1, 10):
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [float(i)])
- # Try fetching after its over twice to test out end of sequence.
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(prefetch_op)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(prefetch_op)
-
- # Now reset everything and try it out again.
- self._event.clear()
- sess.run(reset_op)
- sess.run(ds_iterator.initializer)
- for i in range(1, 10):
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [float(i)])
- # Try fetching after its over twice to test out end of sequence.
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(prefetch_op)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(prefetch_op)
-
- sess.run(destroy_op)
-
- def testStringsGPU(self):
- if not test_util.is_gpu_available():
- self.skipTest("No GPU available")
-
- device0 = "/job:localhost/replica:0/task:0/cpu:0"
- device1 = "/job:localhost/replica:0/task:0/gpu:0"
-
- ds = dataset_ops.Dataset.from_tensor_slices(["a", "b", "c"])
- ds_iterator = ds.make_one_shot_iterator()
- ds_iterator_handle = ds_iterator.string_handle()
-
- @function.Defun(dtypes.string)
- def _remote_fn(h):
- remote_iterator = iterator_ops.Iterator.from_string_handle(
- h, ds.output_types, ds.output_shapes)
- return remote_iterator.get_next()
-
- target = constant_op.constant(device0)
- with ops.device(device1):
- buffer_resource_handle = prefetching_ops.function_buffering_resource(
- f=_remote_fn,
- output_types=[dtypes.string],
- target_device=target,
- string_arg=ds_iterator_handle,
- buffer_size=3,
- shared_name="strings")
-
- with ops.device(device1):
- prefetch_op = prefetching_ops.function_buffering_resource_get_next(
- function_buffer_resource=buffer_resource_handle,
- output_types=[dtypes.string])
- destroy_op = resource_variable_ops.destroy_resource_op(
- buffer_resource_handle, ignore_lookup_error=True)
-
- with self.cached_session() as sess:
- self.assertEqual([b"a"], sess.run(prefetch_op))
- self.assertEqual([b"b"], sess.run(prefetch_op))
- self.assertEqual([b"c"], sess.run(prefetch_op))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(prefetch_op)
-
- sess.run(destroy_op)
-
-
-class PrefetchToDeviceTest(test_base.DatasetTestBase):
-
- def testPrefetchToDevice(self):
- host_dataset = dataset_ops.Dataset.range(10)
- device_dataset = host_dataset.apply(
- prefetching_ops.prefetch_to_device("/cpu:1"))
-
- # NOTE(mrry): This device block creates the "host" dataset and iterator on
- # /cpu:0, and ensures that the prefetching is across devices. In typical use
- # this would not be necessary, because the GPU device would not support any
- # of the dataset-related ops.
- with ops.device("/cpu:0"):
- iterator = device_dataset.make_one_shot_iterator()
-
- self.assertEqual(host_dataset.output_types, device_dataset.output_types)
- self.assertEqual(host_dataset.output_types, iterator.output_types)
- self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
- self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
- self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
- self.assertEqual(host_dataset.output_classes, iterator.output_classes)
-
- next_element = iterator.get_next()
- self.assertEqual(dtypes.int64, next_element.dtype)
- self.assertEqual([], next_element.shape)
-
- worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
- with self.test_session(config=worker_config) as sess:
- for i in range(10):
- self.assertEqual(i, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testPrefetchToSameDevice(self):
- host_dataset = dataset_ops.Dataset.range(10)
- device_dataset = host_dataset.apply(
- prefetching_ops.prefetch_to_device(
- "/job:localhost/replica:0/task:0/device:CPU:0"))
-
- # NOTE(mrry): This device block creates the "host" dataset and iterator on
- # /cpu:0, and ensures that the prefetching is across devices. In typical use
- # this would not be necessary, because the GPU device would not support any
- # of the dataset-related ops.
- with ops.device("/cpu:0"):
- iterator = device_dataset.make_one_shot_iterator()
-
- self.assertEqual(host_dataset.output_types, device_dataset.output_types)
- self.assertEqual(host_dataset.output_types, iterator.output_types)
- self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
- self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
- self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
- self.assertEqual(host_dataset.output_classes, iterator.output_classes)
-
- next_element = iterator.get_next()
- self.assertEqual(dtypes.int64, next_element.dtype)
- self.assertEqual([], next_element.shape)
-
- with self.cached_session() as sess:
- for i in range(10):
- self.assertEqual(i, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testPrefetchDictToDevice(self):
- host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
- device_dataset = host_dataset.apply(
- prefetching_ops.prefetch_to_device("/cpu:1"))
-
- # NOTE(mrry): This device block creates the "host" dataset and iterator on
- # /cpu:0, and ensures that the prefetching is across devices. In typical use
- # this would not be necessary, because the GPU device would not support any
- # of the dataset-related ops.
- with ops.device("/cpu:0"):
- iterator = device_dataset.make_one_shot_iterator()
-
- self.assertEqual(host_dataset.output_types, device_dataset.output_types)
- self.assertEqual(host_dataset.output_types, iterator.output_types)
- self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
- self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
- self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
- self.assertEqual(host_dataset.output_classes, iterator.output_classes)
-
- next_element = iterator.get_next()
- self.assertEqual(dtypes.int64, next_element["a"].dtype)
- self.assertEqual([], next_element["a"].shape)
-
- worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
- with self.test_session(config=worker_config) as sess:
- for i in range(10):
- self.assertEqual({"a": i}, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testPrefetchSparseTensorsToDevice(self):
- def make_tensor(i):
- return sparse_tensor.SparseTensorValue(
- indices=[[0, 0]], values=(i*[1]), dense_shape=[2, 2])
- host_dataset = dataset_ops.Dataset.range(10).map(make_tensor)
-
- device_dataset = host_dataset.apply(
- prefetching_ops.prefetch_to_device("/cpu:1"))
-
- # NOTE(mrry): This device block creates the "host" dataset and iterator on
- # /cpu:0, and ensures that the prefetching is across devices. In typical use
- # this would not be necessary, because the GPU device would not support any
- # of the dataset-related ops.
- with ops.device("/cpu:0"):
- iterator = device_dataset.make_one_shot_iterator()
-
- self.assertEqual(host_dataset.output_types, device_dataset.output_types)
- self.assertEqual(host_dataset.output_types, iterator.output_types)
- self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
- self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
- self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
- self.assertEqual(host_dataset.output_classes, iterator.output_classes)
-
- next_element = iterator.get_next()
- self.assertEqual(dtypes.int64, next_element.dtype)
-
- worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
- with self.test_session(config=worker_config) as sess:
- for i in range(10):
- actual = sess.run(next_element)
- self.assertAllEqual([i], actual.values)
- self.assertAllEqual([[0, 0]], actual.indices)
- self.assertAllEqual([2, 2], actual.dense_shape)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testPrefetchToDeviceGpu(self):
- if not test_util.is_gpu_available():
- self.skipTest("No GPU available")
-
- host_dataset = dataset_ops.Dataset.range(10)
- device_dataset = host_dataset.apply(
- prefetching_ops.prefetch_to_device("/gpu:0"))
-
- iterator = device_dataset.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- for i in range(10):
- self.assertEqual(i, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testPrefetchToDeviceWithReInit(self):
- host_dataset = dataset_ops.Dataset.range(10)
- device_dataset = host_dataset.apply(
- prefetching_ops.prefetch_to_device("/cpu:1"))
-
- # NOTE(mrry): This device block creates the "host" dataset and iterator on
- # /cpu:0, and ensures that the prefetching is across devices. In typical use
- # this would not be necessary, because the GPU device would not support any
- # of the dataset-related ops.
- with ops.device("/cpu:0"):
- iterator = device_dataset.make_initializable_iterator()
-
- self.assertEqual(host_dataset.output_types, device_dataset.output_types)
- self.assertEqual(host_dataset.output_types, iterator.output_types)
- self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
- self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
- self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
- self.assertEqual(host_dataset.output_classes, iterator.output_classes)
-
- next_element = iterator.get_next()
- self.assertEqual(dtypes.int64, next_element.dtype)
- self.assertEqual([], next_element.shape)
-
- worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
- with self.test_session(config=worker_config) as sess:
- sess.run(iterator.initializer)
- for i in range(5):
- self.assertEqual(i, sess.run(next_element))
- sess.run(iterator.initializer)
- for i in range(10):
- self.assertEqual(i, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testPrefetchToDeviceGpuWithReInit(self):
- if not test_util.is_gpu_available():
- self.skipTest("No GPU available")
-
- host_dataset = dataset_ops.Dataset.range(10)
- device_dataset = host_dataset.apply(
- prefetching_ops.prefetch_to_device("/gpu:0"))
-
- iterator = device_dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(iterator.initializer)
- for i in range(5):
- self.assertEqual(i, sess.run(next_element))
- sess.run(iterator.initializer)
- for i in range(10):
- self.assertEqual(i, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
-
class CopyToDeviceTest(test_base.DatasetTestBase):
def testCopyToDevice(self):
diff --git a/tensorflow/python/data/experimental/kernel_tests/counter_test.py b/tensorflow/python/data/experimental/kernel_tests/counter_test.py
new file mode 100644
index 0000000000..4e114ac479
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/counter_test.py
@@ -0,0 +1,51 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.experimental.Counter`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.ops import counter
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.framework import dtypes
+from tensorflow.python.platform import test
+
+
+class CounterTest(test_base.DatasetTestBase):
+
+ def testCounter(self):
+ """Test dataset construction using `count`."""
+ iterator = (counter.Counter(start=3, step=4)
+ .make_one_shot_iterator())
+ get_next = iterator.get_next()
+ self.assertEqual([], get_next.shape.as_list())
+ self.assertEqual(dtypes.int64, get_next.dtype)
+
+ negative_iterator = (counter.Counter(start=0, step=-1)
+ .make_one_shot_iterator())
+ negative_get_next = negative_iterator.get_next()
+
+ with self.cached_session() as sess:
+ self.assertEqual(3, sess.run(get_next))
+ self.assertEqual(3 + 4, sess.run(get_next))
+ self.assertEqual(3 + 2 * 4, sess.run(get_next))
+
+ self.assertEqual(0, sess.run(negative_get_next))
+ self.assertEqual(-1, sess.run(negative_get_next))
+ self.assertEqual(-2, sess.run(negative_get_next))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/csv_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/csv_dataset_test.py
index 4ee1779710..fb75be1fbc 100644
--- a/tensorflow/python/data/experimental/kernel_tests/csv_dataset_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/csv_dataset_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for CsvDatasetOp."""
+"""Tests for `tf.data.experimental.CsvDataset`."""
from __future__ import absolute_import
from __future__ import division
@@ -44,7 +44,7 @@ from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
-class CsvDatasetOpTest(test_base.DatasetTestBase):
+class CsvDatasetTest(test_base.DatasetTestBase):
def _setup_files(self, inputs, linebreak='\n', compression_type=None):
filenames = []
diff --git a/tensorflow/python/data/experimental/kernel_tests/dataset_serialization_test_base.py b/tensorflow/python/data/experimental/kernel_tests/dataset_serialization_test_base.py
deleted file mode 100644
index 7f435b8239..0000000000
--- a/tensorflow/python/data/experimental/kernel_tests/dataset_serialization_test_base.py
+++ /dev/null
@@ -1,692 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Base class for testing serializable datasets."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import os
-
-import numpy as np
-
-from tensorflow.python.data.experimental.ops import iterator_ops as contrib_iterator_ops
-from tensorflow.python.data.ops import iterator_ops
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.ops import lookup_ops
-from tensorflow.python.ops import variables
-from tensorflow.python.platform import gfile
-from tensorflow.python.platform import test
-from tensorflow.python.training import checkpoint_management
-from tensorflow.python.training import saver as saver_lib
-from tensorflow.python.util import nest
-
-
-def remove_variants(get_next_op):
- # TODO(b/72408568): Remove this once session.run can get
- # variant tensors.
- """Remove variants from a nest structure, so sess.run will execute."""
-
- def _remove_variant(x):
- if isinstance(x, ops.Tensor) and x.dtype == dtypes.variant:
- return ()
- else:
- return x
-
- return nest.map_structure(_remove_variant, get_next_op)
-
-
-class DatasetSerializationTestBase(test.TestCase):
- """Base class for testing serializable datasets."""
-
- def tearDown(self):
- self._delete_ckpt()
-
- # TODO(b/72657739): Remove sparse_tensor argument, which is to test the
- # (deprecated) saveable `SparseTensorSliceDataset`, once the API
- # `from_sparse_tensor_slices()`and related tests are deleted.
- def run_core_tests(self, ds_fn1, ds_fn2, num_outputs, sparse_tensors=False):
- """Runs the core tests.
-
- Args:
- ds_fn1: 0-argument function that returns a Dataset.
- ds_fn2: 0-argument function that returns a Dataset different from
- ds_fn1. If None, verify_restore_in_modified_graph test is not run.
- num_outputs: Total number of outputs expected from this Dataset.
- sparse_tensors: Whether dataset is built from SparseTensor(s).
-
- Raises:
- AssertionError if any test fails.
- """
- self.verify_unused_iterator(
- ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
- self.verify_fully_used_iterator(
- ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
- self.verify_exhausted_iterator(
- ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
- self.verify_init_before_restore(
- ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
- self.verify_multiple_breaks(
- ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
- self.verify_reset_restored_iterator(
- ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
- self.verify_restore_in_empty_graph(
- ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
- if ds_fn2:
- self.verify_restore_in_modified_graph(
- ds_fn1, ds_fn2, num_outputs, sparse_tensors=sparse_tensors)
-
- def verify_unused_iterator(self,
- ds_fn,
- num_outputs,
- sparse_tensors=False,
- verify_exhausted=True):
- """Verifies that saving and restoring an unused iterator works.
-
- Args:
- ds_fn: See `run_core_tests`.
- num_outputs: See `run_core_tests`.
- sparse_tensors: See `run_core_tests`.
- verify_exhausted: See `gen_outputs`.
-
- Raises:
- AssertionError if any test fails.
- """
- self.verify_run_with_breaks(
- ds_fn, [0],
- num_outputs,
- sparse_tensors=sparse_tensors,
- verify_exhausted=verify_exhausted)
-
- def verify_fully_used_iterator(self, ds_fn, num_outputs,
- sparse_tensors=False):
- """Verifies that saving and restoring a fully used iterator works.
-
- Note that this only checks saving and restoring an iterator from which
- `num_outputs` items have been produced but does not check for an
- exhausted iterator, i.e., one from which an OutOfRange error has been
- returned.
-
- Args:
- ds_fn: See `run_core_tests`.
- num_outputs: See `run_core_tests`.
- sparse_tensors: See `run_core_tests`.
-
- Raises:
- AssertionError if test fails.
- """
- self.verify_run_with_breaks(
- ds_fn, [num_outputs], num_outputs, sparse_tensors=sparse_tensors)
-
- def verify_exhausted_iterator(self, ds_fn, num_outputs, sparse_tensors=False):
- """Verifies that saving and restoring an exhausted iterator works.
-
- An exhausted iterator is one which has returned an OutOfRange error.
-
- Args:
- ds_fn: See `run_core_tests`.
- num_outputs: See `run_core_tests`.
- sparse_tensors: See `run_core_tests`.
-
- Raises:
- AssertionError if any test fails.
- """
- self.gen_outputs(
- ds_fn, [],
- num_outputs,
- verify_exhausted=True,
- sparse_tensors=sparse_tensors)
- actual = self.gen_outputs(
- ds_fn, [],
- 0,
- ckpt_saved=True,
- verify_exhausted=True,
- sparse_tensors=sparse_tensors)
- self.assertEqual(len(actual), 0)
-
- def verify_init_before_restore(self,
- ds_fn,
- num_outputs,
- sparse_tensors=False,
- verify_exhausted=True):
- """Verifies that restoring into an already initialized iterator works.
-
- Args:
- ds_fn: See `run_core_tests`.
- num_outputs: See `run_core_tests`.
- sparse_tensors: See `run_core_tests`.
- verify_exhausted: See `gen_outputs`.
-
- Raises:
- AssertionError if any test fails.
- """
- self.verify_run_with_breaks(
- ds_fn,
- self.gen_break_points(num_outputs),
- num_outputs,
- init_before_restore=True,
- sparse_tensors=sparse_tensors,
- verify_exhausted=verify_exhausted)
-
- def verify_multiple_breaks(self,
- ds_fn,
- num_outputs,
- num_breaks=10,
- sparse_tensors=False,
- verify_exhausted=True):
- """Attempts to save/restore at multiple break points.
-
- Args:
- ds_fn: See `run_core_tests`.
- num_outputs: See `run_core_tests`.
- num_breaks: The number of break points. These are uniformly spread in
- [0, num_outputs] both inclusive.
- sparse_tensors: See `run_core_tests`.
- verify_exhausted: See `gen_outputs`.
-
- Raises:
- AssertionError if any test fails.
- """
- self.verify_run_with_breaks(
- ds_fn,
- self.gen_break_points(num_outputs, num_breaks),
- num_outputs,
- sparse_tensors=sparse_tensors,
- verify_exhausted=verify_exhausted)
-
- def verify_reset_restored_iterator(self,
- ds_fn,
- num_outputs,
- break_point=None,
- sparse_tensors=False,
- verify_exhausted=True):
- """Attempts to re-initialize a restored iterator.
-
- This is useful when restoring a training checkpoint during validation.
-
- Args:
- ds_fn: See `run_core_tests`.
- num_outputs: See `run_core_tests`.
- break_point: Break point. Optional. Defaults to num_outputs/2.
- sparse_tensors: See `run_core_tests`.
- verify_exhausted: See `gen_outputs`.
-
- Raises:
- AssertionError if any test fails.
- """
- break_point = num_outputs // 2 if not break_point else break_point
-
- # Collect ground truth containing all outputs.
- expected = self.gen_outputs(
- ds_fn, [],
- num_outputs,
- sparse_tensors=sparse_tensors,
- verify_exhausted=verify_exhausted)
-
- # Skip some items and save checkpoint.
- self.gen_outputs(
- ds_fn, [],
- break_point,
- sparse_tensors=sparse_tensors,
- verify_exhausted=False)
-
- actual = []
- # Restore from checkpoint and then run init_op.
- with ops.Graph().as_default() as g:
- saver = self._import_meta_graph()
- init_op, get_next_op = self._get_iterator_ops_from_collection(
- ds_fn, sparse_tensors=sparse_tensors)
- get_next_op = remove_variants(get_next_op)
- with self.session(graph=g) as sess:
- self._restore(saver, sess)
- self._initialize(init_op, sess)
- for _ in range(num_outputs):
- actual.append(sess.run(get_next_op))
- if verify_exhausted:
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next_op)
- self.match(expected, actual)
-
- def verify_restore_in_modified_graph(self,
- ds_fn1,
- ds_fn2,
- num_outputs,
- break_point=None,
- sparse_tensors=False,
- verify_exhausted=True):
- """Attempts to restore an iterator in a modified graph.
-
- Builds an input pipeline using ds_fn1, runs it for `break_point` steps
- and saves a checkpoint. Then builds a new graph using ds_fn2, restores
- the checkpoint from ds_fn1 and verifies that the restore is successful.
-
- Args:
- ds_fn1: See `run_core_tests`.
- ds_fn2: See `run_core_tests`.
- num_outputs: See `run_core_tests`.
- break_point: Break point. Optional. Defaults to num_outputs/2.
- sparse_tensors: See `run_core_tests`.
- verify_exhausted: See `gen_outputs`.
-
- Raises:
- AssertionError if any test fails.
- """
- break_point = num_outputs // 2 if not break_point else break_point
-
- # Skip `break_point` items and store the remaining produced from ds_fn1
- # in `expected`.
- self.gen_outputs(
- ds_fn1, [],
- break_point,
- sparse_tensors=sparse_tensors,
- verify_exhausted=False)
- expected = self.gen_outputs(
- ds_fn1, [],
- num_outputs - break_point,
- ckpt_saved=True,
- sparse_tensors=sparse_tensors,
- verify_exhausted=verify_exhausted)
-
- # Generate `break_point` items from ds_fn1 and save checkpoint.
- self.gen_outputs(
- ds_fn1, [],
- break_point,
- sparse_tensors=sparse_tensors,
- verify_exhausted=False)
-
- actual = []
- # Build graph for ds_fn2 but load checkpoint for ds_fn1.
- with ops.Graph().as_default() as g:
- _, get_next_op, saver = self._build_graph(
- ds_fn2, sparse_tensors=sparse_tensors)
- get_next_op = remove_variants(get_next_op)
- with self.session(graph=g) as sess:
- self._restore(saver, sess)
- for _ in range(num_outputs - break_point):
- actual.append(sess.run(get_next_op))
- if verify_exhausted:
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next_op)
-
- self.match(expected, actual)
-
- def verify_restore_in_empty_graph(self,
- ds_fn,
- num_outputs,
- break_point=None,
- sparse_tensors=False,
- verify_exhausted=True):
- """Attempts to restore an iterator in an empty graph.
-
- Builds an input pipeline using ds_fn, runs it for `break_point` steps
- and saves a checkpoint. Then builds a new empty graph, restores
- the checkpoint from ds_fn and verifies that the restore is successful.
-
- Args:
- ds_fn: See `run_core_tests`.
- num_outputs: See `run_core_tests`.
- break_point: Break point. Optional. Defaults to num_outputs/2.
- sparse_tensors: See `run_core_tests`.
- verify_exhausted: See `gen_outputs`.
-
- Raises:
- AssertionError if any test fails.
- """
- break_point = num_outputs // 2 if not break_point else break_point
-
- # Skip `break_point` items and store the remaining produced from ds_fn
- # in `expected`.
- self.gen_outputs(
- ds_fn, [],
- break_point,
- sparse_tensors=sparse_tensors,
- verify_exhausted=False)
- expected = self.gen_outputs(
- ds_fn, [],
- num_outputs - break_point,
- ckpt_saved=True,
- sparse_tensors=sparse_tensors,
- verify_exhausted=verify_exhausted)
-
- # Generate `break_point` items from ds_fn and save checkpoint.
- self.gen_outputs(
- ds_fn, [],
- break_point,
- sparse_tensors=sparse_tensors,
- verify_exhausted=False)
-
- actual = []
- # Build an empty graph but load checkpoint for ds_fn.
- with ops.Graph().as_default() as g:
- get_next_op, saver = self._build_empty_graph(
- ds_fn, sparse_tensors=sparse_tensors)
- get_next_op = remove_variants(get_next_op)
- with self.session(graph=g) as sess:
- self._restore(saver, sess)
- for _ in range(num_outputs - break_point):
- actual.append(sess.run(get_next_op))
- if verify_exhausted:
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next_op)
-
- self.match(expected, actual)
-
- def verify_error_on_save(self,
- ds_fn,
- num_outputs,
- error,
- break_point=None,
- sparse_tensors=False):
- """Attempts to save a non-saveable iterator.
-
- Args:
- ds_fn: See `run_core_tests`.
- num_outputs: See `run_core_tests`.
- error: Declared error when trying to save iterator.
- break_point: Break point. Optional. Defaults to num_outputs/2.
- sparse_tensors: See `run_core_tests`.
-
- Raises:
- AssertionError if any test fails.
- """
-
- break_point = num_outputs // 2 if not break_point else break_point
- with ops.Graph().as_default() as g:
- init_op, get_next_op, saver = self._build_graph(
- ds_fn, sparse_tensors=sparse_tensors)
- get_next_op = remove_variants(get_next_op)
- with self.session(graph=g) as sess:
- self._initialize(init_op, sess)
- for _ in range(break_point):
- sess.run(get_next_op)
- with self.assertRaises(error):
- self._save(sess, saver)
-
- def verify_run_with_breaks(self,
- ds_fn,
- break_points,
- num_outputs,
- init_before_restore=False,
- sparse_tensors=False,
- verify_exhausted=True):
- """Verifies that ds_fn() produces the same outputs with and without breaks.
-
- 1. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it
- *without* stopping at break points.
- 2. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it
- with stopping at break points.
-
- Deep matches outputs from 1 and 2.
-
- Args:
- ds_fn: See `gen_outputs`.
- break_points: See `gen_outputs`.
- num_outputs: See `gen_outputs`.
- init_before_restore: See `gen_outputs`.
- sparse_tensors: See `run_core_tests`.
- verify_exhausted: See `gen_outputs`.
-
- Raises:
- AssertionError if any test fails.
- """
- expected = self.gen_outputs(
- ds_fn, [],
- num_outputs,
- init_before_restore=init_before_restore,
- sparse_tensors=sparse_tensors,
- verify_exhausted=verify_exhausted)
-
- actual = self.gen_outputs(
- ds_fn,
- break_points,
- num_outputs,
- init_before_restore=init_before_restore,
- sparse_tensors=sparse_tensors,
- verify_exhausted=verify_exhausted)
-
- self.match(expected, actual)
-
- def gen_outputs(self,
- ds_fn,
- break_points,
- num_outputs,
- ckpt_saved=False,
- init_before_restore=False,
- sparse_tensors=False,
- verify_exhausted=True,
- save_checkpoint_at_end=True):
- """Generates elements from input dataset while stopping at break points.
-
- Produces `num_outputs` outputs and saves the state of the iterator in the
- Saver checkpoint.
-
- Args:
- ds_fn: 0-argument function that returns the dataset.
- break_points: A list of integers. For each `break_point` in
- `break_points`, we produce outputs till `break_point` number of items
- have been produced and then checkpoint the state. The current graph
- and session are destroyed and a new graph and session are used to
- produce outputs till next checkpoint or till `num_outputs` elements
- have been produced. `break_point` must be <= `num_outputs`.
- num_outputs: The total number of outputs to produce from the iterator.
- ckpt_saved: Whether a checkpoint already exists. If False, we build the
- graph from ds_fn.
- init_before_restore: Whether init should be called before saver.restore.
- This is just so that we can verify that restoring an already initialized
- iterator works.
- sparse_tensors: Whether dataset is built from SparseTensor(s).
- verify_exhausted: Whether to verify that the iterator has been exhausted
- after producing `num_outputs` elements.
- save_checkpoint_at_end: Whether to save a checkpoint after producing all
- outputs. If False, checkpoints are saved each break point but not at the
- end. Note that checkpoints overwrite each other so there is always only
- a single checkpoint available. Defaults to True.
-
- Returns:
- A list of `num_outputs` items.
- """
- outputs = []
-
- def get_ops():
- if ckpt_saved:
- saver = self._import_meta_graph()
- init_op, get_next_op = self._get_iterator_ops_from_collection(
- ds_fn, sparse_tensors=sparse_tensors)
- else:
- init_op, get_next_op, saver = self._build_graph(
- ds_fn, sparse_tensors=sparse_tensors)
- return init_op, get_next_op, saver
-
- for i in range(len(break_points) + 1):
- with ops.Graph().as_default() as g:
- init_op, get_next_op, saver = get_ops()
- get_next_op = remove_variants(get_next_op)
- with self.session(graph=g) as sess:
- if ckpt_saved:
- if init_before_restore:
- self._initialize(init_op, sess)
- self._restore(saver, sess)
- else:
- self._initialize(init_op, sess)
- start = break_points[i - 1] if i > 0 else 0
- end = break_points[i] if i < len(break_points) else num_outputs
- num_iters = end - start
- for _ in range(num_iters):
- outputs.append(sess.run(get_next_op))
- if i == len(break_points) and verify_exhausted:
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next_op)
- if save_checkpoint_at_end or i < len(break_points):
- self._save(sess, saver)
- ckpt_saved = True
-
- return outputs
-
- def match(self, expected, actual):
- """Matches nested structures.
-
- Recursively matches shape and values of `expected` and `actual`.
- Handles scalars, numpy arrays and other python sequence containers
- e.g. list, dict.
-
- Args:
- expected: Nested structure 1.
- actual: Nested structure 2.
-
- Raises:
- AssertionError if matching fails.
- """
- if isinstance(expected, np.ndarray):
- expected = expected.tolist()
- if isinstance(actual, np.ndarray):
- actual = actual.tolist()
- self.assertEqual(type(expected), type(actual))
-
- if nest.is_sequence(expected):
- self.assertEqual(len(expected), len(actual))
- if isinstance(expected, dict):
- for key1, key2 in zip(sorted(expected), sorted(actual)):
- self.assertEqual(key1, key2)
- self.match(expected[key1], actual[key2])
- else:
- for item1, item2 in zip(expected, actual):
- self.match(item1, item2)
- else:
- self.assertEqual(expected, actual)
-
- def does_not_match(self, expected, actual):
- with self.assertRaises(AssertionError):
- self.match(expected, actual)
-
- def gen_break_points(self, num_outputs, num_samples=10):
- """Generates `num_samples` breaks points in [0, num_outputs]."""
- return np.linspace(0, num_outputs, num_samples, dtype=int)
-
- def _build_graph(self, ds_fn, sparse_tensors=False):
- iterator = ds_fn().make_initializable_iterator()
-
- saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
- ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
- init_op = iterator.initializer
- if sparse_tensors:
- get_next = sparse_tensor.SparseTensor(*iterator.get_next())
- else:
- get_next = iterator.get_next()
- self._add_iterator_ops_to_collection(init_op, get_next, ds_fn,
- sparse_tensors)
- saver = saver_lib.Saver(allow_empty=True)
- return init_op, get_next, saver
-
- def _build_empty_graph(self, ds_fn, sparse_tensors=False):
- iterator = iterator_ops.Iterator.from_structure(
- self._get_output_types(ds_fn),
- output_shapes=self._get_output_shapes(ds_fn),
- output_classes=self._get_output_classes(ds_fn))
- saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
- ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
- if sparse_tensors:
- get_next = sparse_tensor.SparseTensor(*iterator.get_next())
- else:
- get_next = iterator.get_next()
- saver = saver_lib.Saver(allow_empty=True)
- return get_next, saver
-
- def _add_iterator_ops_to_collection(self,
- init_op,
- get_next,
- ds_fn,
- sparse_tensors=False):
- ops.add_to_collection("iterator_ops", init_op)
- # `get_next` may be a tuple e.g. in TensorSliceDataset. Since Collections
- # do not support tuples we flatten the tensors and restore the shape in
- # `_get_iterator_ops_from_collection`.
- if sparse_tensors: # specific for deprecated `from_sparse_tensor_slices`.
- ops.add_to_collection("iterator_ops", get_next.indices)
- ops.add_to_collection("iterator_ops", get_next.values)
- ops.add_to_collection("iterator_ops", get_next.dense_shape)
- return
-
- get_next_list = nest.flatten(get_next)
- for i, output_class in enumerate(
- nest.flatten(self._get_output_classes(ds_fn))):
- if output_class is sparse_tensor.SparseTensor:
- ops.add_to_collection("iterator_ops", get_next_list[i].indices)
- ops.add_to_collection("iterator_ops", get_next_list[i].values)
- ops.add_to_collection("iterator_ops", get_next_list[i].dense_shape)
- else:
- ops.add_to_collection("iterator_ops", get_next_list[i])
-
- def _get_iterator_ops_from_collection(self, ds_fn, sparse_tensors=False):
- all_ops = ops.get_collection("iterator_ops")
- if sparse_tensors: # specific for deprecated `from_sparse_tensor_slices`.
- init_op, indices, values, dense_shape = all_ops
- return init_op, sparse_tensor.SparseTensor(indices, values, dense_shape)
- get_next_list = []
- i = 1
- for output_class in nest.flatten(self._get_output_classes(ds_fn)):
- if output_class is sparse_tensor.SparseTensor:
- indices, values, dense_shape = all_ops[i:i + 3]
- i += 3
- get_next_list.append(
- sparse_tensor.SparseTensor(indices, values, dense_shape))
- else:
- get_next_list.append(all_ops[i])
- i += 1
- return all_ops[0], nest.pack_sequence_as(
- self._get_output_types(ds_fn), get_next_list)
-
- def _get_output_types(self, ds_fn):
- with ops.Graph().as_default():
- return ds_fn().output_types
-
- def _get_output_shapes(self, ds_fn):
- with ops.Graph().as_default():
- return ds_fn().output_shapes
-
- def _get_output_classes(self, ds_fn):
- with ops.Graph().as_default():
- return ds_fn().output_classes
-
- def _ckpt_path(self):
- return os.path.join(self.get_temp_dir(), "iterator")
-
- def _latest_ckpt(self):
- return checkpoint_management.latest_checkpoint(self.get_temp_dir())
-
- def _save(self, sess, saver):
- saver.save(sess, self._ckpt_path())
-
- def _restore(self, saver, sess):
- sess.run(lookup_ops.tables_initializer())
- saver.restore(sess, self._latest_ckpt())
-
- def _initialize(self, init_op, sess):
- sess.run(variables.global_variables_initializer())
- sess.run(lookup_ops.tables_initializer())
- sess.run(init_op)
-
- def _import_meta_graph(self):
- meta_file_path = self._ckpt_path() + ".meta"
- return saver_lib.import_meta_graph(meta_file_path)
-
- def _delete_ckpt(self):
- # Remove all checkpoint files.
- prefix = self._ckpt_path()
- pattern = prefix + "*"
- files = gfile.Glob(pattern)
- map(gfile.Remove, files)
diff --git a/tensorflow/python/data/experimental/kernel_tests/dense_to_sparse_batch_test.py b/tensorflow/python/data/experimental/kernel_tests/dense_to_sparse_batch_test.py
new file mode 100644
index 0000000000..73be6cbcca
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/dense_to_sparse_batch_test.py
@@ -0,0 +1,124 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.experimental.dense_to_sparse_batch()."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.ops import batching
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class DenseToSparseBatchTest(test_base.DatasetTestBase):
+
+ def testDenseToSparseBatchDataset(self):
+ components = np.random.randint(12, size=(100,)).astype(np.int32)
+ iterator = (
+ dataset_ops.Dataset.from_tensor_slices(components)
+ .map(lambda x: array_ops.fill([x], x)).apply(
+ batching.dense_to_sparse_batch(4, [12]))
+ .make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+
+ for start in range(0, len(components), 4):
+ results = sess.run(get_next)
+ self.assertAllEqual([[i, j]
+ for i, c in enumerate(components[start:start + 4])
+ for j in range(c)], results.indices)
+ self.assertAllEqual(
+ [c for c in components[start:start + 4] for _ in range(c)],
+ results.values)
+ self.assertAllEqual([min(4,
+ len(components) - start), 12],
+ results.dense_shape)
+
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testDenseToSparseBatchDatasetWithUnknownShape(self):
+ components = np.random.randint(5, size=(40,)).astype(np.int32)
+ iterator = (
+ dataset_ops.Dataset.from_tensor_slices(components)
+ .map(lambda x: array_ops.fill([x, x], x)).apply(
+ batching.dense_to_sparse_batch(
+ 4, [5, None])).make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+
+ for start in range(0, len(components), 4):
+ results = sess.run(get_next)
+ self.assertAllEqual([[i, j, z]
+ for i, c in enumerate(components[start:start + 4])
+ for j in range(c)
+ for z in range(c)], results.indices)
+ self.assertAllEqual([
+ c
+ for c in components[start:start + 4] for _ in range(c)
+ for _ in range(c)
+ ], results.values)
+ self.assertAllEqual([
+ min(4,
+ len(components) - start), 5,
+ np.max(components[start:start + 4])
+ ], results.dense_shape)
+
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testDenseToSparseBatchDatasetWithInvalidShape(self):
+ input_tensor = array_ops.constant([[1]])
+ with self.assertRaisesRegexp(ValueError, "Dimension -2 must be >= 0"):
+ dataset_ops.Dataset.from_tensors(input_tensor).apply(
+ batching.dense_to_sparse_batch(4, [-2])).make_initializable_iterator()
+
+ def testDenseToSparseBatchDatasetShapeErrors(self):
+ input_tensor = array_ops.placeholder(dtypes.int32)
+ iterator = (
+ dataset_ops.Dataset.from_tensors(input_tensor).apply(
+ batching.dense_to_sparse_batch(4, [12]))
+ .make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ # Initialize with an input tensor of incompatible rank.
+ sess.run(init_op, feed_dict={input_tensor: [[1]]})
+ with self.assertRaisesRegexp(errors.InvalidArgumentError,
+ "incompatible with the row shape"):
+ sess.run(get_next)
+
+ # Initialize with an input tensor that is larger than `row_shape`.
+ sess.run(init_op, feed_dict={input_tensor: range(13)})
+ with self.assertRaisesRegexp(errors.DataLossError,
+ "larger than the row shape"):
+ sess.run(get_next)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/range_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/enumerate_dataset_test.py
index 22412c3965..e54235d9f8 100644
--- a/tensorflow/python/data/experimental/kernel_tests/range_dataset_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/enumerate_dataset_test.py
@@ -12,12 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Test RangeDataset."""
+"""Tests for `tf.data.experimental.enumerate_dataset()`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.data.experimental.ops import counter
from tensorflow.python.data.experimental.ops import enumerate_ops
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
@@ -28,7 +27,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.platform import test
-class RangeDatasetTest(test_base.DatasetTestBase):
+class EnumerateDatasetTest(test_base.DatasetTestBase):
def testEnumerateDataset(self):
components = (["a", "b"], [1, 2], [37.0, 38])
@@ -52,27 +51,6 @@ class RangeDatasetTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- def testCounter(self):
- """Test dataset construction using `count`."""
- iterator = (counter.Counter(start=3, step=4)
- .make_one_shot_iterator())
- get_next = iterator.get_next()
- self.assertEqual([], get_next.shape.as_list())
- self.assertEqual(dtypes.int64, get_next.dtype)
-
- negative_iterator = (counter.Counter(start=0, step=-1)
- .make_one_shot_iterator())
- negative_get_next = negative_iterator.get_next()
-
- with self.cached_session() as sess:
- self.assertEqual(3, sess.run(get_next))
- self.assertEqual(3 + 4, sess.run(get_next))
- self.assertEqual(3 + 2 * 4, sess.run(get_next))
-
- self.assertEqual(0, sess.run(negative_get_next))
- self.assertEqual(-1, sess.run(negative_get_next))
- self.assertEqual(-2, sess.run(negative_get_next))
-
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/function_buffering_resource_test.py b/tensorflow/python/data/experimental/kernel_tests/function_buffering_resource_test.py
new file mode 100644
index 0000000000..399fd284f4
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/function_buffering_resource_test.py
@@ -0,0 +1,247 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the private `FunctionBufferingResource` used in prefetching."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import threading
+
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.data.experimental.ops import prefetching_ops
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import function
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.platform import test
+
+
+class FunctionBufferingResourceTest(test_base.DatasetTestBase):
+
+ def setUp(self):
+ self._event = threading.Event()
+
+ def _create_ds_and_iterator(self, device0, initializable=False):
+
+ def gen():
+ for i in range(1, 10):
+ yield [float(i)]
+ if i == 6:
+ self._event.set()
+
+ with ops.device(device0):
+ ds = dataset_ops.Dataset.from_generator(gen, (dtypes.float32))
+ if initializable:
+ ds_iterator = ds.make_initializable_iterator()
+ else:
+ ds_iterator = ds.make_one_shot_iterator()
+ return (ds, ds_iterator)
+
+ def _create_ops(self, ds, ds_iterator, buffer_name, device0, device1):
+ ds_iterator_handle = ds_iterator.string_handle()
+
+ @function.Defun(dtypes.string)
+ def _remote_fn(h):
+ remote_iterator = iterator_ops.Iterator.from_string_handle(
+ h, ds.output_types, ds.output_shapes)
+ return remote_iterator.get_next()
+
+ target = constant_op.constant(device0)
+ with ops.device(device1):
+ buffer_resource_handle = prefetching_ops.function_buffering_resource(
+ f=_remote_fn,
+ output_types=[dtypes.float32],
+ target_device=target,
+ string_arg=ds_iterator_handle,
+ buffer_size=3,
+ shared_name=buffer_name)
+
+ with ops.device(device1):
+ prefetch_op = prefetching_ops.function_buffering_resource_get_next(
+ function_buffer_resource=buffer_resource_handle,
+ output_types=[dtypes.float32])
+ reset_op = prefetching_ops.function_buffering_resource_reset(
+ function_buffer_resource=buffer_resource_handle)
+ destroy_op = resource_variable_ops.destroy_resource_op(
+ buffer_resource_handle, ignore_lookup_error=True)
+
+ return (prefetch_op, reset_op, destroy_op)
+
+ def _prefetch_fn_helper_one_shot(self, buffer_name, device0, device1):
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+
+ ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=False)
+ prefetch_op, _, destroy_op = self._create_ops(ds, ds_iterator, buffer_name,
+ device0, device1)
+
+ with self.test_session(config=worker_config) as sess:
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [1.0])
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [2.0])
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [3.0])
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [4.0])
+ self._event.wait()
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [5.0])
+ sess.run(destroy_op)
+
+ def testSameDeviceCPU(self):
+ self._prefetch_fn_helper_one_shot("same_device_cpu",
+ "/job:localhost/replica:0/task:0/cpu:0",
+ "/job:localhost/replica:0/task:0/cpu:0")
+
+ def testDifferentDeviceCPU(self):
+ self._prefetch_fn_helper_one_shot("diff_device_cpu",
+ "/job:localhost/replica:0/task:0/cpu:0",
+ "/job:localhost/replica:0/task:0/cpu:1")
+
+ def testDifferentDeviceCPUGPU(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ self._prefetch_fn_helper_one_shot("cpu_gpu",
+ "/job:localhost/replica:0/task:0/cpu:0",
+ "/job:localhost/replica:0/task:0/gpu:0")
+
+ def testReinitialization(self):
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+
+ device0 = "/job:localhost/replica:0/task:0/cpu:0"
+ device1 = "/job:localhost/replica:0/task:0/cpu:1"
+ ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=True)
+ prefetch_op, reset_op, destroy_op = self._create_ops(
+ ds, ds_iterator, "reinit", device0, device1)
+
+ with self.test_session(config=worker_config) as sess:
+ sess.run(ds_iterator.initializer)
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [1.0])
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [2.0])
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [3.0])
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [4.0])
+ self._event.wait()
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [5.0])
+ # Lets reset the function buffering resource and reinitialize the
+ # iterator. Should be able to go through this again.
+ self._event.clear()
+ sess.run(reset_op)
+ sess.run(ds_iterator.initializer)
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [1.0])
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [2.0])
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [3.0])
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [4.0])
+ self._event.wait()
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [5.0])
+ sess.run(destroy_op)
+
+ def testReinitializationOutOfRange(self):
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+
+ device0 = "/job:localhost/replica:0/task:0/cpu:0"
+ device1 = "/job:localhost/replica:0/task:0/cpu:1"
+ ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=True)
+ prefetch_op, reset_op, destroy_op = self._create_ops(
+ ds, ds_iterator, "reinit", device0, device1)
+
+ with self.test_session(config=worker_config) as sess:
+ sess.run(ds_iterator.initializer)
+ for i in range(1, 10):
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [float(i)])
+ # Try fetching after its over twice to test out end of sequence.
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(prefetch_op)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(prefetch_op)
+
+ # Now reset everything and try it out again.
+ self._event.clear()
+ sess.run(reset_op)
+ sess.run(ds_iterator.initializer)
+ for i in range(1, 10):
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [float(i)])
+ # Try fetching after its over twice to test out end of sequence.
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(prefetch_op)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(prefetch_op)
+
+ sess.run(destroy_op)
+
+ def testStringsGPU(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ device0 = "/job:localhost/replica:0/task:0/cpu:0"
+ device1 = "/job:localhost/replica:0/task:0/gpu:0"
+
+ ds = dataset_ops.Dataset.from_tensor_slices(["a", "b", "c"])
+ ds_iterator = ds.make_one_shot_iterator()
+ ds_iterator_handle = ds_iterator.string_handle()
+
+ @function.Defun(dtypes.string)
+ def _remote_fn(h):
+ remote_iterator = iterator_ops.Iterator.from_string_handle(
+ h, ds.output_types, ds.output_shapes)
+ return remote_iterator.get_next()
+
+ target = constant_op.constant(device0)
+ with ops.device(device1):
+ buffer_resource_handle = prefetching_ops.function_buffering_resource(
+ f=_remote_fn,
+ output_types=[dtypes.string],
+ target_device=target,
+ string_arg=ds_iterator_handle,
+ buffer_size=3,
+ shared_name="strings")
+
+ with ops.device(device1):
+ prefetch_op = prefetching_ops.function_buffering_resource_get_next(
+ function_buffer_resource=buffer_resource_handle,
+ output_types=[dtypes.string])
+ destroy_op = resource_variable_ops.destroy_resource_op(
+ buffer_resource_handle, ignore_lookup_error=True)
+
+ with self.cached_session() as sess:
+ self.assertEqual([b"a"], sess.run(prefetch_op))
+ self.assertEqual([b"b"], sess.run(prefetch_op))
+ self.assertEqual([b"c"], sess.run(prefetch_op))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(prefetch_op)
+
+ sess.run(destroy_op)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/group_by_reducer_test.py b/tensorflow/python/data/experimental/kernel_tests/group_by_reducer_test.py
new file mode 100644
index 0000000000..9030328593
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/group_by_reducer_test.py
@@ -0,0 +1,199 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.experimental.group_by_reducer()`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.ops import grouping
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class GroupByReducerTest(test_base.DatasetTestBase):
+
+ def checkResults(self, dataset, shapes, values):
+ self.assertEqual(shapes, dataset.output_shapes)
+ get_next = dataset.make_one_shot_iterator().get_next()
+ with self.cached_session() as sess:
+ for expected in values:
+ got = sess.run(get_next)
+ self.assertEqual(got, expected)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testSum(self):
+ reducer = grouping.Reducer(
+ init_func=lambda _: np.int64(0),
+ reduce_func=lambda x, y: x + y,
+ finalize_func=lambda x: x)
+ for i in range(1, 11):
+ dataset = dataset_ops.Dataset.range(2 * i).apply(
+ grouping.group_by_reducer(lambda x: x % 2, reducer))
+ self.checkResults(
+ dataset, shapes=tensor_shape.scalar(), values=[(i - 1) * i, i * i])
+
+ def testAverage(self):
+
+ def reduce_fn(x, y):
+ return (x[0] * x[1] + math_ops.cast(y, dtypes.float32)) / (
+ x[1] + 1), x[1] + 1
+
+ reducer = grouping.Reducer(
+ init_func=lambda _: (0.0, 0.0),
+ reduce_func=reduce_fn,
+ finalize_func=lambda x, _: x)
+ for i in range(1, 11):
+ dataset = dataset_ops.Dataset.range(2 * i).apply(
+ grouping.group_by_reducer(
+ lambda x: math_ops.cast(x, dtypes.int64) % 2, reducer))
+ self.checkResults(
+ dataset, shapes=tensor_shape.scalar(), values=[i - 1, i])
+
+ def testConcat(self):
+ components = np.array(list("abcdefghijklmnopqrst")).view(np.chararray)
+ reducer = grouping.Reducer(
+ init_func=lambda x: "",
+ reduce_func=lambda x, y: x + y[0],
+ finalize_func=lambda x: x)
+ for i in range(1, 11):
+ dataset = dataset_ops.Dataset.zip(
+ (dataset_ops.Dataset.from_tensor_slices(components),
+ dataset_ops.Dataset.range(2 * i))).apply(
+ grouping.group_by_reducer(lambda x, y: y % 2, reducer))
+ self.checkResults(
+ dataset,
+ shapes=tensor_shape.scalar(),
+ values=[b"acegikmoqs" [:i], b"bdfhjlnprt" [:i]])
+
+ def testSparseSum(self):
+ def _sparse(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=np.array([[0, 0]]),
+ values=(i * np.array([1], dtype=np.int64)),
+ dense_shape=np.array([1, 1]))
+
+ reducer = grouping.Reducer(
+ init_func=lambda _: _sparse(np.int64(0)),
+ reduce_func=lambda x, y: _sparse(x.values[0] + y.values[0]),
+ finalize_func=lambda x: x.values[0])
+ for i in range(1, 11):
+ dataset = dataset_ops.Dataset.range(2 * i).map(_sparse).apply(
+ grouping.group_by_reducer(lambda x: x.values[0] % 2, reducer))
+ self.checkResults(
+ dataset, shapes=tensor_shape.scalar(), values=[(i - 1) * i, i * i])
+
+ def testChangingStateShape(self):
+
+ def reduce_fn(x, _):
+ # Statically known rank, but dynamic length.
+ larger_dim = array_ops.concat([x[0], x[0]], 0)
+ # Statically unknown rank.
+ larger_rank = array_ops.expand_dims(x[1], 0)
+ return larger_dim, larger_rank
+
+ reducer = grouping.Reducer(
+ init_func=lambda x: ([0], 1),
+ reduce_func=reduce_fn,
+ finalize_func=lambda x, y: (x, y))
+
+ for i in range(1, 11):
+ dataset = dataset_ops.Dataset.from_tensors(np.int64(0)).repeat(i).apply(
+ grouping.group_by_reducer(lambda x: x, reducer))
+ self.assertEqual([None], dataset.output_shapes[0].as_list())
+ self.assertIs(None, dataset.output_shapes[1].ndims)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+ with self.cached_session() as sess:
+ x, y = sess.run(get_next)
+ self.assertAllEqual([0] * (2**i), x)
+ self.assertAllEqual(np.array(1, ndmin=i), y)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testTypeMismatch(self):
+ reducer = grouping.Reducer(
+ init_func=lambda x: constant_op.constant(1, dtype=dtypes.int32),
+ reduce_func=lambda x, y: constant_op.constant(1, dtype=dtypes.int64),
+ finalize_func=lambda x: x)
+
+ dataset = dataset_ops.Dataset.range(10)
+ with self.assertRaisesRegexp(
+ TypeError,
+ "The element types for the new state must match the initial state."):
+ dataset.apply(
+ grouping.group_by_reducer(lambda _: np.int64(0), reducer))
+
+ # TODO(b/78665031): Remove once non-scalar keys are supported.
+ def testInvalidKeyShape(self):
+ reducer = grouping.Reducer(
+ init_func=lambda x: np.int64(0),
+ reduce_func=lambda x, y: x + y,
+ finalize_func=lambda x: x)
+
+ dataset = dataset_ops.Dataset.range(10)
+ with self.assertRaisesRegexp(
+ ValueError, "`key_func` must return a single tf.int64 tensor."):
+ dataset.apply(
+ grouping.group_by_reducer(lambda _: np.int64((0, 0)), reducer))
+
+ # TODO(b/78665031): Remove once non-int64 keys are supported.
+ def testInvalidKeyType(self):
+ reducer = grouping.Reducer(
+ init_func=lambda x: np.int64(0),
+ reduce_func=lambda x, y: x + y,
+ finalize_func=lambda x: x)
+
+ dataset = dataset_ops.Dataset.range(10)
+ with self.assertRaisesRegexp(
+ ValueError, "`key_func` must return a single tf.int64 tensor."):
+ dataset.apply(
+ grouping.group_by_reducer(lambda _: "wrong", reducer))
+
+ def testTuple(self):
+ def init_fn(_):
+ return np.array([], dtype=np.int64), np.int64(0)
+
+ def reduce_fn(state, value):
+ s1, s2 = state
+ v1, v2 = value
+ return array_ops.concat([s1, [v1]], 0), s2 + v2
+
+ def finalize_fn(s1, s2):
+ return s1, s2
+
+ reducer = grouping.Reducer(init_fn, reduce_fn, finalize_fn)
+ dataset = dataset_ops.Dataset.zip(
+ (dataset_ops.Dataset.range(10), dataset_ops.Dataset.range(10))).apply(
+ grouping.group_by_reducer(lambda x, y: np.int64(0), reducer))
+ get_next = dataset.make_one_shot_iterator().get_next()
+ with self.cached_session() as sess:
+ x, y = sess.run(get_next)
+ self.assertAllEqual(x, np.asarray([x for x in range(10)]))
+ self.assertEqual(y, 45)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/group_by_window_test.py b/tensorflow/python/data/experimental/kernel_tests/group_by_window_test.py
new file mode 100644
index 0000000000..557d56e8b9
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/group_by_window_test.py
@@ -0,0 +1,367 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.experimental.group_by_window()`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.ops import grouping
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import string_ops
+from tensorflow.python.platform import test
+
+
+# NOTE(mrry): These tests are based on the tests in bucket_ops_test.py.
+# Currently, they use a constant batch size, though should be made to use a
+# different batch size per key.
+class GroupByWindowTest(test_base.DatasetTestBase):
+
+ def _dynamicPad(self, bucket, window, window_size):
+ # TODO(mrry): To match `tf.contrib.training.bucket()`, implement a
+ # generic form of padded_batch that pads every component
+ # dynamically and does not rely on static shape information about
+ # the arguments.
+ return dataset_ops.Dataset.zip(
+ (dataset_ops.Dataset.from_tensors(bucket),
+ window.padded_batch(
+ 32, (tensor_shape.TensorShape([]), tensor_shape.TensorShape(
+ [None]), tensor_shape.TensorShape([3])))))
+
+ def testSingleBucket(self):
+
+ def _map_fn(v):
+ return (v, array_ops.fill([v], v),
+ array_ops.fill([3], string_ops.as_string(v)))
+
+ input_dataset = (
+ dataset_ops.Dataset.from_tensor_slices(math_ops.range(32)).map(_map_fn))
+
+ bucketed_dataset = input_dataset.apply(
+ grouping.group_by_window(
+ lambda x, y, z: 0,
+ lambda k, bucket: self._dynamicPad(k, bucket, 32), 32))
+
+ iterator = bucketed_dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+
+ which_bucket, bucketed_values = sess.run(get_next)
+
+ self.assertEqual(0, which_bucket)
+
+ expected_scalar_int = np.arange(32, dtype=np.int64)
+ expected_unk_int64 = np.zeros((32, 31)).astype(np.int64)
+ for i in range(32):
+ expected_unk_int64[i, :i] = i
+ expected_vec3_str = np.vstack(3 * [np.arange(32).astype(bytes)]).T
+
+ self.assertAllEqual(expected_scalar_int, bucketed_values[0])
+ self.assertAllEqual(expected_unk_int64, bucketed_values[1])
+ self.assertAllEqual(expected_vec3_str, bucketed_values[2])
+
+ def testEvenOddBuckets(self):
+
+ def _map_fn(v):
+ return (v, array_ops.fill([v], v),
+ array_ops.fill([3], string_ops.as_string(v)))
+
+ input_dataset = (
+ dataset_ops.Dataset.from_tensor_slices(math_ops.range(64)).map(_map_fn))
+
+ bucketed_dataset = input_dataset.apply(
+ grouping.group_by_window(
+ lambda x, y, z: math_ops.cast(x % 2, dtypes.int64),
+ lambda k, bucket: self._dynamicPad(k, bucket, 32), 32))
+
+ iterator = bucketed_dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+
+ # Get two minibatches (one containing even values, one containing odds)
+ which_bucket_even, bucketed_values_even = sess.run(get_next)
+ which_bucket_odd, bucketed_values_odd = sess.run(get_next)
+
+ # Count number of bucket_tensors.
+ self.assertEqual(3, len(bucketed_values_even))
+ self.assertEqual(3, len(bucketed_values_odd))
+
+ # Ensure bucket 0 was used for all minibatch entries.
+ self.assertAllEqual(0, which_bucket_even)
+ self.assertAllEqual(1, which_bucket_odd)
+
+ # Test the first bucket outputted, the events starting at 0
+ expected_scalar_int = np.arange(0, 32 * 2, 2, dtype=np.int64)
+ expected_unk_int64 = np.zeros((32, 31 * 2)).astype(np.int64)
+ for i in range(0, 32):
+ expected_unk_int64[i, :2 * i] = 2 * i
+ expected_vec3_str = np.vstack(
+ 3 * [np.arange(0, 32 * 2, 2).astype(bytes)]).T
+
+ self.assertAllEqual(expected_scalar_int, bucketed_values_even[0])
+ self.assertAllEqual(expected_unk_int64, bucketed_values_even[1])
+ self.assertAllEqual(expected_vec3_str, bucketed_values_even[2])
+
+ # Test the second bucket outputted, the odds starting at 1
+ expected_scalar_int = np.arange(1, 32 * 2 + 1, 2, dtype=np.int64)
+ expected_unk_int64 = np.zeros((32, 31 * 2 + 1)).astype(np.int64)
+ for i in range(0, 32):
+ expected_unk_int64[i, :2 * i + 1] = 2 * i + 1
+ expected_vec3_str = np.vstack(
+ 3 * [np.arange(1, 32 * 2 + 1, 2).astype(bytes)]).T
+
+ self.assertAllEqual(expected_scalar_int, bucketed_values_odd[0])
+ self.assertAllEqual(expected_unk_int64, bucketed_values_odd[1])
+ self.assertAllEqual(expected_vec3_str, bucketed_values_odd[2])
+
+ def testEvenOddBucketsFilterOutAllOdd(self):
+
+ def _map_fn(v):
+ return {
+ "x": v,
+ "y": array_ops.fill([v], v),
+ "z": array_ops.fill([3], string_ops.as_string(v))
+ }
+
+ def _dynamic_pad_fn(bucket, window, _):
+ return dataset_ops.Dataset.zip(
+ (dataset_ops.Dataset.from_tensors(bucket),
+ window.padded_batch(
+ 32, {
+ "x": tensor_shape.TensorShape([]),
+ "y": tensor_shape.TensorShape([None]),
+ "z": tensor_shape.TensorShape([3])
+ })))
+
+ input_dataset = (
+ dataset_ops.Dataset.from_tensor_slices(math_ops.range(128)).map(_map_fn)
+ .filter(lambda d: math_ops.equal(d["x"] % 2, 0)))
+
+ bucketed_dataset = input_dataset.apply(
+ grouping.group_by_window(
+ lambda d: math_ops.cast(d["x"] % 2, dtypes.int64),
+ lambda k, bucket: _dynamic_pad_fn(k, bucket, 32), 32))
+
+ iterator = bucketed_dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+
+ # Get two minibatches ([0, 2, ...] and [64, 66, ...])
+ which_bucket0, bucketed_values_even0 = sess.run(get_next)
+ which_bucket1, bucketed_values_even1 = sess.run(get_next)
+
+ # Ensure that bucket 1 was completely filtered out
+ self.assertAllEqual(0, which_bucket0)
+ self.assertAllEqual(0, which_bucket1)
+ self.assertAllEqual(
+ np.arange(0, 64, 2, dtype=np.int64), bucketed_values_even0["x"])
+ self.assertAllEqual(
+ np.arange(64, 128, 2, dtype=np.int64), bucketed_values_even1["x"])
+
+ def testDynamicWindowSize(self):
+ components = np.arange(100).astype(np.int64)
+
+ # Key fn: even/odd
+ # Reduce fn: batches of 5
+ # Window size fn: even=5, odd=10
+
+ def window_size_func(key):
+ window_sizes = constant_op.constant([5, 10], dtype=dtypes.int64)
+ return window_sizes[key]
+
+ dataset = dataset_ops.Dataset.from_tensor_slices(components).apply(
+ grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(20),
+ None, window_size_func))
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ with self.assertRaises(errors.OutOfRangeError):
+ batches = 0
+ while True:
+ result = sess.run(get_next)
+ is_even = all(x % 2 == 0 for x in result)
+ is_odd = all(x % 2 == 1 for x in result)
+ self.assertTrue(is_even or is_odd)
+ expected_batch_size = 5 if is_even else 10
+ self.assertEqual(expected_batch_size, result.shape[0])
+ batches += 1
+
+ self.assertEqual(batches, 15)
+
+ def testSimple(self):
+ components = np.random.randint(100, size=(200,)).astype(np.int64)
+ iterator = (
+ dataset_ops.Dataset.from_tensor_slices(components).map(lambda x: x * x)
+ .apply(
+ grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(4),
+ 4)).make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ counts = []
+ with self.assertRaises(errors.OutOfRangeError):
+ while True:
+ result = sess.run(get_next)
+ self.assertTrue(
+ all(x % 2 == 0
+ for x in result) or all(x % 2 == 1)
+ for x in result)
+ counts.append(result.shape[0])
+
+ self.assertEqual(len(components), sum(counts))
+ num_full_batches = len([c for c in counts if c == 4])
+ self.assertGreaterEqual(num_full_batches, 24)
+ self.assertTrue(all(c == 4 for c in counts[:num_full_batches]))
+
+ def testImmediateOutput(self):
+ components = np.array(
+ [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0], dtype=np.int64)
+ iterator = (
+ dataset_ops.Dataset.from_tensor_slices(components).repeat(-1).apply(
+ grouping.group_by_window(lambda x: x % 3, lambda _, xs: xs.batch(4),
+ 4)).make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ # The input is infinite, so this test demonstrates that:
+ # 1. We produce output without having to consume the entire input,
+ # 2. Different buckets can produce output at different rates, and
+ # 3. For deterministic input, the output is deterministic.
+ for _ in range(3):
+ self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
+ self.assertAllEqual([1, 1, 1, 1], sess.run(get_next))
+ self.assertAllEqual([2, 2, 2, 2], sess.run(get_next))
+ self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
+
+ def testSmallGroups(self):
+ components = np.array([0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0], dtype=np.int64)
+ iterator = (
+ dataset_ops.Dataset.from_tensor_slices(components).apply(
+ grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(4),
+ 4)).make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
+ self.assertAllEqual([1, 1, 1, 1], sess.run(get_next))
+ # The small outputs at the end are deterministically produced in key
+ # order.
+ self.assertAllEqual([0, 0, 0], sess.run(get_next))
+ self.assertAllEqual([1], sess.run(get_next))
+
+ def testEmpty(self):
+ iterator = (
+ dataset_ops.Dataset.range(4).apply(
+ grouping.group_by_window(lambda _: 0, lambda _, xs: xs, 0))
+ .make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ "Window size must be greater than zero, but got 0."):
+ print(sess.run(get_next))
+
+ def testReduceFuncError(self):
+ components = np.random.randint(100, size=(200,)).astype(np.int64)
+
+ def reduce_func(_, xs):
+ # Introduce an incorrect padded shape that cannot (currently) be
+ # detected at graph construction time.
+ return xs.padded_batch(
+ 4,
+ padded_shapes=(tensor_shape.TensorShape([]),
+ constant_op.constant([5], dtype=dtypes.int64) * -1))
+
+ iterator = (
+ dataset_ops.Dataset.from_tensor_slices(components)
+ .map(lambda x: (x, ops.convert_to_tensor([x * x]))).apply(
+ grouping.group_by_window(lambda x, _: x % 2, reduce_func,
+ 32)).make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(get_next)
+
+ def testConsumeWindowDatasetMoreThanOnce(self):
+ components = np.random.randint(50, size=(200,)).astype(np.int64)
+
+ def reduce_func(key, window):
+ # Apply two different kinds of padding to the input: tight
+ # padding, and quantized (to a multiple of 10) padding.
+ return dataset_ops.Dataset.zip((
+ window.padded_batch(
+ 4, padded_shapes=tensor_shape.TensorShape([None])),
+ window.padded_batch(
+ 4, padded_shapes=ops.convert_to_tensor([(key + 1) * 10])),
+ ))
+
+ iterator = (
+ dataset_ops.Dataset.from_tensor_slices(components)
+ .map(lambda x: array_ops.fill([math_ops.cast(x, dtypes.int32)], x))
+ .apply(grouping.group_by_window(
+ lambda x: math_ops.cast(array_ops.shape(x)[0] // 10, dtypes.int64),
+ reduce_func, 4))
+ .make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ counts = []
+ with self.assertRaises(errors.OutOfRangeError):
+ while True:
+ tight_result, multiple_of_10_result = sess.run(get_next)
+ self.assertEqual(0, multiple_of_10_result.shape[1] % 10)
+ self.assertAllEqual(tight_result,
+ multiple_of_10_result[:, :tight_result.shape[1]])
+ counts.append(tight_result.shape[0])
+ self.assertEqual(len(components), sum(counts))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/ignore_errors_test.py b/tensorflow/python/data/experimental/kernel_tests/ignore_errors_test.py
new file mode 100644
index 0000000000..c0ec1486ab
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/ignore_errors_test.py
@@ -0,0 +1,115 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.experimental.ignore_errors()`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+import numpy as np
+
+from tensorflow.python.data.experimental.ops import error_ops
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import errors
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import io_ops
+from tensorflow.python.platform import test
+from tensorflow.python.util import compat
+
+_NUMPY_RANDOM_SEED = 42
+
+
+class IgnoreErrorsTest(test_base.DatasetTestBase):
+
+ def testMapIgnoreError(self):
+ components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
+
+ dataset = (
+ dataset_ops.Dataset.from_tensor_slices(components)
+ .map(lambda x: array_ops.check_numerics(x, "message")).apply(
+ error_ops.ignore_errors()))
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ for x in [1., 2., 3., 5.]:
+ self.assertEqual(x, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testParallelMapIgnoreError(self):
+ components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
+
+ dataset = (
+ dataset_ops.Dataset.from_tensor_slices(components).map(
+ lambda x: array_ops.check_numerics(x, "message"),
+ num_parallel_calls=2).prefetch(2).apply(error_ops.ignore_errors()))
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ for x in [1., 2., 3., 5.]:
+ self.assertEqual(x, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testReadFileIgnoreError(self):
+
+ def write_string_to_file(value, filename):
+ with open(filename, "w") as f:
+ f.write(value)
+
+ filenames = [
+ os.path.join(self.get_temp_dir(), "file_%d.txt" % i) for i in range(5)
+ ]
+ for filename in filenames:
+ write_string_to_file(filename, filename)
+
+ dataset = (
+ dataset_ops.Dataset.from_tensor_slices(filenames).map(
+ io_ops.read_file,
+ num_parallel_calls=2).prefetch(2).apply(error_ops.ignore_errors()))
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ # All of the files are present.
+ sess.run(init_op)
+ for filename in filenames:
+ self.assertEqual(compat.as_bytes(filename), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Delete one of the files.
+ os.remove(filenames[0])
+
+ # Attempting to read filenames[0] will fail, but ignore_errors()
+ # will catch the error.
+ sess.run(init_op)
+ for filename in filenames[1:]:
+ self.assertEqual(compat.as_bytes(filename), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/make_batched_features_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/make_batched_features_dataset_test.py
new file mode 100644
index 0000000000..5ee94e14dc
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/make_batched_features_dataset_test.py
@@ -0,0 +1,239 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.experimental.make_batched_features_dataset()`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
+from tensorflow.python.data.ops import readers as core_readers
+from tensorflow.python.data.util import nest
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import parsing_ops
+from tensorflow.python.platform import test
+
+
+class MakeBatchedFeaturesDatasetTest(
+ reader_dataset_ops_test_base.MakeBatchedFeaturesDatasetTestBase):
+
+ def testRead(self):
+ for batch_size in [1, 2]:
+ for num_epochs in [1, 10]:
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
+ # Basic test: read from file 0.
+ self.outputs = self.make_batch_feature(
+ filenames=self.test_filenames[0],
+ label_key="label",
+ num_epochs=num_epochs,
+ batch_size=batch_size).make_one_shot_iterator().get_next()
+ self.verify_records(
+ sess,
+ batch_size,
+ 0,
+ num_epochs=num_epochs,
+ label_key_provided=True)
+ with self.assertRaises(errors.OutOfRangeError):
+ self._next_actual_batch(sess, label_key_provided=True)
+
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
+ # Basic test: read from file 1.
+ self.outputs = self.make_batch_feature(
+ filenames=self.test_filenames[1],
+ label_key="label",
+ num_epochs=num_epochs,
+ batch_size=batch_size).make_one_shot_iterator().get_next()
+ self.verify_records(
+ sess,
+ batch_size,
+ 1,
+ num_epochs=num_epochs,
+ label_key_provided=True)
+ with self.assertRaises(errors.OutOfRangeError):
+ self._next_actual_batch(sess, label_key_provided=True)
+
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
+ # Basic test: read from both files.
+ self.outputs = self.make_batch_feature(
+ filenames=self.test_filenames,
+ label_key="label",
+ num_epochs=num_epochs,
+ batch_size=batch_size).make_one_shot_iterator().get_next()
+ self.verify_records(
+ sess,
+ batch_size,
+ num_epochs=num_epochs,
+ label_key_provided=True)
+ with self.assertRaises(errors.OutOfRangeError):
+ self._next_actual_batch(sess, label_key_provided=True)
+
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
+ # Basic test: read from both files.
+ self.outputs = self.make_batch_feature(
+ filenames=self.test_filenames,
+ num_epochs=num_epochs,
+ batch_size=batch_size).make_one_shot_iterator().get_next()
+ self.verify_records(sess, batch_size, num_epochs=num_epochs)
+ with self.assertRaises(errors.OutOfRangeError):
+ self._next_actual_batch(sess)
+
+ def testReadWithEquivalentDataset(self):
+ features = {
+ "file": parsing_ops.FixedLenFeature([], dtypes.int64),
+ "record": parsing_ops.FixedLenFeature([], dtypes.int64),
+ }
+ dataset = (
+ core_readers.TFRecordDataset(self.test_filenames)
+ .map(lambda x: parsing_ops.parse_single_example(x, features))
+ .repeat(10).batch(2))
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ for file_batch, _, _, _, record_batch, _ in self._next_expected_batch(
+ range(self._num_files), 2, 10):
+ actual_batch = sess.run(next_element)
+ self.assertAllEqual(file_batch, actual_batch["file"])
+ self.assertAllEqual(record_batch, actual_batch["record"])
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testReadWithFusedShuffleRepeatDataset(self):
+ num_epochs = 5
+ total_records = num_epochs * self._num_records
+ for batch_size in [1, 2]:
+ # Test that shuffling with same seed produces the same result.
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
+ outputs1 = self.make_batch_feature(
+ filenames=self.test_filenames[0],
+ num_epochs=num_epochs,
+ batch_size=batch_size,
+ shuffle=True,
+ shuffle_seed=5).make_one_shot_iterator().get_next()
+ outputs2 = self.make_batch_feature(
+ filenames=self.test_filenames[0],
+ num_epochs=num_epochs,
+ batch_size=batch_size,
+ shuffle=True,
+ shuffle_seed=5).make_one_shot_iterator().get_next()
+ for _ in range(total_records // batch_size):
+ batch1 = self._run_actual_batch(outputs1, sess)
+ batch2 = self._run_actual_batch(outputs2, sess)
+ for i in range(len(batch1)):
+ self.assertAllEqual(batch1[i], batch2[i])
+
+ # Test that shuffling with different seeds produces a different order.
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
+ outputs1 = self.make_batch_feature(
+ filenames=self.test_filenames[0],
+ num_epochs=num_epochs,
+ batch_size=batch_size,
+ shuffle=True,
+ shuffle_seed=5).make_one_shot_iterator().get_next()
+ outputs2 = self.make_batch_feature(
+ filenames=self.test_filenames[0],
+ num_epochs=num_epochs,
+ batch_size=batch_size,
+ shuffle=True,
+ shuffle_seed=15).make_one_shot_iterator().get_next()
+ all_equal = True
+ for _ in range(total_records // batch_size):
+ batch1 = self._run_actual_batch(outputs1, sess)
+ batch2 = self._run_actual_batch(outputs2, sess)
+ for i in range(len(batch1)):
+ all_equal = all_equal and np.array_equal(batch1[i], batch2[i])
+ self.assertFalse(all_equal)
+
+ def testParallelReadersAndParsers(self):
+ num_epochs = 5
+ for batch_size in [1, 2]:
+ for reader_num_threads in [2, 4]:
+ for parser_num_threads in [2, 4]:
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
+ self.outputs = self.make_batch_feature(
+ filenames=self.test_filenames,
+ label_key="label",
+ num_epochs=num_epochs,
+ batch_size=batch_size,
+ reader_num_threads=reader_num_threads,
+ parser_num_threads=parser_num_threads).make_one_shot_iterator(
+ ).get_next()
+ self.verify_records(
+ sess,
+ batch_size,
+ num_epochs=num_epochs,
+ label_key_provided=True,
+ interleave_cycle_length=reader_num_threads)
+ with self.assertRaises(errors.OutOfRangeError):
+ self._next_actual_batch(sess, label_key_provided=True)
+
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
+ self.outputs = self.make_batch_feature(
+ filenames=self.test_filenames,
+ num_epochs=num_epochs,
+ batch_size=batch_size,
+ reader_num_threads=reader_num_threads,
+ parser_num_threads=parser_num_threads).make_one_shot_iterator(
+ ).get_next()
+ self.verify_records(
+ sess,
+ batch_size,
+ num_epochs=num_epochs,
+ interleave_cycle_length=reader_num_threads)
+ with self.assertRaises(errors.OutOfRangeError):
+ self._next_actual_batch(sess)
+
+ def testDropFinalBatch(self):
+ for batch_size in [1, 2]:
+ for num_epochs in [1, 10]:
+ with ops.Graph().as_default():
+ # Basic test: read from file 0.
+ outputs = self.make_batch_feature(
+ filenames=self.test_filenames[0],
+ label_key="label",
+ num_epochs=num_epochs,
+ batch_size=batch_size,
+ drop_final_batch=True).make_one_shot_iterator().get_next()
+ for tensor in nest.flatten(outputs):
+ if isinstance(tensor, ops.Tensor): # Guard against SparseTensor.
+ self.assertEqual(tensor.shape[0], batch_size)
+
+ def testIndefiniteRepeatShapeInference(self):
+ dataset = self.make_batch_feature(
+ filenames=self.test_filenames[0],
+ label_key="label",
+ num_epochs=None,
+ batch_size=32)
+ for shape, clazz in zip(nest.flatten(dataset.output_shapes),
+ nest.flatten(dataset.output_classes)):
+ if issubclass(clazz, ops.Tensor):
+ self.assertEqual(32, shape[0])
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/reader_dataset_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/make_csv_dataset_test.py
index a02f4bd14f..e4bf089184 100644
--- a/tensorflow/python/data/experimental/kernel_tests/reader_dataset_ops_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/make_csv_dataset_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
+"""Tests for `tf.data.experimental.make_csv_dataset()`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -23,226 +23,16 @@ import zlib
import numpy as np
-from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
from tensorflow.python.data.experimental.ops import readers
from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import readers as core_readers
from tensorflow.python.data.util import nest
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
-from tensorflow.python.ops import parsing_ops
-from tensorflow.python.ops import string_ops
from tensorflow.python.platform import test
-class ReadBatchFeaturesTest(
- reader_dataset_ops_test_base.ReadBatchFeaturesTestBase):
-
- def testRead(self):
- for batch_size in [1, 2]:
- for num_epochs in [1, 10]:
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- # Basic test: read from file 0.
- self.outputs = self.make_batch_feature(
- filenames=self.test_filenames[0],
- label_key="label",
- num_epochs=num_epochs,
- batch_size=batch_size).make_one_shot_iterator().get_next()
- self.verify_records(
- sess,
- batch_size,
- 0,
- num_epochs=num_epochs,
- label_key_provided=True)
- with self.assertRaises(errors.OutOfRangeError):
- self._next_actual_batch(sess, label_key_provided=True)
-
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- # Basic test: read from file 1.
- self.outputs = self.make_batch_feature(
- filenames=self.test_filenames[1],
- label_key="label",
- num_epochs=num_epochs,
- batch_size=batch_size).make_one_shot_iterator().get_next()
- self.verify_records(
- sess,
- batch_size,
- 1,
- num_epochs=num_epochs,
- label_key_provided=True)
- with self.assertRaises(errors.OutOfRangeError):
- self._next_actual_batch(sess, label_key_provided=True)
-
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- # Basic test: read from both files.
- self.outputs = self.make_batch_feature(
- filenames=self.test_filenames,
- label_key="label",
- num_epochs=num_epochs,
- batch_size=batch_size).make_one_shot_iterator().get_next()
- self.verify_records(
- sess,
- batch_size,
- num_epochs=num_epochs,
- label_key_provided=True)
- with self.assertRaises(errors.OutOfRangeError):
- self._next_actual_batch(sess, label_key_provided=True)
-
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- # Basic test: read from both files.
- self.outputs = self.make_batch_feature(
- filenames=self.test_filenames,
- num_epochs=num_epochs,
- batch_size=batch_size).make_one_shot_iterator().get_next()
- self.verify_records(sess, batch_size, num_epochs=num_epochs)
- with self.assertRaises(errors.OutOfRangeError):
- self._next_actual_batch(sess)
-
- def testReadWithEquivalentDataset(self):
- features = {
- "file": parsing_ops.FixedLenFeature([], dtypes.int64),
- "record": parsing_ops.FixedLenFeature([], dtypes.int64),
- }
- dataset = (
- core_readers.TFRecordDataset(self.test_filenames)
- .map(lambda x: parsing_ops.parse_single_example(x, features))
- .repeat(10).batch(2))
- iterator = dataset.make_initializable_iterator()
- init_op = iterator.initializer
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
- for file_batch, _, _, _, record_batch, _ in self._next_expected_batch(
- range(self._num_files), 2, 10):
- actual_batch = sess.run(next_element)
- self.assertAllEqual(file_batch, actual_batch["file"])
- self.assertAllEqual(record_batch, actual_batch["record"])
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testReadWithFusedShuffleRepeatDataset(self):
- num_epochs = 5
- total_records = num_epochs * self._num_records
- for batch_size in [1, 2]:
- # Test that shuffling with same seed produces the same result.
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- outputs1 = self.make_batch_feature(
- filenames=self.test_filenames[0],
- num_epochs=num_epochs,
- batch_size=batch_size,
- shuffle=True,
- shuffle_seed=5).make_one_shot_iterator().get_next()
- outputs2 = self.make_batch_feature(
- filenames=self.test_filenames[0],
- num_epochs=num_epochs,
- batch_size=batch_size,
- shuffle=True,
- shuffle_seed=5).make_one_shot_iterator().get_next()
- for _ in range(total_records // batch_size):
- batch1 = self._run_actual_batch(outputs1, sess)
- batch2 = self._run_actual_batch(outputs2, sess)
- for i in range(len(batch1)):
- self.assertAllEqual(batch1[i], batch2[i])
-
- # Test that shuffling with different seeds produces a different order.
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- outputs1 = self.make_batch_feature(
- filenames=self.test_filenames[0],
- num_epochs=num_epochs,
- batch_size=batch_size,
- shuffle=True,
- shuffle_seed=5).make_one_shot_iterator().get_next()
- outputs2 = self.make_batch_feature(
- filenames=self.test_filenames[0],
- num_epochs=num_epochs,
- batch_size=batch_size,
- shuffle=True,
- shuffle_seed=15).make_one_shot_iterator().get_next()
- all_equal = True
- for _ in range(total_records // batch_size):
- batch1 = self._run_actual_batch(outputs1, sess)
- batch2 = self._run_actual_batch(outputs2, sess)
- for i in range(len(batch1)):
- all_equal = all_equal and np.array_equal(batch1[i], batch2[i])
- self.assertFalse(all_equal)
-
- def testParallelReadersAndParsers(self):
- num_epochs = 5
- for batch_size in [1, 2]:
- for reader_num_threads in [2, 4]:
- for parser_num_threads in [2, 4]:
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- self.outputs = self.make_batch_feature(
- filenames=self.test_filenames,
- label_key="label",
- num_epochs=num_epochs,
- batch_size=batch_size,
- reader_num_threads=reader_num_threads,
- parser_num_threads=parser_num_threads).make_one_shot_iterator(
- ).get_next()
- self.verify_records(
- sess,
- batch_size,
- num_epochs=num_epochs,
- label_key_provided=True,
- interleave_cycle_length=reader_num_threads)
- with self.assertRaises(errors.OutOfRangeError):
- self._next_actual_batch(sess, label_key_provided=True)
-
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- self.outputs = self.make_batch_feature(
- filenames=self.test_filenames,
- num_epochs=num_epochs,
- batch_size=batch_size,
- reader_num_threads=reader_num_threads,
- parser_num_threads=parser_num_threads).make_one_shot_iterator(
- ).get_next()
- self.verify_records(
- sess,
- batch_size,
- num_epochs=num_epochs,
- interleave_cycle_length=reader_num_threads)
- with self.assertRaises(errors.OutOfRangeError):
- self._next_actual_batch(sess)
-
- def testDropFinalBatch(self):
- for batch_size in [1, 2]:
- for num_epochs in [1, 10]:
- with ops.Graph().as_default():
- # Basic test: read from file 0.
- outputs = self.make_batch_feature(
- filenames=self.test_filenames[0],
- label_key="label",
- num_epochs=num_epochs,
- batch_size=batch_size,
- drop_final_batch=True).make_one_shot_iterator().get_next()
- for tensor in nest.flatten(outputs):
- if isinstance(tensor, ops.Tensor): # Guard against SparseTensor.
- self.assertEqual(tensor.shape[0], batch_size)
-
- def testIndefiniteRepeatShapeInference(self):
- dataset = self.make_batch_feature(
- filenames=self.test_filenames[0],
- label_key="label",
- num_epochs=None,
- batch_size=32)
- for shape, clazz in zip(nest.flatten(dataset.output_shapes),
- nest.flatten(dataset.output_classes)):
- if issubclass(clazz, ops.Tensor):
- self.assertEqual(32, shape[0])
-
-
class MakeCsvDatasetTest(test_base.DatasetTestBase):
def _make_csv_dataset(self, filenames, batch_size, num_epochs=1, **kwargs):
@@ -866,218 +656,5 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
self.assertEqual(32, shape[0])
-class MakeTFRecordDatasetTest(
- reader_dataset_ops_test_base.TFRecordDatasetTestBase):
-
- def _interleave(self, iterators, cycle_length):
- pending_iterators = iterators
- open_iterators = []
- num_open = 0
- for i in range(cycle_length):
- if pending_iterators:
- open_iterators.append(pending_iterators.pop(0))
- num_open += 1
-
- while num_open:
- for i in range(min(cycle_length, len(open_iterators))):
- if open_iterators[i] is None:
- continue
- try:
- yield next(open_iterators[i])
- except StopIteration:
- if pending_iterators:
- open_iterators[i] = pending_iterators.pop(0)
- else:
- open_iterators[i] = None
- num_open -= 1
-
- def _next_expected_batch(self,
- file_indices,
- batch_size,
- num_epochs,
- cycle_length,
- drop_final_batch,
- use_parser_fn):
-
- def _next_record(file_indices):
- for j in file_indices:
- for i in range(self._num_records):
- yield j, i
-
- def _next_record_interleaved(file_indices, cycle_length):
- return self._interleave([_next_record([i]) for i in file_indices],
- cycle_length)
-
- record_batch = []
- batch_index = 0
- for _ in range(num_epochs):
- if cycle_length == 1:
- next_records = _next_record(file_indices)
- else:
- next_records = _next_record_interleaved(file_indices, cycle_length)
- for f, r in next_records:
- record = self._record(f, r)
- if use_parser_fn:
- record = record[1:]
- record_batch.append(record)
- batch_index += 1
- if len(record_batch) == batch_size:
- yield record_batch
- record_batch = []
- batch_index = 0
- if record_batch and not drop_final_batch:
- yield record_batch
-
- def _verify_records(self,
- sess,
- outputs,
- batch_size,
- file_index,
- num_epochs,
- interleave_cycle_length,
- drop_final_batch,
- use_parser_fn):
- if file_index is not None:
- file_indices = [file_index]
- else:
- file_indices = range(self._num_files)
-
- for expected_batch in self._next_expected_batch(
- file_indices, batch_size, num_epochs, interleave_cycle_length,
- drop_final_batch, use_parser_fn):
- actual_batch = sess.run(outputs)
- self.assertAllEqual(expected_batch, actual_batch)
-
- def _read_test(self, batch_size, num_epochs, file_index=None,
- num_parallel_reads=1, drop_final_batch=False, parser_fn=False):
- if file_index is None:
- file_pattern = self.test_filenames
- else:
- file_pattern = self.test_filenames[file_index]
-
- if parser_fn:
- fn = lambda x: string_ops.substr(x, 1, 999)
- else:
- fn = None
-
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- outputs = readers.make_tf_record_dataset(
- file_pattern=file_pattern,
- num_epochs=num_epochs,
- batch_size=batch_size,
- parser_fn=fn,
- num_parallel_reads=num_parallel_reads,
- drop_final_batch=drop_final_batch,
- shuffle=False).make_one_shot_iterator().get_next()
- self._verify_records(
- sess, outputs, batch_size, file_index, num_epochs=num_epochs,
- interleave_cycle_length=num_parallel_reads,
- drop_final_batch=drop_final_batch, use_parser_fn=parser_fn)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(outputs)
-
- def testRead(self):
- for batch_size in [1, 2]:
- for num_epochs in [1, 3]:
- # Basic test: read from file 0.
- self._read_test(batch_size, num_epochs, 0)
-
- # Basic test: read from file 1.
- self._read_test(batch_size, num_epochs, 1)
-
- # Basic test: read from both files.
- self._read_test(batch_size, num_epochs)
-
- # Basic test: read from both files, with parallel reads.
- self._read_test(batch_size, num_epochs, num_parallel_reads=8)
-
- def testDropFinalBatch(self):
- for batch_size in [1, 2, 10]:
- for num_epochs in [1, 3]:
- # Read from file 0.
- self._read_test(batch_size, num_epochs, 0, drop_final_batch=True)
-
- # Read from both files.
- self._read_test(batch_size, num_epochs, drop_final_batch=True)
-
- # Read from both files, with parallel reads.
- self._read_test(batch_size, num_epochs, num_parallel_reads=8,
- drop_final_batch=True)
-
- def testParserFn(self):
- for batch_size in [1, 2]:
- for num_epochs in [1, 3]:
- for drop_final_batch in [False, True]:
- self._read_test(batch_size, num_epochs, parser_fn=True,
- drop_final_batch=drop_final_batch)
- self._read_test(batch_size, num_epochs, num_parallel_reads=8,
- parser_fn=True, drop_final_batch=drop_final_batch)
-
- def _shuffle_test(self, batch_size, num_epochs, num_parallel_reads=1,
- seed=None):
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- dataset = readers.make_tf_record_dataset(
- file_pattern=self.test_filenames,
- num_epochs=num_epochs,
- batch_size=batch_size,
- num_parallel_reads=num_parallel_reads,
- shuffle=True,
- shuffle_seed=seed)
- iterator = dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- sess.run(iterator.initializer)
- first_batches = []
- try:
- while True:
- first_batches.append(sess.run(next_element))
- except errors.OutOfRangeError:
- pass
-
- sess.run(iterator.initializer)
- second_batches = []
- try:
- while True:
- second_batches.append(sess.run(next_element))
- except errors.OutOfRangeError:
- pass
-
- self.assertEqual(len(first_batches), len(second_batches))
- if seed is not None:
- # if you set a seed, should get the same results
- for i in range(len(first_batches)):
- self.assertAllEqual(first_batches[i], second_batches[i])
-
- expected = []
- for f in range(self._num_files):
- for r in range(self._num_records):
- expected.extend([self._record(f, r)] * num_epochs)
-
- for batches in (first_batches, second_batches):
- actual = []
- for b in batches:
- actual.extend(b)
- self.assertAllEqual(sorted(expected), sorted(actual))
-
- def testShuffle(self):
- for batch_size in [1, 2]:
- for num_epochs in [1, 3]:
- for num_parallel_reads in [1, 2]:
- # Test that all expected elements are produced
- self._shuffle_test(batch_size, num_epochs, num_parallel_reads)
- # Test that elements are produced in a consistent order if
- # you specify a seed.
- self._shuffle_test(batch_size, num_epochs, num_parallel_reads,
- seed=21345)
-
- def testIndefiniteRepeatShapeInference(self):
- dataset = readers.make_tf_record_dataset(
- file_pattern=self.test_filenames, num_epochs=None, batch_size=32)
- for shape in nest.flatten(dataset.output_shapes):
- self.assertEqual(32, shape[0])
-
-
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/make_tf_record_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/make_tf_record_dataset_test.py
new file mode 100644
index 0000000000..657cf3c00e
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/make_tf_record_dataset_test.py
@@ -0,0 +1,243 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.experimental.make_tf_record_dataset()`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
+from tensorflow.python.data.experimental.ops import readers
+from tensorflow.python.data.util import nest
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import string_ops
+from tensorflow.python.platform import test
+
+
+class MakeTFRecordDatasetTest(
+ reader_dataset_ops_test_base.TFRecordDatasetTestBase):
+
+ def _interleave(self, iterators, cycle_length):
+ pending_iterators = iterators
+ open_iterators = []
+ num_open = 0
+ for i in range(cycle_length):
+ if pending_iterators:
+ open_iterators.append(pending_iterators.pop(0))
+ num_open += 1
+
+ while num_open:
+ for i in range(min(cycle_length, len(open_iterators))):
+ if open_iterators[i] is None:
+ continue
+ try:
+ yield next(open_iterators[i])
+ except StopIteration:
+ if pending_iterators:
+ open_iterators[i] = pending_iterators.pop(0)
+ else:
+ open_iterators[i] = None
+ num_open -= 1
+
+ def _next_expected_batch(self,
+ file_indices,
+ batch_size,
+ num_epochs,
+ cycle_length,
+ drop_final_batch,
+ use_parser_fn):
+
+ def _next_record(file_indices):
+ for j in file_indices:
+ for i in range(self._num_records):
+ yield j, i
+
+ def _next_record_interleaved(file_indices, cycle_length):
+ return self._interleave([_next_record([i]) for i in file_indices],
+ cycle_length)
+
+ record_batch = []
+ batch_index = 0
+ for _ in range(num_epochs):
+ if cycle_length == 1:
+ next_records = _next_record(file_indices)
+ else:
+ next_records = _next_record_interleaved(file_indices, cycle_length)
+ for f, r in next_records:
+ record = self._record(f, r)
+ if use_parser_fn:
+ record = record[1:]
+ record_batch.append(record)
+ batch_index += 1
+ if len(record_batch) == batch_size:
+ yield record_batch
+ record_batch = []
+ batch_index = 0
+ if record_batch and not drop_final_batch:
+ yield record_batch
+
+ def _verify_records(self,
+ sess,
+ outputs,
+ batch_size,
+ file_index,
+ num_epochs,
+ interleave_cycle_length,
+ drop_final_batch,
+ use_parser_fn):
+ if file_index is not None:
+ file_indices = [file_index]
+ else:
+ file_indices = range(self._num_files)
+
+ for expected_batch in self._next_expected_batch(
+ file_indices, batch_size, num_epochs, interleave_cycle_length,
+ drop_final_batch, use_parser_fn):
+ actual_batch = sess.run(outputs)
+ self.assertAllEqual(expected_batch, actual_batch)
+
+ def _read_test(self, batch_size, num_epochs, file_index=None,
+ num_parallel_reads=1, drop_final_batch=False, parser_fn=False):
+ if file_index is None:
+ file_pattern = self.test_filenames
+ else:
+ file_pattern = self.test_filenames[file_index]
+
+ if parser_fn:
+ fn = lambda x: string_ops.substr(x, 1, 999)
+ else:
+ fn = None
+
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
+ outputs = readers.make_tf_record_dataset(
+ file_pattern=file_pattern,
+ num_epochs=num_epochs,
+ batch_size=batch_size,
+ parser_fn=fn,
+ num_parallel_reads=num_parallel_reads,
+ drop_final_batch=drop_final_batch,
+ shuffle=False).make_one_shot_iterator().get_next()
+ self._verify_records(
+ sess, outputs, batch_size, file_index, num_epochs=num_epochs,
+ interleave_cycle_length=num_parallel_reads,
+ drop_final_batch=drop_final_batch, use_parser_fn=parser_fn)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(outputs)
+
+ def testRead(self):
+ for batch_size in [1, 2]:
+ for num_epochs in [1, 3]:
+ # Basic test: read from file 0.
+ self._read_test(batch_size, num_epochs, 0)
+
+ # Basic test: read from file 1.
+ self._read_test(batch_size, num_epochs, 1)
+
+ # Basic test: read from both files.
+ self._read_test(batch_size, num_epochs)
+
+ # Basic test: read from both files, with parallel reads.
+ self._read_test(batch_size, num_epochs, num_parallel_reads=8)
+
+ def testDropFinalBatch(self):
+ for batch_size in [1, 2, 10]:
+ for num_epochs in [1, 3]:
+ # Read from file 0.
+ self._read_test(batch_size, num_epochs, 0, drop_final_batch=True)
+
+ # Read from both files.
+ self._read_test(batch_size, num_epochs, drop_final_batch=True)
+
+ # Read from both files, with parallel reads.
+ self._read_test(batch_size, num_epochs, num_parallel_reads=8,
+ drop_final_batch=True)
+
+ def testParserFn(self):
+ for batch_size in [1, 2]:
+ for num_epochs in [1, 3]:
+ for drop_final_batch in [False, True]:
+ self._read_test(batch_size, num_epochs, parser_fn=True,
+ drop_final_batch=drop_final_batch)
+ self._read_test(batch_size, num_epochs, num_parallel_reads=8,
+ parser_fn=True, drop_final_batch=drop_final_batch)
+
+ def _shuffle_test(self, batch_size, num_epochs, num_parallel_reads=1,
+ seed=None):
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
+ dataset = readers.make_tf_record_dataset(
+ file_pattern=self.test_filenames,
+ num_epochs=num_epochs,
+ batch_size=batch_size,
+ num_parallel_reads=num_parallel_reads,
+ shuffle=True,
+ shuffle_seed=seed)
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ sess.run(iterator.initializer)
+ first_batches = []
+ try:
+ while True:
+ first_batches.append(sess.run(next_element))
+ except errors.OutOfRangeError:
+ pass
+
+ sess.run(iterator.initializer)
+ second_batches = []
+ try:
+ while True:
+ second_batches.append(sess.run(next_element))
+ except errors.OutOfRangeError:
+ pass
+
+ self.assertEqual(len(first_batches), len(second_batches))
+ if seed is not None:
+ # if you set a seed, should get the same results
+ for i in range(len(first_batches)):
+ self.assertAllEqual(first_batches[i], second_batches[i])
+
+ expected = []
+ for f in range(self._num_files):
+ for r in range(self._num_records):
+ expected.extend([self._record(f, r)] * num_epochs)
+
+ for batches in (first_batches, second_batches):
+ actual = []
+ for b in batches:
+ actual.extend(b)
+ self.assertAllEqual(sorted(expected), sorted(actual))
+
+ def testShuffle(self):
+ for batch_size in [1, 2]:
+ for num_epochs in [1, 3]:
+ for num_parallel_reads in [1, 2]:
+ # Test that all expected elements are produced
+ self._shuffle_test(batch_size, num_epochs, num_parallel_reads)
+ # Test that elements are produced in a consistent order if
+ # you specify a seed.
+ self._shuffle_test(batch_size, num_epochs, num_parallel_reads,
+ seed=21345)
+
+ def testIndefiniteRepeatShapeInference(self):
+ dataset = readers.make_tf_record_dataset(
+ file_pattern=self.test_filenames, num_epochs=None, batch_size=32)
+ for shape in nest.flatten(dataset.output_shapes):
+ self.assertEqual(32, shape[0])
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py b/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py
new file mode 100644
index 0000000000..afd0fc3abf
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py
@@ -0,0 +1,337 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.experimental.map_and_batch()`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.python.data.experimental.ops import batching
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import script_ops
+from tensorflow.python.platform import test
+
+
+class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ ("Default", None, None),
+ ("SequentialCalls", 1, None),
+ ("ParallelCalls", 2, None),
+ ("ParallelBatches", None, 10),
+ )
+ def testMapAndBatch(self, num_parallel_calls, num_parallel_batches):
+ """Test a dataset that maps a TF function across its input elements."""
+ # The pipeline is TensorSliceDataset ->
+ # RepeatDataset(count) -> MapAndBatchDataset(square_3, batch_size).
+ components = (np.arange(7),
+ np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
+ np.array(37.0) * np.arange(7))
+
+ count = array_ops.placeholder(dtypes.int64, shape=[])
+ batch_size = array_ops.placeholder(dtypes.int64, shape=[])
+
+ def _map_fn(x, y, z):
+ return math_ops.square(x), math_ops.square(y), math_ops.square(z)
+
+ iterator = (
+ dataset_ops.Dataset.from_tensor_slices(components).repeat(count).apply(
+ batching.map_and_batch(
+ map_func=_map_fn,
+ batch_size=batch_size,
+ num_parallel_calls=num_parallel_calls,
+ num_parallel_batches=num_parallel_batches))
+ .make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ self.assertEqual([[None] + list(c.shape[1:]) for c in components],
+ [t.shape.as_list() for t in get_next])
+
+ with self.cached_session() as sess:
+ # Batch of a finite input, where the batch_size divides the
+ # total number of elements.
+ sess.run(init_op, feed_dict={count: 28, batch_size: 14})
+ num_batches = (28 * 7) // 14
+ for i in range(num_batches):
+ result = sess.run(get_next)
+ for component, result_component in zip(components, result):
+ for j in range(14):
+ self.assertAllEqual(component[(i * 14 + j) % 7]**2,
+ result_component[j])
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Batch of a finite input, where the batch_size does not
+ # divide the total number of elements.
+ sess.run(init_op, feed_dict={count: 14, batch_size: 8})
+
+ # We expect (num_batches - 1) full-sized batches.
+ num_batches = int(math.ceil((14 * 7) / 8))
+ for i in range(num_batches - 1):
+ result = sess.run(get_next)
+ for component, result_component in zip(components, result):
+ for j in range(8):
+ self.assertAllEqual(component[(i * 8 + j) % 7]**2,
+ result_component[j])
+ result = sess.run(get_next)
+ for component, result_component in zip(components, result):
+ for j in range((14 * 7) % 8):
+ self.assertAllEqual(component[((num_batches - 1) * 8 + j) % 7]**2,
+ result_component[j])
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Batch of an empty input should fail straight away.
+ sess.run(init_op, feed_dict={count: 0, batch_size: 8})
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Empty batch should be an initialization time error.
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(init_op, feed_dict={count: 14, batch_size: 0})
+
+ @parameterized.named_parameters(
+ ("Even", False),
+ ("Uneven", True),
+ )
+ def testMapAndBatchPartialBatch(self, drop_remainder):
+ iterator = (
+ dataset_ops.Dataset.range(10).apply(
+ batching.map_and_batch(
+ lambda x: array_ops.reshape(x * x, [1]),
+ batch_size=4,
+ drop_remainder=drop_remainder)).make_one_shot_iterator())
+ if drop_remainder:
+ self.assertEqual([4, 1], iterator.output_shapes.as_list())
+ else:
+ self.assertEqual([None, 1], iterator.output_shapes.as_list())
+ next_element = iterator.get_next()
+ with self.cached_session() as sess:
+ self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
+ self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
+ if not drop_remainder:
+ self.assertAllEqual([[64], [81]], sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testMapAndBatchYieldsPartialBatch(self):
+ iterator = (dataset_ops.Dataset.range(10)
+ .apply(batching.map_and_batch(
+ lambda x: array_ops.reshape(x * x, [1]), 4))
+ .make_one_shot_iterator())
+ self.assertEqual([None, 1], iterator.output_shapes.as_list())
+ next_element = iterator.get_next()
+ with self.cached_session() as sess:
+ self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
+ self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
+ self.assertAllEqual([[64], [81]], sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testMapAndBatchParallelGetNext(self):
+ iterator = (dataset_ops.Dataset.range(50000)
+ .apply(batching.map_and_batch(lambda x: x, batch_size=100))
+ .make_one_shot_iterator())
+ elements = []
+ for _ in range(100):
+ elements.append(iterator.get_next())
+ with self.cached_session() as sess:
+ for i in range(5):
+ got = sess.run(elements)
+ got.sort(key=lambda x: x[0])
+ expected = []
+ for j in range(100):
+ expected.append(range(i*10000+j*100, i*10000+(j+1)*100))
+ self.assertAllEqual(got, expected)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(elements)
+
+ def testMapAndBatchParallelGetNextDropRemainder(self):
+ iterator = (
+ dataset_ops.Dataset.range(49999).apply(
+ batching.map_and_batch(
+ lambda x: x, batch_size=100, drop_remainder=True))
+ .make_one_shot_iterator())
+ elements = []
+ for _ in range(100):
+ elements.append(iterator.get_next())
+ with self.cached_session() as sess:
+ for i in range(4):
+ got = sess.run(elements)
+ got.sort(key=lambda x: x[0])
+ expected = []
+ for j in range(100):
+ expected.append(range(i*10000+j*100, i*10000+(j+1)*100))
+ self.assertAllEqual(got, expected)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(elements)
+
+ def testMapAndBatchSparse(self):
+
+ def _sparse(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=[[0]], values=(i * [1]), dense_shape=[1])
+
+ iterator = dataset_ops.Dataset.range(10).apply(
+ batching.map_and_batch(_sparse, 5)).make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ for i in range(2):
+ actual = sess.run(get_next)
+ expected = sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
+ values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4],
+ dense_shape=[5, 1])
+ self.assertTrue(sparse_tensor.is_sparse(actual))
+ self.assertSparseValuesEqual(actual, expected)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testMapAndBatchFails(self):
+ """Test a dataset that maps a TF function across its input elements."""
+ dataset = dataset_ops.Dataset.from_tensors(
+ array_ops.check_numerics(
+ constant_op.constant(1.0) / constant_op.constant(0.0), "oops"))
+ batch_size = array_ops.placeholder(dtypes.int64, shape=[])
+ iterator = (
+ dataset.apply(batching.map_and_batch(lambda x: x, batch_size))
+ .make_initializable_iterator())
+ init_op = iterator.initializer
+ with self.cached_session() as sess:
+ with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"):
+ sess.run(init_op, feed_dict={batch_size: 14})
+
+ def testMapAndBatchShapeMismatch(self):
+ """Test a dataset that maps a TF function across its input elements."""
+
+ def generator():
+ yield [1]
+ yield [2]
+ yield [3]
+ yield [[4, 5, 6]]
+
+ dataset = dataset_ops.Dataset.from_generator(
+ generator, output_types=dtypes.int32)
+ batch_size = 4
+ iterator = (
+ dataset.apply(batching.map_and_batch(lambda x: x, batch_size))
+ .make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ with self.assertRaisesRegexp(errors.InvalidArgumentError,
+ "number of elements does not match"):
+ sess.run(get_next)
+
+ def testMapAndBatchImplicitDispose(self):
+ # Tests whether a map and batch dataset will be cleaned up correctly when
+ # the pipeline does not run it until exhaustion.
+ # The pipeline is TensorSliceDataset -> RepeatDataset(1000) ->
+ # MapAndBatchDataset(f=square_3, batch_size=100).
+ components = (np.arange(1000),
+ np.array([[1, 2, 3]]) * np.arange(1000)[:, np.newaxis],
+ np.array(37.0) * np.arange(1000))
+
+ def _map_fn(x, y, z):
+ return math_ops.square(x), math_ops.square(y), math_ops.square(z)
+
+ dataset = dataset_ops.Dataset.from_tensor_slices(components).repeat(
+ 1000).apply(batching.map_and_batch(_map_fn, batch_size=100))
+ dataset = dataset.prefetch(5)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ for _ in range(3):
+ sess.run(get_next)
+
+ @parameterized.named_parameters(
+ ("1", 0),
+ ("2", 5),
+ ("3", 10),
+ ("4", 90),
+ ("5", 95),
+ ("6", 99),
+ )
+ def testMapAndBatchOutOfRangeError(self, threshold):
+
+ def raising_py_fn(i):
+ if i >= threshold:
+ raise StopIteration()
+ else:
+ return i
+
+ iterator = (
+ dataset_ops.Dataset.range(100).apply(
+ batching.map_and_batch(
+ lambda x: script_ops.py_func(raising_py_fn, [x], dtypes.int64),
+ batch_size=10)).make_one_shot_iterator())
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ for i in range(threshold // 10):
+ self.assertAllEqual([i * 10 + j for j in range(10)], sess.run(get_next))
+ if threshold % 10 != 0:
+ self.assertAllEqual(
+ [threshold // 10 * 10 + j for j in range(threshold % 10)],
+ sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ @parameterized.named_parameters(
+ ("1", False, dtypes.bool),
+ ("2", -42, dtypes.int8),
+ ("3", -42, dtypes.int16),
+ ("4", -42, dtypes.int32),
+ ("5", -42, dtypes.int64),
+ ("6", 42, dtypes.uint8),
+ ("7", 42, dtypes.uint16),
+ ("8", 42.0, dtypes.float16),
+ ("9", 42.0, dtypes.float32),
+ ("10", 42.0, dtypes.float64),
+ ("11", b"hello", dtypes.string),
+ )
+ def testMapAndBatchTypes(self, element, dtype):
+ def gen():
+ yield element
+
+ dataset = dataset_ops.Dataset.from_generator(gen, dtype).repeat(100).apply(
+ batching.map_and_batch(lambda x: x, batch_size=10))
+
+ get_next = dataset.make_one_shot_iterator().get_next()
+
+ with self.cached_session() as sess:
+ for _ in range(10):
+ self.assertAllEqual([element for _ in range(10)], sess.run(get_next))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/threadpool_dataset_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/override_threadpool_test.py
index 4432dcb05a..5e419a9b2f 100644
--- a/tensorflow/python/data/experimental/kernel_tests/threadpool_dataset_ops_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/override_threadpool_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for the experimental input pipeline statistics gathering ops."""
+"""Tests for the private `override_threadpool()` transformation."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -32,8 +32,8 @@ from tensorflow.python.ops import script_ops
from tensorflow.python.platform import test
-class OverrideThreadpoolDatasetTest(test_base.DatasetTestBase,
- parameterized.TestCase):
+class OverrideThreadpoolTest(test_base.DatasetTestBase,
+ parameterized.TestCase):
@parameterized.named_parameters(
("1", 1, None),
diff --git a/tensorflow/python/data/experimental/kernel_tests/interleave_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/parallel_interleave_test.py
index 560902caad..90ac250df7 100644
--- a/tensorflow/python/data/experimental/kernel_tests/interleave_dataset_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/parallel_interleave_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
+"""Tests for `tf.data.experimental.parallel_interleave()`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -37,7 +37,7 @@ from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import test
-class ParallelInterleaveDatasetTest(test_base.DatasetTestBase):
+class ParallelInterleaveTest(test_base.DatasetTestBase):
def setUp(self):
diff --git a/tensorflow/python/data/experimental/kernel_tests/parsing_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/parse_example_dataset_test.py
index 13f924b656..723e709ae8 100644
--- a/tensorflow/python/data/experimental/kernel_tests/parsing_ops_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/parse_example_dataset_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for tensorflow.ops.parsing_ops."""
+"""Tests for `tf.data.experimental.parse_example_dataset()."""
from __future__ import absolute_import
from __future__ import division
@@ -73,7 +73,7 @@ def _compare_output_to_expected(tester, dict_tensors, expected_tensors,
i += 1
-class ParseExampleTest(test_base.DatasetTestBase):
+class ParseExampleDatasetTest(test_base.DatasetTestBase):
def _test(self,
input_tensor,
diff --git a/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py b/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py
new file mode 100644
index 0000000000..f73725366c
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py
@@ -0,0 +1,234 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.experimental.prefetch_to_device()`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.data.experimental.ops import prefetching_ops
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import test
+
+
+class PrefetchToDeviceTest(test_base.DatasetTestBase):
+
+ def testPrefetchToDevice(self):
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.prefetch_to_device("/cpu:1"))
+
+ # NOTE(mrry): This device block creates the "host" dataset and iterator on
+ # /cpu:0, and ensures that the prefetching is across devices. In typical use
+ # this would not be necessary, because the GPU device would not support any
+ # of the dataset-related ops.
+ with ops.device("/cpu:0"):
+ iterator = device_dataset.make_one_shot_iterator()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ next_element = iterator.get_next()
+ self.assertEqual(dtypes.int64, next_element.dtype)
+ self.assertEqual([], next_element.shape)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testPrefetchToSameDevice(self):
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.prefetch_to_device(
+ "/job:localhost/replica:0/task:0/device:CPU:0"))
+
+ # NOTE(mrry): This device block creates the "host" dataset and iterator on
+ # /cpu:0, and ensures that the prefetching is across devices. In typical use
+ # this would not be necessary, because the GPU device would not support any
+ # of the dataset-related ops.
+ with ops.device("/cpu:0"):
+ iterator = device_dataset.make_one_shot_iterator()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ next_element = iterator.get_next()
+ self.assertEqual(dtypes.int64, next_element.dtype)
+ self.assertEqual([], next_element.shape)
+
+ with self.cached_session() as sess:
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testPrefetchDictToDevice(self):
+ host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
+ device_dataset = host_dataset.apply(
+ prefetching_ops.prefetch_to_device("/cpu:1"))
+
+ # NOTE(mrry): This device block creates the "host" dataset and iterator on
+ # /cpu:0, and ensures that the prefetching is across devices. In typical use
+ # this would not be necessary, because the GPU device would not support any
+ # of the dataset-related ops.
+ with ops.device("/cpu:0"):
+ iterator = device_dataset.make_one_shot_iterator()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ next_element = iterator.get_next()
+ self.assertEqual(dtypes.int64, next_element["a"].dtype)
+ self.assertEqual([], next_element["a"].shape)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ for i in range(10):
+ self.assertEqual({"a": i}, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testPrefetchSparseTensorsToDevice(self):
+ def make_tensor(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=[[0, 0]], values=(i*[1]), dense_shape=[2, 2])
+ host_dataset = dataset_ops.Dataset.range(10).map(make_tensor)
+
+ device_dataset = host_dataset.apply(
+ prefetching_ops.prefetch_to_device("/cpu:1"))
+
+ # NOTE(mrry): This device block creates the "host" dataset and iterator on
+ # /cpu:0, and ensures that the prefetching is across devices. In typical use
+ # this would not be necessary, because the GPU device would not support any
+ # of the dataset-related ops.
+ with ops.device("/cpu:0"):
+ iterator = device_dataset.make_one_shot_iterator()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ next_element = iterator.get_next()
+ self.assertEqual(dtypes.int64, next_element.dtype)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ for i in range(10):
+ actual = sess.run(next_element)
+ self.assertAllEqual([i], actual.values)
+ self.assertAllEqual([[0, 0]], actual.indices)
+ self.assertAllEqual([2, 2], actual.dense_shape)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testPrefetchToDeviceGpu(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.prefetch_to_device("/gpu:0"))
+
+ iterator = device_dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testPrefetchToDeviceWithReInit(self):
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.prefetch_to_device("/cpu:1"))
+
+ # NOTE(mrry): This device block creates the "host" dataset and iterator on
+ # /cpu:0, and ensures that the prefetching is across devices. In typical use
+ # this would not be necessary, because the GPU device would not support any
+ # of the dataset-related ops.
+ with ops.device("/cpu:0"):
+ iterator = device_dataset.make_initializable_iterator()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ next_element = iterator.get_next()
+ self.assertEqual(dtypes.int64, next_element.dtype)
+ self.assertEqual([], next_element.shape)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ sess.run(iterator.initializer)
+ for i in range(5):
+ self.assertEqual(i, sess.run(next_element))
+ sess.run(iterator.initializer)
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testPrefetchToDeviceGpuWithReInit(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.prefetch_to_device("/gpu:0"))
+
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(5):
+ self.assertEqual(i, sess.run(next_element))
+ sess.run(iterator.initializer)
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/reader_dataset_ops_test_base.py b/tensorflow/python/data/experimental/kernel_tests/reader_dataset_ops_test_base.py
index b6ab80d132..fe0b3b5f3b 100644
--- a/tensorflow/python/data/experimental/kernel_tests/reader_dataset_ops_test_base.py
+++ b/tensorflow/python/data/experimental/kernel_tests/reader_dataset_ops_test_base.py
@@ -63,11 +63,11 @@ class FixedLengthRecordDatasetTestBase(test_base.DatasetTestBase):
return filenames
-class ReadBatchFeaturesTestBase(test_base.DatasetTestBase):
+class MakeBatchedFeaturesDatasetTestBase(test_base.DatasetTestBase):
"""Base class for setting up and testing `make_batched_feature_dataset`."""
def setUp(self):
- super(ReadBatchFeaturesTestBase, self).setUp()
+ super(MakeBatchedFeaturesDatasetTestBase, self).setUp()
self._num_files = 2
self._num_records = 7
self.test_filenames = self._createFiles()
diff --git a/tensorflow/python/data/experimental/kernel_tests/resample_test.py b/tensorflow/python/data/experimental/kernel_tests/rejection_resample_test.py
index 775648c943..4c879dbae6 100644
--- a/tensorflow/python/data/experimental/kernel_tests/resample_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/rejection_resample_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
+"""Tests for `tf.data.experimental.rejection_resample()`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -58,7 +58,7 @@ def _time_resampling(
return end_time - start_time
-class ResampleTest(test_base.DatasetTestBase, parameterized.TestCase):
+class RejectionResampleTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.named_parameters(
("InitialDistributionKnown", True),
diff --git a/tensorflow/python/data/experimental/kernel_tests/dataset_constructor_op_test.py b/tensorflow/python/data/experimental/kernel_tests/restructured_dataset_test.py
index 3fc7157bc5..516e489d04 100644
--- a/tensorflow/python/data/experimental/kernel_tests/dataset_constructor_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/restructured_dataset_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
+"""Tests for the private `_RestructuredDataset` transformation."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -26,7 +26,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class DatasetConstructorTest(test_base.DatasetTestBase):
+class RestructuredDatasetTest(test_base.DatasetTestBase):
def testRestructureDataset(self):
components = (array_ops.placeholder(dtypes.int32),
diff --git a/tensorflow/python/data/experimental/kernel_tests/scan_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/scan_test.py
index 78ec80de23..0730455431 100644
--- a/tensorflow/python/data/experimental/kernel_tests/scan_dataset_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/scan_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
+"""Tests for `tf.data.experimental.scan()`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -34,7 +34,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class ScanDatasetTest(test_base.DatasetTestBase):
+class ScanTest(test_base.DatasetTestBase):
def _counting_dataset(self, start, scan_fn):
return dataset_ops.Dataset.from_tensors(0).repeat().apply(
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD b/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD
index 58a335ae4f..e556b65b7c 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD
@@ -70,6 +70,26 @@ py_test(
)
py_test(
+ name = "checkpoint_input_pipeline_hook_test",
+ size = "small",
+ srcs = ["checkpoint_input_pipeline_hook_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/data/experimental/ops:iterator_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/estimator:estimator_py",
+ ],
+)
+
+py_test(
name = "concatenate_dataset_serialization_test",
size = "small",
srcs = ["concatenate_dataset_serialization_test.py"],
@@ -580,7 +600,7 @@ py_test(
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
- "//tensorflow/python/data/experimental/kernel_tests:sql_dataset_op_test_base",
+ "//tensorflow/python/data/experimental/kernel_tests:sql_dataset_test_base",
"//tensorflow/python/data/experimental/ops:readers",
],
)
diff --git a/tensorflow/python/data/experimental/kernel_tests/iterator_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/checkpoint_input_pipeline_hook_test.py
index 94393d6d4b..94393d6d4b 100644
--- a/tensorflow/python/data/experimental/kernel_tests/iterator_ops_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/checkpoint_input_pipeline_hook_test.py
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/parse_example_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/parse_example_dataset_serialization_test.py
index a0dd6960b0..b3dfe21486 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/parse_example_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/parse_example_dataset_serialization_test.py
@@ -23,7 +23,7 @@ from tensorflow.python.platform import test
class ParseExampleDatasetSerializationTest(
- reader_dataset_ops_test_base.ReadBatchFeaturesTestBase,
+ reader_dataset_ops_test_base.MakeBatchedFeaturesDatasetTestBase,
dataset_serialization_test_base.DatasetSerializationTestBase):
def ParseExampleDataset(self, num_repeat, batch_size):
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/sql_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/sql_dataset_serialization_test.py
index b179770ce3..006279bbe1 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/sql_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/sql_dataset_serialization_test.py
@@ -19,7 +19,7 @@ from __future__ import print_function
import os
-from tensorflow.python.data.experimental.kernel_tests import sql_dataset_op_test_base
+from tensorflow.python.data.experimental.kernel_tests import sql_dataset_test_base
from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.data.experimental.ops import readers
from tensorflow.python.framework import dtypes
@@ -28,7 +28,7 @@ from tensorflow.python.platform import test
class SqlDatasetSerializationTest(
- sql_dataset_op_test_base.SqlDatasetTestBase,
+ sql_dataset_test_base.SqlDatasetTestBase,
dataset_serialization_test_base.DatasetSerializationTestBase):
def _build_dataset(self, num_repeats):
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization_integration_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization_integration_test.py
deleted file mode 100644
index 88d5c896c9..0000000000
--- a/tensorflow/python/data/experimental/kernel_tests/serialization_integration_test.py
+++ /dev/null
@@ -1,85 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Integration test for dataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import os
-
-from tensorflow.python.data.experimental.ops import iterator_ops as contrib_iterator_ops
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import ops
-from tensorflow.python.platform import test
-from tensorflow.python.training import saver as saver_lib
-
-
-class SerializationIntegrationTest(test.TestCase):
-
- def _build_input_pipeline(self, name, num_outputs):
- with ops.name_scope(name):
- ds = dataset_ops.Dataset.range(num_outputs).shuffle(
- 10, reshuffle_each_iteration=False).prefetch(10)
- iterator = ds.make_initializable_iterator()
- saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
- ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
- return iterator.initializer, iterator.get_next()
-
- def _build_graph(self, num_pipelines, num_outputs):
- init_ops = []
- get_next_ops = []
- for i in range(num_pipelines):
- name = "input_pipeline_%d" % i
- init_op, get_next_op = self._build_input_pipeline(name, num_outputs)
- init_ops.append(init_op)
- get_next_ops.append(get_next_op)
- saver = saver_lib.Saver()
- return init_ops, get_next_ops, saver
-
- def _ckpt_path(self):
- return os.path.join(self.get_temp_dir(), "iterator")
-
- def testConcurrentSaves(self):
- num_pipelines = 100
- num_outputs = 100
- break_point = 10
- all_outputs = [[] for _ in range(num_pipelines)]
- with ops.Graph().as_default() as g:
- init_ops, get_next_ops, saver = self._build_graph(num_pipelines,
- num_outputs)
- with self.session(graph=g) as sess:
- sess.run(init_ops)
- for _ in range(break_point):
- output = sess.run(get_next_ops)
- for i in range(num_pipelines):
- all_outputs[i].append(output[i])
- saver.save(sess, self._ckpt_path())
-
- with ops.Graph().as_default() as g:
- init_ops, get_next_ops, saver = self._build_graph(num_pipelines,
- num_outputs)
- with self.session(graph=g) as sess:
- saver.restore(sess, self._ckpt_path())
- for _ in range(num_outputs - break_point):
- output = sess.run(get_next_ops)
- for i in range(num_pipelines):
- all_outputs[i].append(output[i])
-
- for output in all_outputs:
- self.assertSequenceEqual(sorted(output), range(num_outputs))
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/shuffle_and_repeat_test.py
index 50895b5945..c208963a86 100644
--- a/tensorflow/python/data/experimental/kernel_tests/shuffle_dataset_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/shuffle_and_repeat_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
+"""Tests for `tf.data.experimental.shuffle_and_repeat()`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
diff --git a/tensorflow/python/data/experimental/kernel_tests/sql_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/sql_dataset_test.py
index 301f75488a..a2c1169638 100644
--- a/tensorflow/python/data/experimental/kernel_tests/sql_dataset_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/sql_dataset_test.py
@@ -12,19 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for experimental sql input op."""
+"""Tests for `tf.data.experimental.SqlDataset`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.data.experimental.kernel_tests import sql_dataset_op_test_base
+from tensorflow.python.data.experimental.kernel_tests import sql_dataset_test_base
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.platform import test
-class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
+class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# Test that SqlDataset can read from a database table.
def testReadResultSet(self):
diff --git a/tensorflow/python/data/experimental/kernel_tests/sql_dataset_op_test_base.py b/tensorflow/python/data/experimental/kernel_tests/sql_dataset_test_base.py
index a135c357f0..6aaaa90c65 100644
--- a/tensorflow/python/data/experimental/kernel_tests/sql_dataset_op_test_base.py
+++ b/tensorflow/python/data/experimental/kernel_tests/sql_dataset_test_base.py
@@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Base class for testing SqlDataset."""
-
+"""Base class for testing `tf.data.experimental.SqlDataset`."""
from __future__ import absolute_import
from __future__ import division
diff --git a/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py
index 19f5a62d45..427654cd76 100644
--- a/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py
@@ -280,7 +280,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
class FeatureStatsDatasetTest(
stats_dataset_test_base.StatsDatasetTestBase,
- reader_dataset_ops_test_base.ReadBatchFeaturesTestBase):
+ reader_dataset_ops_test_base.MakeBatchedFeaturesDatasetTestBase):
def testFeaturesStats(self):
num_epochs = 5
diff --git a/tensorflow/python/data/experimental/kernel_tests/writer_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/tf_record_writer_test.py
index 25a2e63ba1..8fd0ad50c4 100644
--- a/tensorflow/python/data/experimental/kernel_tests/writer_ops_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/tf_record_writer_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
+"""Tests for `tf.data.experimental.TFRecordWriter`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
diff --git a/tensorflow/python/data/experimental/kernel_tests/unbatch_test.py b/tensorflow/python/data/experimental/kernel_tests/unbatch_test.py
new file mode 100644
index 0000000000..0278a208cb
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/unbatch_test.py
@@ -0,0 +1,300 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.experimental.unbatch()`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.python.client import session
+from tensorflow.python.data.experimental.ops import batching
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import string_ops
+from tensorflow.python.platform import test
+from tensorflow.python.util import compat
+
+
+class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
+
+ def testUnbatchWithUnknownRankInput(self):
+ placeholder = array_ops.placeholder(dtypes.int32)
+ dataset = dataset_ops.Dataset.from_tensors(placeholder).apply(
+ batching.unbatch())
+ iterator = dataset.make_initializable_iterator()
+ next_elem = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer, feed_dict={placeholder: [0, 1, 2, 3]})
+ for i in range(4):
+ self.assertEqual(i, sess.run(next_elem))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_elem)
+
+ def testUnbatchScalarDataset(self):
+ data = tuple([math_ops.range(10) for _ in range(3)])
+ data = dataset_ops.Dataset.from_tensor_slices(data)
+ expected_types = (dtypes.int32,) * 3
+ data = data.batch(2)
+ self.assertEqual(expected_types, data.output_types)
+ data = data.apply(batching.unbatch())
+ self.assertEqual(expected_types, data.output_types)
+
+ iterator = data.make_one_shot_iterator()
+ op = iterator.get_next()
+
+ with self.cached_session() as sess:
+ for i in range(10):
+ self.assertEqual((i,) * 3, sess.run(op))
+
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(op)
+
+ def testUnbatchDatasetWithStrings(self):
+ data = tuple([math_ops.range(10) for _ in range(3)])
+ data = dataset_ops.Dataset.from_tensor_slices(data)
+ data = data.map(lambda x, y, z: (x, string_ops.as_string(y), z))
+ expected_types = (dtypes.int32, dtypes.string, dtypes.int32)
+ data = data.batch(2)
+ self.assertEqual(expected_types, data.output_types)
+ data = data.apply(batching.unbatch())
+ self.assertEqual(expected_types, data.output_types)
+
+ iterator = data.make_one_shot_iterator()
+ op = iterator.get_next()
+
+ with self.cached_session() as sess:
+ for i in range(10):
+ self.assertEqual((i, compat.as_bytes(str(i)), i), sess.run(op))
+
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(op)
+
+ def testUnbatchDatasetWithSparseTensor(self):
+ st = sparse_tensor.SparseTensorValue(
+ indices=[[i, i] for i in range(10)],
+ values=list(range(10)),
+ dense_shape=[10, 10])
+ data = dataset_ops.Dataset.from_tensors(st)
+ data = data.apply(batching.unbatch())
+ data = data.batch(5)
+ data = data.apply(batching.unbatch())
+ iterator = data.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ for i in range(10):
+ st_row = sess.run(next_element)
+ self.assertEqual([i], st_row.indices)
+ self.assertEqual([i], st_row.values)
+ self.assertEqual([10], st_row.dense_shape)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testUnbatchDatasetWithDenseAndSparseTensor(self):
+ st = sparse_tensor.SparseTensorValue(
+ indices=[[i, i] for i in range(10)],
+ values=list(range(10)),
+ dense_shape=[10, 10])
+ data = dataset_ops.Dataset.from_tensors((list(range(10)), st))
+ data = data.apply(batching.unbatch())
+ data = data.batch(5)
+ data = data.apply(batching.unbatch())
+ iterator = data.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ for i in range(10):
+ dense_elem, st_row = sess.run(next_element)
+ self.assertEqual(i, dense_elem)
+ self.assertEqual([i], st_row.indices)
+ self.assertEqual([i], st_row.values)
+ self.assertEqual([10], st_row.dense_shape)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testUnbatchSingleElementTupleDataset(self):
+ data = tuple([(math_ops.range(10),) for _ in range(3)])
+ data = dataset_ops.Dataset.from_tensor_slices(data)
+ expected_types = ((dtypes.int32,),) * 3
+ data = data.batch(2)
+ self.assertEqual(expected_types, data.output_types)
+ data = data.apply(batching.unbatch())
+ self.assertEqual(expected_types, data.output_types)
+
+ iterator = data.make_one_shot_iterator()
+ op = iterator.get_next()
+
+ with self.cached_session() as sess:
+ for i in range(10):
+ self.assertEqual(((i,),) * 3, sess.run(op))
+
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(op)
+
+ def testUnbatchMultiElementTupleDataset(self):
+ data = tuple([(math_ops.range(10 * i, 10 * i + 10),
+ array_ops.fill([10], "hi")) for i in range(3)])
+ data = dataset_ops.Dataset.from_tensor_slices(data)
+ expected_types = ((dtypes.int32, dtypes.string),) * 3
+ data = data.batch(2)
+ self.assertAllEqual(expected_types, data.output_types)
+ data = data.apply(batching.unbatch())
+ self.assertAllEqual(expected_types, data.output_types)
+
+ iterator = data.make_one_shot_iterator()
+ op = iterator.get_next()
+
+ with self.cached_session() as sess:
+ for i in range(10):
+ self.assertEqual(((i, b"hi"), (10 + i, b"hi"), (20 + i, b"hi")),
+ sess.run(op))
+
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(op)
+
+ def testUnbatchEmpty(self):
+ data = dataset_ops.Dataset.from_tensors(
+ (constant_op.constant([]), constant_op.constant([], shape=[0, 4]),
+ constant_op.constant([], shape=[0, 4, 0])))
+ data = data.apply(batching.unbatch())
+ iterator = data.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testUnbatchStaticShapeMismatch(self):
+ data = dataset_ops.Dataset.from_tensors((np.arange(7), np.arange(8),
+ np.arange(9)))
+ with self.assertRaises(ValueError):
+ data.apply(batching.unbatch())
+
+ def testUnbatchDynamicShapeMismatch(self):
+ ph1 = array_ops.placeholder(dtypes.int32, shape=[None])
+ ph2 = array_ops.placeholder(dtypes.int32, shape=None)
+ data = dataset_ops.Dataset.from_tensors((ph1, ph2))
+ data = data.apply(batching.unbatch())
+ iterator = data.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ # Mismatch in the 0th dimension.
+ sess.run(
+ iterator.initializer,
+ feed_dict={
+ ph1: np.arange(7).astype(np.int32),
+ ph2: np.arange(8).astype(np.int32)
+ })
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(next_element)
+
+ # No 0th dimension (i.e. scalar value) for one component.
+ sess.run(
+ iterator.initializer,
+ feed_dict={
+ ph1: np.arange(7).astype(np.int32),
+ ph2: 7
+ })
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(next_element)
+
+
+class UnbatchBenchmark(test.Benchmark):
+
+ def benchmarkNativeUnbatch(self):
+ batch_sizes = [1, 2, 5, 10, 20, 50]
+ elems_per_trial = 10000
+ with ops.Graph().as_default():
+ dataset = dataset_ops.Dataset.from_tensors("element").repeat(None)
+ batch_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
+ dataset = dataset.batch(batch_size_placeholder)
+ dataset = dataset.apply(batching.unbatch())
+ dataset = dataset.skip(elems_per_trial)
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with session.Session() as sess:
+ for batch_size in batch_sizes:
+ deltas = []
+ for _ in range(5):
+ sess.run(
+ iterator.initializer,
+ feed_dict={batch_size_placeholder: batch_size})
+ start = time.time()
+ sess.run(next_element.op)
+ end = time.time()
+ deltas.append((end - start) / elems_per_trial)
+
+ median_wall_time = np.median(deltas)
+ print("Unbatch (native) batch size: %d Median wall time per element:"
+ " %f microseconds" % (batch_size, median_wall_time * 1e6))
+ self.report_benchmark(
+ iters=10000,
+ wall_time=median_wall_time,
+ name="benchmark_unbatch_dataset_native_batch_size_%d" %
+ batch_size)
+
+ # Include a benchmark of the previous `unbatch()` implementation that uses
+ # a composition of more primitive ops. Eventually we'd hope to generate code
+ # that is as good in both cases.
+ def benchmarkOldUnbatchImplementation(self):
+ batch_sizes = [1, 2, 5, 10, 20, 50]
+ elems_per_trial = 10000
+ with ops.Graph().as_default():
+ dataset = dataset_ops.Dataset.from_tensors("element").repeat(None)
+ batch_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
+ dataset = dataset.batch(batch_size_placeholder)
+ dataset = dataset.flat_map(dataset_ops.Dataset.from_tensor_slices)
+ dataset = dataset.skip(elems_per_trial)
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with session.Session() as sess:
+ for batch_size in batch_sizes:
+ deltas = []
+ for _ in range(5):
+ sess.run(
+ iterator.initializer,
+ feed_dict={batch_size_placeholder: batch_size})
+ start = time.time()
+ sess.run(next_element.op)
+ end = time.time()
+ deltas.append((end - start) / elems_per_trial)
+
+ median_wall_time = np.median(deltas)
+ print("Unbatch (unfused) batch size: %d Median wall time per element:"
+ " %f microseconds" % (batch_size, median_wall_time * 1e6))
+ self.report_benchmark(
+ iters=10000,
+ wall_time=median_wall_time,
+ name="benchmark_unbatch_dataset_unfused_batch_size_%d" %
+ batch_size)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/unique_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/unique_test.py
index b5a0b20f3f..847cff26b0 100644
--- a/tensorflow/python/data/experimental/kernel_tests/unique_dataset_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/unique_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
+"""Tests for `tf.data.experimental.unique()`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -26,7 +26,7 @@ from tensorflow.python.platform import test
from tensorflow.python.util import compat
-class UniqueDatasetTest(test_base.DatasetTestBase):
+class UniqueTest(test_base.DatasetTestBase):
def _testSimpleHelper(self, dtype, test_cases):
"""Test the `unique()` transformation on a list of test cases.
diff --git a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
index 230ae3f3fd..0c372ebb10 100644
--- a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
+"""Tests for `tf.data.Dataset.map()`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -267,6 +267,35 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
+ def testCaptureIterator(self):
+
+ def _build_ds(iterator):
+
+ def _map_fn(x):
+ get_next = iterator.get_next()
+ return x * get_next
+
+ return dataset_ops.Dataset.range(10).map(_map_fn)
+
+ def _build_graph():
+ captured_iterator = dataset_ops.Dataset.range(
+ 10).make_initializable_iterator()
+ ds = _build_ds(captured_iterator)
+ iterator = ds.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ return captured_iterator.initializer, init_op, get_next
+
+ with ops.Graph().as_default() as g:
+ captured_init_op, init_op, get_next = _build_graph()
+ with self.session(graph=g) as sess:
+ sess.run(captured_init_op)
+ sess.run(init_op)
+ for i in range(10):
+ self.assertEqual(i * i, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
def testCaptureHashTable(self):
# NOTE(mrry): We must use the V2 variants of `HashTable`
# etc. because these produce a `tf.resource`-typed output that is