aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/data
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-10-01 16:45:11 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-01 16:50:05 -0700
commitb72265dc002e712fc3d0f33434f13c7a36a484b2 (patch)
treef92d1f23c329654772f95d93f5cf4458741b72df /tensorflow/python/data
parentbb1f9e1a57c8bc18325b3c86298be96e6647a0a3 (diff)
[tf.data] Deprecate `tf.contrib.data` and introduce `tf.data.experimental` to replace it.
This change prepares `tf.data` for TensorFlow 2.0, where `tf.contrib` will no longer exist. It retains the pre-existing endpoints in `tf.contrib.data` with deprecation warnings. Note there are some exceptions to the move: * Deprecated symbols in `tf.contrib.data` have not been moved to `tf.data.experimental`, because replacements already exist. * `tf.contrib.data.LMDBDataset` has not been moved, because we plan to move it to a SIG-maintained repository. * `tf.contrib.data.assert_element_shape()` has not yet been moved, because it depends on functionality in `tf.contrib`, and it will move in a later change. * `tf.contrib.data.AUTOTUNE` has not yet been moved, because we have not yet determined how to `tf_export()` a Python integer. * The stats-related API endpoints have not yet appeared in a released version of TensorFlow, so these are moved to `tf.data.experimental` without retaining an endpoint in `tf.contrib.data`. In addition, this change includes some build rule and ApiDef refactoring: * Some of the "//third_party/tensorflow/python:training" dependencies had to be split in order to avoid a circular dependency. * The `tf.contrib.stateless` ops now have a private core library for the generated wrappers (and accordingly are hidden in their ApiDef) so that `tf.data.experimental.sample_from_datasets()` can depend on them. PiperOrigin-RevId: 215304249
Diffstat (limited to 'tensorflow/python/data')
-rw-r--r--tensorflow/python/data/BUILD1
-rw-r--r--tensorflow/python/data/__init__.py1
-rw-r--r--tensorflow/python/data/experimental/BUILD16
-rw-r--r--tensorflow/python/data/experimental/__init__.py109
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/BUILD569
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/batch_dataset_op_test.py672
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/bucketing_test.py824
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/csv_dataset_op_test.py632
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/dataset_constructor_op_test.py71
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/dataset_serialization_test_base.py692
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/directed_interleave_dataset_test.py148
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/filter_dataset_op_test.py76
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/get_single_element_test.py72
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/indexed_dataset_ops_test.py79
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/interleave_dataset_op_test.py811
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/iterator_ops_test.py125
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/map_dataset_op_test.py359
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py281
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/optimization/BUILD164
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/optimization/assert_next_dataset_op_test.py65
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/optimization/hoist_random_uniform_test.py103
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/optimization/latency_all_edges_test.py58
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/optimization/map_and_filter_fusion_test.py225
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/optimization/map_parallelization_test.py85
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py223
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/optimization/model_dataset_op_test.py183
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/optimization/noop_elimination_test.py58
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_op_test.py109
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/parsing_ops_test.py850
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/prefetching_ops_test.py948
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/range_dataset_op_test.py78
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/reader_dataset_ops_test.py1083
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/reader_dataset_ops_test_base.py353
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/resample_test.py182
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/scan_dataset_op_test.py172
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/BUILD555
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/batch_dataset_serialization_test.py83
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/cache_dataset_serialization_test.py253
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/concatenate_dataset_serialization_test.py49
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/csv_dataset_serialization_test.py73
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/dataset_constructor_serialization_test.py95
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/dataset_serialization_test_base.py692
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/filter_dataset_serialization_test.py71
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py45
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/flat_map_dataset_serialization_test.py122
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/group_by_reducer_serialization_test.py61
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/group_by_window_serialization_test.py57
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/ignore_errors_serialization_test.py46
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/interleave_dataset_serialization_test.py83
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py88
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/map_dataset_serialization_test.py140
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/optimize_dataset_serialization_test.py39
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/padded_batch_dataset_serialization_test.py66
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py101
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/parallel_map_dataset_serialization_test.py139
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/parse_example_dataset_serialization_test.py50
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/prefetch_dataset_serialization_test.py39
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/range_dataset_serialization_test.py118
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/sample_from_datasets_serialization_test.py46
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/scan_dataset_serialization_test.py40
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/sequence_dataset_serialization_test.py129
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/serialization_integration_test.py85
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py39
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/shuffle_dataset_serialization_test.py148
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/sql_dataset_serialization_test.py53
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/stats_dataset_serialization_test.py106
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/textline_dataset_serialization_test.py53
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/tf_record_dataset_serialization_test.py99
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/unbatch_dataset_serialization_test.py51
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/unique_dataset_serialization_test.py40
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/zip_dataset_serialization_test.py54
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization_integration_test.py85
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/shuffle_dataset_op_test.py115
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/sql_dataset_op_test.py590
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/sql_dataset_op_test_base.py95
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py253
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/stats_dataset_test_base.py71
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/threadpool_dataset_ops_test.py91
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/unique_dataset_op_test.py83
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/writer_ops_test.py118
-rw-r--r--tensorflow/python/data/experimental/ops/BUILD377
-rw-r--r--tensorflow/python/data/experimental/ops/batching.py669
-rw-r--r--tensorflow/python/data/experimental/ops/counter.py55
-rw-r--r--tensorflow/python/data/experimental/ops/enumerate_ops.py60
-rw-r--r--tensorflow/python/data/experimental/ops/error_ops.py78
-rw-r--r--tensorflow/python/data/experimental/ops/get_single_element.py72
-rw-r--r--tensorflow/python/data/experimental/ops/grouping.py551
-rw-r--r--tensorflow/python/data/experimental/ops/indexed_dataset_ops.py177
-rw-r--r--tensorflow/python/data/experimental/ops/interleave_ops.py262
-rw-r--r--tensorflow/python/data/experimental/ops/iterator_ops.py268
-rw-r--r--tensorflow/python/data/experimental/ops/map_defun.py56
-rw-r--r--tensorflow/python/data/experimental/ops/optimization.py171
-rw-r--r--tensorflow/python/data/experimental/ops/parsing_ops.py152
-rw-r--r--tensorflow/python/data/experimental/ops/prefetching_ops.py531
-rw-r--r--tensorflow/python/data/experimental/ops/random_ops.py54
-rw-r--r--tensorflow/python/data/experimental/ops/readers.py904
-rw-r--r--tensorflow/python/data/experimental/ops/resampling.py296
-rw-r--r--tensorflow/python/data/experimental/ops/scan_ops.py177
-rw-r--r--tensorflow/python/data/experimental/ops/shuffle_ops.py102
-rw-r--r--tensorflow/python/data/experimental/ops/stats_ops.py205
-rw-r--r--tensorflow/python/data/experimental/ops/threadpool.py104
-rw-r--r--tensorflow/python/data/experimental/ops/unique.py79
-rw-r--r--tensorflow/python/data/experimental/ops/writers.py60
-rw-r--r--tensorflow/python/data/ops/dataset_ops.py4
-rw-r--r--tensorflow/python/data/ops/optional_ops.py4
-rw-r--r--tensorflow/python/data/ops/readers.py4
106 files changed, 21452 insertions, 6 deletions
diff --git a/tensorflow/python/data/BUILD b/tensorflow/python/data/BUILD
index 138141f4fc..e32eeecbb8 100644
--- a/tensorflow/python/data/BUILD
+++ b/tensorflow/python/data/BUILD
@@ -10,6 +10,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:util",
+ "//tensorflow/python/data/experimental",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:iterator_ops",
"//tensorflow/python/data/ops:multi_device_iterator_ops",
diff --git a/tensorflow/python/data/__init__.py b/tensorflow/python/data/__init__.py
index f8b561205e..7536ba668a 100644
--- a/tensorflow/python/data/__init__.py
+++ b/tensorflow/python/data/__init__.py
@@ -22,6 +22,7 @@ from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import
+from tensorflow.python.data import experimental
from tensorflow.python.data.ops.dataset_ops import Dataset
from tensorflow.python.data.ops.iterator_ops import Iterator
from tensorflow.python.data.ops.readers import FixedLengthRecordDataset
diff --git a/tensorflow/python/data/experimental/BUILD b/tensorflow/python/data/experimental/BUILD
new file mode 100644
index 0000000000..84e761d376
--- /dev/null
+++ b/tensorflow/python/data/experimental/BUILD
@@ -0,0 +1,16 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+py_library(
+ name = "experimental",
+ srcs = ["__init__.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/experimental/ops:dataset_ops",
+ "//tensorflow/python/data/experimental/ops:iterator_ops",
+ ],
+)
diff --git a/tensorflow/python/data/experimental/__init__.py b/tensorflow/python/data/experimental/__init__.py
new file mode 100644
index 0000000000..2ac159d38a
--- /dev/null
+++ b/tensorflow/python/data/experimental/__init__.py
@@ -0,0 +1,109 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Experimental API for building input pipelines.
+
+This module contains experimental `Dataset` sources and transformations that can
+be used in conjunction with the `tf.data.Dataset` API. Note that the
+`tf.data.experimental` API is not subject to the same backwards compatibility
+guarantees as `tf.data`, but we will provide deprecation advice in advance of
+removing existing functionality.
+
+See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
+
+@@Counter
+@@CheckpointInputPipelineHook
+@@CsvDataset
+@@Optional
+@@RandomDataset
+@@Reducer
+@@SqlDataset
+@@TFRecordWriter
+
+@@bucket_by_sequence_length
+@@choose_from_datasets
+@@copy_to_device
+@@dense_to_sparse_batch
+@@enumerate_dataset
+@@get_next_as_optional
+@@get_single_element
+@@group_by_reducer
+@@group_by_window
+@@ignore_errors
+@@latency_stats
+@@make_batched_features_dataset
+@@make_csv_dataset
+@@make_saveable_from_iterator
+@@map_and_batch
+@@parallel_interleave
+@@parse_example_dataset
+@@prefetch_to_device
+@@rejection_resample
+@@sample_from_datasets
+@@scan
+@@set_stats_aggregator
+@@shuffle_and_repeat
+@@StatsAggregator
+@@unbatch
+@@unique
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=unused-import
+
+from tensorflow.python.data.experimental.ops.batching import dense_to_sparse_batch
+from tensorflow.python.data.experimental.ops.batching import map_and_batch
+from tensorflow.python.data.experimental.ops.batching import unbatch
+from tensorflow.python.data.experimental.ops.counter import Counter
+from tensorflow.python.data.experimental.ops.enumerate_ops import enumerate_dataset
+from tensorflow.python.data.experimental.ops.error_ops import ignore_errors
+from tensorflow.python.data.experimental.ops.get_single_element import get_single_element
+from tensorflow.python.data.experimental.ops.grouping import bucket_by_sequence_length
+from tensorflow.python.data.experimental.ops.grouping import group_by_reducer
+from tensorflow.python.data.experimental.ops.grouping import group_by_window
+from tensorflow.python.data.experimental.ops.grouping import Reducer
+from tensorflow.python.data.experimental.ops.interleave_ops import choose_from_datasets
+from tensorflow.python.data.experimental.ops.interleave_ops import parallel_interleave
+from tensorflow.python.data.experimental.ops.interleave_ops import sample_from_datasets
+from tensorflow.python.data.experimental.ops.iterator_ops import CheckpointInputPipelineHook
+from tensorflow.python.data.experimental.ops.iterator_ops import make_saveable_from_iterator
+
+# Optimization constant that can be used to enable auto-tuning.
+from tensorflow.python.data.experimental.ops.optimization import AUTOTUNE
+
+from tensorflow.python.data.experimental.ops.parsing_ops import parse_example_dataset
+from tensorflow.python.data.experimental.ops.prefetching_ops import copy_to_device
+from tensorflow.python.data.experimental.ops.prefetching_ops import prefetch_to_device
+from tensorflow.python.data.experimental.ops.random_ops import RandomDataset
+from tensorflow.python.data.experimental.ops.readers import CsvDataset
+from tensorflow.python.data.experimental.ops.readers import make_batched_features_dataset
+from tensorflow.python.data.experimental.ops.readers import make_csv_dataset
+from tensorflow.python.data.experimental.ops.readers import SqlDataset
+from tensorflow.python.data.experimental.ops.resampling import rejection_resample
+from tensorflow.python.data.experimental.ops.scan_ops import scan
+from tensorflow.python.data.experimental.ops.shuffle_ops import shuffle_and_repeat
+from tensorflow.python.data.experimental.ops.stats_ops import latency_stats
+from tensorflow.python.data.experimental.ops.stats_ops import set_stats_aggregator
+from tensorflow.python.data.experimental.ops.stats_ops import StatsAggregator
+from tensorflow.python.data.experimental.ops.unique import unique
+from tensorflow.python.data.experimental.ops.writers import TFRecordWriter
+from tensorflow.python.data.ops.iterator_ops import get_next_as_optional
+from tensorflow.python.data.ops.optional_ops import Optional
+# pylint: enable=unused-import
+
+from tensorflow.python.util.all_util import remove_undocumented
+remove_undocumented(__name__)
diff --git a/tensorflow/python/data/experimental/kernel_tests/BUILD b/tensorflow/python/data/experimental/kernel_tests/BUILD
new file mode 100644
index 0000000000..a46c30ed2e
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/BUILD
@@ -0,0 +1,569 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "cuda_py_test")
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+py_test(
+ name = "batch_dataset_op_test",
+ size = "medium",
+ srcs = ["batch_dataset_op_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss", # (b/79552534)
+ "no_pip",
+ ],
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:script_ops",
+ "//tensorflow/python:session",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:string_ops",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/experimental/ops:batching",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+py_test(
+ name = "bucketing_test",
+ size = "medium",
+ srcs = ["bucketing_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:string_ops",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/experimental/ops:grouping",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "csv_dataset_op_test",
+ size = "medium",
+ srcs = ["csv_dataset_op_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:session",
+ "//tensorflow/python/data/experimental/ops:error_ops",
+ "//tensorflow/python/data/experimental/ops:readers",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:readers",
+ "//tensorflow/python/eager:context",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "dataset_constructor_op_test",
+ size = "medium",
+ srcs = ["dataset_constructor_op_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "manual",
+ "nomac", # b/62040583
+ ],
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python/data/experimental/ops:batching",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ ],
+)
+
+py_test(
+ name = "directed_interleave_dataset_test",
+ size = "medium",
+ srcs = ["directed_interleave_dataset_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:random_seed",
+ "//tensorflow/python/data/experimental/ops:interleave_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "get_single_element_test",
+ size = "small",
+ srcs = ["get_single_element_test.py"],
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/experimental/ops:get_single_element",
+ "//tensorflow/python/data/experimental/ops:grouping",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+py_test(
+ name = "indexed_dataset_ops_test",
+ srcs = ["indexed_dataset_ops_test.py"],
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:experimental_dataset_ops_gen",
+ "//tensorflow/python/data/experimental/ops:indexed_dataset_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "interleave_dataset_op_test",
+ size = "medium",
+ srcs = ["interleave_dataset_op_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "notap",
+ ],
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:script_ops",
+ "//tensorflow/python:sparse_ops",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/experimental/ops:interleave_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "@six_archive//:six",
+ ],
+)
+
+py_test(
+ name = "iterator_ops_test",
+ size = "small",
+ srcs = ["iterator_ops_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/data/experimental/ops:iterator_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/estimator:estimator_py",
+ ],
+)
+
+py_test(
+ name = "map_dataset_op_test",
+ size = "medium",
+ srcs = ["map_dataset_op_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_pip",
+ "noasan", # times out
+ "optonly",
+ ],
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:io_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/experimental/ops:batching",
+ "//tensorflow/python/data/experimental/ops:error_ops",
+ "//tensorflow/python/data/experimental/ops:optimization",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "filter_dataset_op_test",
+ size = "medium",
+ srcs = ["filter_dataset_op_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:io_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/experimental/ops:optimization",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "map_defun_op_test",
+ size = "small",
+ srcs = ["map_defun_op_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:check_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:data_flow_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:function",
+ "//tensorflow/python:functional_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:session",
+ "//tensorflow/python/data/experimental/ops:map_defun",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ ],
+)
+
+py_test(
+ name = "parsing_ops_test",
+ size = "small",
+ srcs = ["parsing_ops_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/experimental/ops:parsing_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ "//third_party/py/numpy",
+ ],
+)
+
+cuda_py_test(
+ name = "prefetching_ops_test",
+ size = "small",
+ srcs = ["prefetching_ops_test.py"],
+ additional_deps = [
+ "//tensorflow/python/data/experimental/ops:prefetching_ops",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:function",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python/compat:compat",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
+ ],
+ tags = ["no_windows_gpu"],
+)
+
+py_test(
+ name = "range_dataset_op_test",
+ size = "small",
+ srcs = ["range_dataset_op_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/experimental/ops:counter",
+ "//tensorflow/python/data/experimental/ops:enumerate_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_library(
+ name = "reader_dataset_ops_test_base",
+ testonly = 1,
+ srcs = [
+ "reader_dataset_ops_test_base.py",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = [
+ "//tensorflow/python/data/experimental/kernel_tests:__pkg__",
+ "//tensorflow/python/data/experimental/kernel_tests/serialization:__pkg__",
+ ],
+ deps = [
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:lib",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/experimental/ops:readers",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python/data/ops:readers",
+ ],
+)
+
+py_test(
+ name = "reader_dataset_ops_test",
+ size = "medium",
+ srcs = ["reader_dataset_ops_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":reader_dataset_ops_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:string_ops",
+ "//tensorflow/python/data/experimental/ops:readers",
+ "//tensorflow/python/data/ops:readers",
+ "//tensorflow/python/data/util:nest",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "resample_test",
+ size = "medium",
+ srcs = ["resample_test.py"],
+ shard_count = 2,
+ srcs_version = "PY2AND3",
+ tags = [
+ "noasan",
+ "optonly",
+ ],
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:random_ops",
+ "//tensorflow/python:string_ops",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/experimental/ops:resampling",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
+ "@six_archive//:six",
+ ],
+)
+
+py_test(
+ name = "scan_dataset_op_test",
+ size = "small",
+ srcs = ["scan_dataset_op_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/experimental/ops:scan_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/eager:context",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "shuffle_dataset_op_test",
+ size = "medium",
+ srcs = ["shuffle_dataset_op_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_pip",
+ "optonly",
+ ],
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python/data/experimental/ops:shuffle_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_library(
+ name = "sql_dataset_op_test_base",
+ srcs = ["sql_dataset_op_test_base.py"],
+ srcs_version = "PY2AND3",
+ visibility = [
+ "//tensorflow/python/data/experimental/kernel_tests:__pkg__",
+ "//tensorflow/python/data/experimental/kernel_tests/serialization:__pkg__",
+ ],
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python/data/experimental/ops:readers",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "@org_sqlite//:python",
+ ],
+)
+
+py_test(
+ name = "sql_dataset_op_test",
+ size = "small",
+ srcs = ["sql_dataset_op_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":sql_dataset_op_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ ],
+)
+
+py_test(
+ name = "stats_dataset_ops_test",
+ size = "medium",
+ srcs = ["stats_dataset_ops_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":reader_dataset_ops_test_base",
+ ":stats_dataset_test_base",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python/data/experimental/ops:stats_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_library(
+ name = "stats_dataset_test_base",
+ srcs = ["stats_dataset_test_base.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ ],
+)
+
+py_test(
+ name = "threadpool_dataset_ops_test",
+ size = "small",
+ srcs = ["threadpool_dataset_ops_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:script_ops",
+ "//tensorflow/python/data/experimental/ops:threadpool",
+ "//tensorflow/python/data/experimental/ops:unique",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+py_test(
+ name = "unique_dataset_op_test",
+ size = "small",
+ srcs = ["unique_dataset_op_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/experimental/ops:unique",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "writer_ops_test",
+ size = "small",
+ srcs = ["writer_ops_test.py"],
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:lib",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/experimental/ops:writers",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:readers",
+ ],
+)
diff --git a/tensorflow/python/data/experimental/kernel_tests/batch_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/batch_dataset_op_test.py
new file mode 100644
index 0000000000..8703b2810e
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/batch_dataset_op_test.py
@@ -0,0 +1,672 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import time
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.python.client import session
+from tensorflow.python.data.experimental.ops import batching
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import script_ops
+from tensorflow.python.ops import string_ops
+from tensorflow.python.platform import test
+from tensorflow.python.util import compat
+
+
+class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
+
+ def testDenseToSparseBatchDataset(self):
+ components = np.random.randint(12, size=(100,)).astype(np.int32)
+ iterator = (
+ dataset_ops.Dataset.from_tensor_slices(components)
+ .map(lambda x: array_ops.fill([x], x)).apply(
+ batching.dense_to_sparse_batch(4, [12]))
+ .make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+
+ for start in range(0, len(components), 4):
+ results = sess.run(get_next)
+ self.assertAllEqual([[i, j]
+ for i, c in enumerate(components[start:start + 4])
+ for j in range(c)], results.indices)
+ self.assertAllEqual(
+ [c for c in components[start:start + 4] for _ in range(c)],
+ results.values)
+ self.assertAllEqual([min(4,
+ len(components) - start), 12],
+ results.dense_shape)
+
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testDenseToSparseBatchDatasetWithUnknownShape(self):
+ components = np.random.randint(5, size=(40,)).astype(np.int32)
+ iterator = (
+ dataset_ops.Dataset.from_tensor_slices(components)
+ .map(lambda x: array_ops.fill([x, x], x)).apply(
+ batching.dense_to_sparse_batch(
+ 4, [5, None])).make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+
+ for start in range(0, len(components), 4):
+ results = sess.run(get_next)
+ self.assertAllEqual([[i, j, z]
+ for i, c in enumerate(components[start:start + 4])
+ for j in range(c)
+ for z in range(c)], results.indices)
+ self.assertAllEqual([
+ c
+ for c in components[start:start + 4] for _ in range(c)
+ for _ in range(c)
+ ], results.values)
+ self.assertAllEqual([
+ min(4,
+ len(components) - start), 5,
+ np.max(components[start:start + 4])
+ ], results.dense_shape)
+
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testDenseToSparseBatchDatasetWithInvalidShape(self):
+ input_tensor = array_ops.constant([[1]])
+ with self.assertRaisesRegexp(ValueError, "Dimension -2 must be >= 0"):
+ dataset_ops.Dataset.from_tensors(input_tensor).apply(
+ batching.dense_to_sparse_batch(4, [-2])).make_initializable_iterator()
+
+ def testDenseToSparseBatchDatasetShapeErrors(self):
+ input_tensor = array_ops.placeholder(dtypes.int32)
+ iterator = (
+ dataset_ops.Dataset.from_tensors(input_tensor).apply(
+ batching.dense_to_sparse_batch(4, [12]))
+ .make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ # Initialize with an input tensor of incompatible rank.
+ sess.run(init_op, feed_dict={input_tensor: [[1]]})
+ with self.assertRaisesRegexp(errors.InvalidArgumentError,
+ "incompatible with the row shape"):
+ sess.run(get_next)
+
+ # Initialize with an input tensor that is larger than `row_shape`.
+ sess.run(init_op, feed_dict={input_tensor: range(13)})
+ with self.assertRaisesRegexp(errors.DataLossError,
+ "larger than the row shape"):
+ sess.run(get_next)
+
+ def testUnbatchScalarDataset(self):
+ data = tuple([math_ops.range(10) for _ in range(3)])
+ data = dataset_ops.Dataset.from_tensor_slices(data)
+ expected_types = (dtypes.int32,) * 3
+ data = data.batch(2)
+ self.assertEqual(expected_types, data.output_types)
+ data = data.apply(batching.unbatch())
+ self.assertEqual(expected_types, data.output_types)
+
+ iterator = data.make_one_shot_iterator()
+ op = iterator.get_next()
+
+ with self.cached_session() as sess:
+ for i in range(10):
+ self.assertEqual((i,) * 3, sess.run(op))
+
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(op)
+
+ def testUnbatchDatasetWithStrings(self):
+ data = tuple([math_ops.range(10) for _ in range(3)])
+ data = dataset_ops.Dataset.from_tensor_slices(data)
+ data = data.map(lambda x, y, z: (x, string_ops.as_string(y), z))
+ expected_types = (dtypes.int32, dtypes.string, dtypes.int32)
+ data = data.batch(2)
+ self.assertEqual(expected_types, data.output_types)
+ data = data.apply(batching.unbatch())
+ self.assertEqual(expected_types, data.output_types)
+
+ iterator = data.make_one_shot_iterator()
+ op = iterator.get_next()
+
+ with self.cached_session() as sess:
+ for i in range(10):
+ self.assertEqual((i, compat.as_bytes(str(i)), i), sess.run(op))
+
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(op)
+
+ def testUnbatchDatasetWithSparseTensor(self):
+ st = sparse_tensor.SparseTensorValue(
+ indices=[[i, i] for i in range(10)],
+ values=list(range(10)),
+ dense_shape=[10, 10])
+ data = dataset_ops.Dataset.from_tensors(st)
+ data = data.apply(batching.unbatch())
+ data = data.batch(5)
+ data = data.apply(batching.unbatch())
+ iterator = data.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ for i in range(10):
+ st_row = sess.run(next_element)
+ self.assertEqual([i], st_row.indices)
+ self.assertEqual([i], st_row.values)
+ self.assertEqual([10], st_row.dense_shape)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testUnbatchDatasetWithDenseAndSparseTensor(self):
+ st = sparse_tensor.SparseTensorValue(
+ indices=[[i, i] for i in range(10)],
+ values=list(range(10)),
+ dense_shape=[10, 10])
+ data = dataset_ops.Dataset.from_tensors((list(range(10)), st))
+ data = data.apply(batching.unbatch())
+ data = data.batch(5)
+ data = data.apply(batching.unbatch())
+ iterator = data.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ for i in range(10):
+ dense_elem, st_row = sess.run(next_element)
+ self.assertEqual(i, dense_elem)
+ self.assertEqual([i], st_row.indices)
+ self.assertEqual([i], st_row.values)
+ self.assertEqual([10], st_row.dense_shape)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testUnbatchSingleElementTupleDataset(self):
+ data = tuple([(math_ops.range(10),) for _ in range(3)])
+ data = dataset_ops.Dataset.from_tensor_slices(data)
+ expected_types = ((dtypes.int32,),) * 3
+ data = data.batch(2)
+ self.assertEqual(expected_types, data.output_types)
+ data = data.apply(batching.unbatch())
+ self.assertEqual(expected_types, data.output_types)
+
+ iterator = data.make_one_shot_iterator()
+ op = iterator.get_next()
+
+ with self.cached_session() as sess:
+ for i in range(10):
+ self.assertEqual(((i,),) * 3, sess.run(op))
+
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(op)
+
+ def testUnbatchMultiElementTupleDataset(self):
+ data = tuple([(math_ops.range(10 * i, 10 * i + 10),
+ array_ops.fill([10], "hi")) for i in range(3)])
+ data = dataset_ops.Dataset.from_tensor_slices(data)
+ expected_types = ((dtypes.int32, dtypes.string),) * 3
+ data = data.batch(2)
+ self.assertAllEqual(expected_types, data.output_types)
+ data = data.apply(batching.unbatch())
+ self.assertAllEqual(expected_types, data.output_types)
+
+ iterator = data.make_one_shot_iterator()
+ op = iterator.get_next()
+
+ with self.cached_session() as sess:
+ for i in range(10):
+ self.assertEqual(((i, b"hi"), (10 + i, b"hi"), (20 + i, b"hi")),
+ sess.run(op))
+
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(op)
+
+ def testUnbatchEmpty(self):
+ data = dataset_ops.Dataset.from_tensors(
+ (constant_op.constant([]), constant_op.constant([], shape=[0, 4]),
+ constant_op.constant([], shape=[0, 4, 0])))
+ data = data.apply(batching.unbatch())
+ iterator = data.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testUnbatchStaticShapeMismatch(self):
+ data = dataset_ops.Dataset.from_tensors((np.arange(7), np.arange(8),
+ np.arange(9)))
+ with self.assertRaises(ValueError):
+ data.apply(batching.unbatch())
+
+ def testUnbatchDynamicShapeMismatch(self):
+ ph1 = array_ops.placeholder(dtypes.int32, shape=[None])
+ ph2 = array_ops.placeholder(dtypes.int32, shape=None)
+ data = dataset_ops.Dataset.from_tensors((ph1, ph2))
+ data = data.apply(batching.unbatch())
+ iterator = data.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ # Mismatch in the 0th dimension.
+ sess.run(
+ iterator.initializer,
+ feed_dict={
+ ph1: np.arange(7).astype(np.int32),
+ ph2: np.arange(8).astype(np.int32)
+ })
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(next_element)
+
+ # No 0th dimension (i.e. scalar value) for one component.
+ sess.run(
+ iterator.initializer,
+ feed_dict={
+ ph1: np.arange(7).astype(np.int32),
+ ph2: 7
+ })
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(next_element)
+
+ @parameterized.named_parameters(
+ ("Default", None, None),
+ ("SequentialCalls", 1, None),
+ ("ParallelCalls", 2, None),
+ ("ParallelBatches", None, 10),
+ )
+ def testMapAndBatch(self, num_parallel_calls, num_parallel_batches):
+ """Test a dataset that maps a TF function across its input elements."""
+ # The pipeline is TensorSliceDataset ->
+ # RepeatDataset(count) -> MapAndBatchDataset(square_3, batch_size).
+ components = (np.arange(7),
+ np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
+ np.array(37.0) * np.arange(7))
+
+ count = array_ops.placeholder(dtypes.int64, shape=[])
+ batch_size = array_ops.placeholder(dtypes.int64, shape=[])
+
+ def _map_fn(x, y, z):
+ return math_ops.square(x), math_ops.square(y), math_ops.square(z)
+
+ iterator = (
+ dataset_ops.Dataset.from_tensor_slices(components).repeat(count).apply(
+ batching.map_and_batch(
+ map_func=_map_fn,
+ batch_size=batch_size,
+ num_parallel_calls=num_parallel_calls,
+ num_parallel_batches=num_parallel_batches))
+ .make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ self.assertEqual([[None] + list(c.shape[1:]) for c in components],
+ [t.shape.as_list() for t in get_next])
+
+ with self.cached_session() as sess:
+ # Batch of a finite input, where the batch_size divides the
+ # total number of elements.
+ sess.run(init_op, feed_dict={count: 28, batch_size: 14})
+ num_batches = (28 * 7) // 14
+ for i in range(num_batches):
+ result = sess.run(get_next)
+ for component, result_component in zip(components, result):
+ for j in range(14):
+ self.assertAllEqual(component[(i * 14 + j) % 7]**2,
+ result_component[j])
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Batch of a finite input, where the batch_size does not
+ # divide the total number of elements.
+ sess.run(init_op, feed_dict={count: 14, batch_size: 8})
+
+ # We expect (num_batches - 1) full-sized batches.
+ num_batches = int(math.ceil((14 * 7) / 8))
+ for i in range(num_batches - 1):
+ result = sess.run(get_next)
+ for component, result_component in zip(components, result):
+ for j in range(8):
+ self.assertAllEqual(component[(i * 8 + j) % 7]**2,
+ result_component[j])
+ result = sess.run(get_next)
+ for component, result_component in zip(components, result):
+ for j in range((14 * 7) % 8):
+ self.assertAllEqual(component[((num_batches - 1) * 8 + j) % 7]**2,
+ result_component[j])
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Batch of an empty input should fail straight away.
+ sess.run(init_op, feed_dict={count: 0, batch_size: 8})
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Empty batch should be an initialization time error.
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(init_op, feed_dict={count: 14, batch_size: 0})
+
+ @parameterized.named_parameters(
+ ("Even", False),
+ ("Uneven", True),
+ )
+ def testMapAndBatchPartialBatch(self, drop_remainder):
+ iterator = (
+ dataset_ops.Dataset.range(10).apply(
+ batching.map_and_batch(
+ lambda x: array_ops.reshape(x * x, [1]),
+ batch_size=4,
+ drop_remainder=drop_remainder)).make_one_shot_iterator())
+ if drop_remainder:
+ self.assertEqual([4, 1], iterator.output_shapes.as_list())
+ else:
+ self.assertEqual([None, 1], iterator.output_shapes.as_list())
+ next_element = iterator.get_next()
+ with self.cached_session() as sess:
+ self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
+ self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
+ if not drop_remainder:
+ self.assertAllEqual([[64], [81]], sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testMapAndBatchYieldsPartialBatch(self):
+ iterator = (dataset_ops.Dataset.range(10)
+ .apply(batching.map_and_batch(
+ lambda x: array_ops.reshape(x * x, [1]), 4))
+ .make_one_shot_iterator())
+ self.assertEqual([None, 1], iterator.output_shapes.as_list())
+ next_element = iterator.get_next()
+ with self.cached_session() as sess:
+ self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
+ self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
+ self.assertAllEqual([[64], [81]], sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testMapAndBatchParallelGetNext(self):
+ iterator = (dataset_ops.Dataset.range(50000)
+ .apply(batching.map_and_batch(lambda x: x, batch_size=100))
+ .make_one_shot_iterator())
+ elements = []
+ for _ in range(100):
+ elements.append(iterator.get_next())
+ with self.cached_session() as sess:
+ for i in range(5):
+ got = sess.run(elements)
+ got.sort(key=lambda x: x[0])
+ expected = []
+ for j in range(100):
+ expected.append(range(i*10000+j*100, i*10000+(j+1)*100))
+ self.assertAllEqual(got, expected)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(elements)
+
+ def testMapAndBatchParallelGetNextDropRemainder(self):
+ iterator = (
+ dataset_ops.Dataset.range(49999).apply(
+ batching.map_and_batch(
+ lambda x: x, batch_size=100, drop_remainder=True))
+ .make_one_shot_iterator())
+ elements = []
+ for _ in range(100):
+ elements.append(iterator.get_next())
+ with self.cached_session() as sess:
+ for i in range(4):
+ got = sess.run(elements)
+ got.sort(key=lambda x: x[0])
+ expected = []
+ for j in range(100):
+ expected.append(range(i*10000+j*100, i*10000+(j+1)*100))
+ self.assertAllEqual(got, expected)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(elements)
+
+ def testMapAndBatchSparse(self):
+
+ def _sparse(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=[[0]], values=(i * [1]), dense_shape=[1])
+
+ iterator = dataset_ops.Dataset.range(10).apply(
+ batching.map_and_batch(_sparse, 5)).make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ for i in range(2):
+ actual = sess.run(get_next)
+ expected = sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
+ values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4],
+ dense_shape=[5, 1])
+ self.assertTrue(sparse_tensor.is_sparse(actual))
+ self.assertSparseValuesEqual(actual, expected)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testMapAndBatchFails(self):
+ """Test a dataset that maps a TF function across its input elements."""
+ dataset = dataset_ops.Dataset.from_tensors(
+ array_ops.check_numerics(
+ constant_op.constant(1.0) / constant_op.constant(0.0), "oops"))
+ batch_size = array_ops.placeholder(dtypes.int64, shape=[])
+ iterator = (
+ dataset.apply(batching.map_and_batch(lambda x: x, batch_size))
+ .make_initializable_iterator())
+ init_op = iterator.initializer
+ with self.cached_session() as sess:
+ with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"):
+ sess.run(init_op, feed_dict={batch_size: 14})
+
+ def testMapAndBatchShapeMismatch(self):
+ """Test a dataset that maps a TF function across its input elements."""
+
+ def generator():
+ yield [1]
+ yield [2]
+ yield [3]
+ yield [[4, 5, 6]]
+
+ dataset = dataset_ops.Dataset.from_generator(
+ generator, output_types=dtypes.int32)
+ batch_size = 4
+ iterator = (
+ dataset.apply(batching.map_and_batch(lambda x: x, batch_size))
+ .make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ with self.assertRaisesRegexp(errors.InvalidArgumentError,
+ "number of elements does not match"):
+ sess.run(get_next)
+
+ def testMapAndBatchImplicitDispose(self):
+ # Tests whether a map and batch dataset will be cleaned up correctly when
+ # the pipeline does not run it until exhaustion.
+ # The pipeline is TensorSliceDataset -> RepeatDataset(1000) ->
+ # MapAndBatchDataset(f=square_3, batch_size=100).
+ components = (np.arange(1000),
+ np.array([[1, 2, 3]]) * np.arange(1000)[:, np.newaxis],
+ np.array(37.0) * np.arange(1000))
+
+ def _map_fn(x, y, z):
+ return math_ops.square(x), math_ops.square(y), math_ops.square(z)
+
+ dataset = dataset_ops.Dataset.from_tensor_slices(components).repeat(
+ 1000).apply(batching.map_and_batch(_map_fn, batch_size=100))
+ dataset = dataset.prefetch(5)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ for _ in range(3):
+ sess.run(get_next)
+
+ @parameterized.named_parameters(
+ ("1", 0),
+ ("2", 5),
+ ("3", 10),
+ ("4", 90),
+ ("5", 95),
+ ("6", 99),
+ )
+ def testMapAndBatchOutOfRangeError(self, threshold):
+
+ def raising_py_fn(i):
+ if i >= threshold:
+ raise StopIteration()
+ else:
+ return i
+
+ iterator = (
+ dataset_ops.Dataset.range(100).apply(
+ batching.map_and_batch(
+ lambda x: script_ops.py_func(raising_py_fn, [x], dtypes.int64),
+ batch_size=10)).make_one_shot_iterator())
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ for i in range(threshold // 10):
+ self.assertAllEqual([i * 10 + j for j in range(10)], sess.run(get_next))
+ if threshold % 10 != 0:
+ self.assertAllEqual(
+ [threshold // 10 * 10 + j for j in range(threshold % 10)],
+ sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ @parameterized.named_parameters(
+ ("1", False, dtypes.bool),
+ ("2", -42, dtypes.int8),
+ ("3", -42, dtypes.int16),
+ ("4", -42, dtypes.int32),
+ ("5", -42, dtypes.int64),
+ ("6", 42, dtypes.uint8),
+ ("7", 42, dtypes.uint16),
+ ("8", 42.0, dtypes.float16),
+ ("9", 42.0, dtypes.float32),
+ ("10", 42.0, dtypes.float64),
+ ("11", b"hello", dtypes.string),
+ )
+ def testMapAndBatchTypes(self, element, dtype):
+ def gen():
+ yield element
+
+ dataset = dataset_ops.Dataset.from_generator(gen, dtype).repeat(100).apply(
+ batching.map_and_batch(lambda x: x, batch_size=10))
+
+ get_next = dataset.make_one_shot_iterator().get_next()
+
+ with self.cached_session() as sess:
+ for _ in range(10):
+ self.assertAllEqual([element for _ in range(10)], sess.run(get_next))
+
+
+class UnbatchDatasetBenchmark(test.Benchmark):
+
+ def benchmarkNativeUnbatch(self):
+ batch_sizes = [1, 2, 5, 10, 20, 50]
+ elems_per_trial = 10000
+ with ops.Graph().as_default():
+ dataset = dataset_ops.Dataset.from_tensors("element").repeat(None)
+ batch_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
+ dataset = dataset.batch(batch_size_placeholder)
+ dataset = dataset.apply(batching.unbatch())
+ dataset = dataset.skip(elems_per_trial)
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with session.Session() as sess:
+ for batch_size in batch_sizes:
+ deltas = []
+ for _ in range(5):
+ sess.run(
+ iterator.initializer,
+ feed_dict={batch_size_placeholder: batch_size})
+ start = time.time()
+ sess.run(next_element.op)
+ end = time.time()
+ deltas.append((end - start) / elems_per_trial)
+
+ median_wall_time = np.median(deltas)
+ print("Unbatch (native) batch size: %d Median wall time per element:"
+ " %f microseconds" % (batch_size, median_wall_time * 1e6))
+ self.report_benchmark(
+ iters=10000,
+ wall_time=median_wall_time,
+ name="benchmark_unbatch_dataset_native_batch_size_%d" %
+ batch_size)
+
+ # Include a benchmark of the previous `unbatch()` implementation that uses
+ # a composition of more primitive ops. Eventually we'd hope to generate code
+ # that is as good in both cases.
+ def benchmarkOldUnbatchImplementation(self):
+ batch_sizes = [1, 2, 5, 10, 20, 50]
+ elems_per_trial = 10000
+ with ops.Graph().as_default():
+ dataset = dataset_ops.Dataset.from_tensors("element").repeat(None)
+ batch_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
+ dataset = dataset.batch(batch_size_placeholder)
+ dataset = dataset.flat_map(dataset_ops.Dataset.from_tensor_slices)
+ dataset = dataset.skip(elems_per_trial)
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with session.Session() as sess:
+ for batch_size in batch_sizes:
+ deltas = []
+ for _ in range(5):
+ sess.run(
+ iterator.initializer,
+ feed_dict={batch_size_placeholder: batch_size})
+ start = time.time()
+ sess.run(next_element.op)
+ end = time.time()
+ deltas.append((end - start) / elems_per_trial)
+
+ median_wall_time = np.median(deltas)
+ print("Unbatch (unfused) batch size: %d Median wall time per element:"
+ " %f microseconds" % (batch_size, median_wall_time * 1e6))
+ self.report_benchmark(
+ iters=10000,
+ wall_time=median_wall_time,
+ name="benchmark_unbatch_dataset_unfused_batch_size_%d" %
+ batch_size)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/bucketing_test.py b/tensorflow/python/data/experimental/kernel_tests/bucketing_test.py
new file mode 100644
index 0000000000..153a03989b
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/bucketing_test.py
@@ -0,0 +1,824 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import random
+
+import numpy as np
+
+from tensorflow.python.data.experimental.ops import grouping
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import string_ops
+from tensorflow.python.platform import test
+
+
+class GroupByReducerTest(test_base.DatasetTestBase):
+
+ def checkResults(self, dataset, shapes, values):
+ self.assertEqual(shapes, dataset.output_shapes)
+ get_next = dataset.make_one_shot_iterator().get_next()
+ with self.cached_session() as sess:
+ for expected in values:
+ got = sess.run(get_next)
+ self.assertEqual(got, expected)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testSum(self):
+ reducer = grouping.Reducer(
+ init_func=lambda _: np.int64(0),
+ reduce_func=lambda x, y: x + y,
+ finalize_func=lambda x: x)
+ for i in range(1, 11):
+ dataset = dataset_ops.Dataset.range(2 * i).apply(
+ grouping.group_by_reducer(lambda x: x % 2, reducer))
+ self.checkResults(
+ dataset, shapes=tensor_shape.scalar(), values=[(i - 1) * i, i * i])
+
+ def testAverage(self):
+
+ def reduce_fn(x, y):
+ return (x[0] * x[1] + math_ops.cast(y, dtypes.float32)) / (
+ x[1] + 1), x[1] + 1
+
+ reducer = grouping.Reducer(
+ init_func=lambda _: (0.0, 0.0),
+ reduce_func=reduce_fn,
+ finalize_func=lambda x, _: x)
+ for i in range(1, 11):
+ dataset = dataset_ops.Dataset.range(2 * i).apply(
+ grouping.group_by_reducer(
+ lambda x: math_ops.cast(x, dtypes.int64) % 2, reducer))
+ self.checkResults(
+ dataset, shapes=tensor_shape.scalar(), values=[i - 1, i])
+
+ def testConcat(self):
+ components = np.array(list("abcdefghijklmnopqrst")).view(np.chararray)
+ reducer = grouping.Reducer(
+ init_func=lambda x: "",
+ reduce_func=lambda x, y: x + y[0],
+ finalize_func=lambda x: x)
+ for i in range(1, 11):
+ dataset = dataset_ops.Dataset.zip(
+ (dataset_ops.Dataset.from_tensor_slices(components),
+ dataset_ops.Dataset.range(2 * i))).apply(
+ grouping.group_by_reducer(lambda x, y: y % 2, reducer))
+ self.checkResults(
+ dataset,
+ shapes=tensor_shape.scalar(),
+ values=[b"acegikmoqs" [:i], b"bdfhjlnprt" [:i]])
+
+ def testSparseSum(self):
+ def _sparse(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=np.array([[0, 0]]),
+ values=(i * np.array([1], dtype=np.int64)),
+ dense_shape=np.array([1, 1]))
+
+ reducer = grouping.Reducer(
+ init_func=lambda _: _sparse(np.int64(0)),
+ reduce_func=lambda x, y: _sparse(x.values[0] + y.values[0]),
+ finalize_func=lambda x: x.values[0])
+ for i in range(1, 11):
+ dataset = dataset_ops.Dataset.range(2 * i).map(_sparse).apply(
+ grouping.group_by_reducer(lambda x: x.values[0] % 2, reducer))
+ self.checkResults(
+ dataset, shapes=tensor_shape.scalar(), values=[(i - 1) * i, i * i])
+
+ def testChangingStateShape(self):
+
+ def reduce_fn(x, _):
+ # Statically known rank, but dynamic length.
+ larger_dim = array_ops.concat([x[0], x[0]], 0)
+ # Statically unknown rank.
+ larger_rank = array_ops.expand_dims(x[1], 0)
+ return larger_dim, larger_rank
+
+ reducer = grouping.Reducer(
+ init_func=lambda x: ([0], 1),
+ reduce_func=reduce_fn,
+ finalize_func=lambda x, y: (x, y))
+
+ for i in range(1, 11):
+ dataset = dataset_ops.Dataset.from_tensors(np.int64(0)).repeat(i).apply(
+ grouping.group_by_reducer(lambda x: x, reducer))
+ self.assertEqual([None], dataset.output_shapes[0].as_list())
+ self.assertIs(None, dataset.output_shapes[1].ndims)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+ with self.cached_session() as sess:
+ x, y = sess.run(get_next)
+ self.assertAllEqual([0] * (2**i), x)
+ self.assertAllEqual(np.array(1, ndmin=i), y)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testTypeMismatch(self):
+ reducer = grouping.Reducer(
+ init_func=lambda x: constant_op.constant(1, dtype=dtypes.int32),
+ reduce_func=lambda x, y: constant_op.constant(1, dtype=dtypes.int64),
+ finalize_func=lambda x: x)
+
+ dataset = dataset_ops.Dataset.range(10)
+ with self.assertRaisesRegexp(
+ TypeError,
+ "The element types for the new state must match the initial state."):
+ dataset.apply(
+ grouping.group_by_reducer(lambda _: np.int64(0), reducer))
+
+ # TODO(b/78665031): Remove once non-scalar keys are supported.
+ def testInvalidKeyShape(self):
+ reducer = grouping.Reducer(
+ init_func=lambda x: np.int64(0),
+ reduce_func=lambda x, y: x + y,
+ finalize_func=lambda x: x)
+
+ dataset = dataset_ops.Dataset.range(10)
+ with self.assertRaisesRegexp(
+ ValueError, "`key_func` must return a single tf.int64 tensor."):
+ dataset.apply(
+ grouping.group_by_reducer(lambda _: np.int64((0, 0)), reducer))
+
+ # TODO(b/78665031): Remove once non-int64 keys are supported.
+ def testInvalidKeyType(self):
+ reducer = grouping.Reducer(
+ init_func=lambda x: np.int64(0),
+ reduce_func=lambda x, y: x + y,
+ finalize_func=lambda x: x)
+
+ dataset = dataset_ops.Dataset.range(10)
+ with self.assertRaisesRegexp(
+ ValueError, "`key_func` must return a single tf.int64 tensor."):
+ dataset.apply(
+ grouping.group_by_reducer(lambda _: "wrong", reducer))
+
+ def testTuple(self):
+ def init_fn(_):
+ return np.array([], dtype=np.int64), np.int64(0)
+
+ def reduce_fn(state, value):
+ s1, s2 = state
+ v1, v2 = value
+ return array_ops.concat([s1, [v1]], 0), s2 + v2
+
+ def finalize_fn(s1, s2):
+ return s1, s2
+
+ reducer = grouping.Reducer(init_fn, reduce_fn, finalize_fn)
+ dataset = dataset_ops.Dataset.zip(
+ (dataset_ops.Dataset.range(10), dataset_ops.Dataset.range(10))).apply(
+ grouping.group_by_reducer(lambda x, y: np.int64(0), reducer))
+ get_next = dataset.make_one_shot_iterator().get_next()
+ with self.cached_session() as sess:
+ x, y = sess.run(get_next)
+ self.assertAllEqual(x, np.asarray([x for x in range(10)]))
+ self.assertEqual(y, 45)
+
+
+class GroupByWindowTest(test_base.DatasetTestBase):
+
+ def testSimple(self):
+ components = np.random.randint(100, size=(200,)).astype(np.int64)
+ iterator = (
+ dataset_ops.Dataset.from_tensor_slices(components).map(lambda x: x * x)
+ .apply(
+ grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(4),
+ 4)).make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ counts = []
+ with self.assertRaises(errors.OutOfRangeError):
+ while True:
+ result = sess.run(get_next)
+ self.assertTrue(
+ all(x % 2 == 0
+ for x in result) or all(x % 2 == 1)
+ for x in result)
+ counts.append(result.shape[0])
+
+ self.assertEqual(len(components), sum(counts))
+ num_full_batches = len([c for c in counts if c == 4])
+ self.assertGreaterEqual(num_full_batches, 24)
+ self.assertTrue(all(c == 4 for c in counts[:num_full_batches]))
+
+ def testImmediateOutput(self):
+ components = np.array(
+ [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0], dtype=np.int64)
+ iterator = (
+ dataset_ops.Dataset.from_tensor_slices(components).repeat(-1).apply(
+ grouping.group_by_window(lambda x: x % 3, lambda _, xs: xs.batch(4),
+ 4)).make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ # The input is infinite, so this test demonstrates that:
+ # 1. We produce output without having to consume the entire input,
+ # 2. Different buckets can produce output at different rates, and
+ # 3. For deterministic input, the output is deterministic.
+ for _ in range(3):
+ self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
+ self.assertAllEqual([1, 1, 1, 1], sess.run(get_next))
+ self.assertAllEqual([2, 2, 2, 2], sess.run(get_next))
+ self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
+
+ def testSmallGroups(self):
+ components = np.array([0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0], dtype=np.int64)
+ iterator = (
+ dataset_ops.Dataset.from_tensor_slices(components).apply(
+ grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(4),
+ 4)).make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
+ self.assertAllEqual([1, 1, 1, 1], sess.run(get_next))
+ # The small outputs at the end are deterministically produced in key
+ # order.
+ self.assertAllEqual([0, 0, 0], sess.run(get_next))
+ self.assertAllEqual([1], sess.run(get_next))
+
+ def testEmpty(self):
+ iterator = (
+ dataset_ops.Dataset.range(4).apply(
+ grouping.group_by_window(lambda _: 0, lambda _, xs: xs, 0))
+ .make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ "Window size must be greater than zero, but got 0."):
+ print(sess.run(get_next))
+
+ def testReduceFuncError(self):
+ components = np.random.randint(100, size=(200,)).astype(np.int64)
+
+ def reduce_func(_, xs):
+ # Introduce an incorrect padded shape that cannot (currently) be
+ # detected at graph construction time.
+ return xs.padded_batch(
+ 4,
+ padded_shapes=(tensor_shape.TensorShape([]),
+ constant_op.constant([5], dtype=dtypes.int64) * -1))
+
+ iterator = (
+ dataset_ops.Dataset.from_tensor_slices(components)
+ .map(lambda x: (x, ops.convert_to_tensor([x * x]))).apply(
+ grouping.group_by_window(lambda x, _: x % 2, reduce_func,
+ 32)).make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(get_next)
+
+ def testConsumeWindowDatasetMoreThanOnce(self):
+ components = np.random.randint(50, size=(200,)).astype(np.int64)
+
+ def reduce_func(key, window):
+ # Apply two different kinds of padding to the input: tight
+ # padding, and quantized (to a multiple of 10) padding.
+ return dataset_ops.Dataset.zip((
+ window.padded_batch(
+ 4, padded_shapes=tensor_shape.TensorShape([None])),
+ window.padded_batch(
+ 4, padded_shapes=ops.convert_to_tensor([(key + 1) * 10])),
+ ))
+
+ iterator = (
+ dataset_ops.Dataset.from_tensor_slices(components)
+ .map(lambda x: array_ops.fill([math_ops.cast(x, dtypes.int32)], x))
+ .apply(grouping.group_by_window(
+ lambda x: math_ops.cast(array_ops.shape(x)[0] // 10, dtypes.int64),
+ reduce_func, 4))
+ .make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ counts = []
+ with self.assertRaises(errors.OutOfRangeError):
+ while True:
+ tight_result, multiple_of_10_result = sess.run(get_next)
+ self.assertEqual(0, multiple_of_10_result.shape[1] % 10)
+ self.assertAllEqual(tight_result,
+ multiple_of_10_result[:, :tight_result.shape[1]])
+ counts.append(tight_result.shape[0])
+ self.assertEqual(len(components), sum(counts))
+
+
+# NOTE(mrry): These tests are based on the tests in bucket_ops_test.py.
+# Currently, they use a constant batch size, though should be made to use a
+# different batch size per key.
+class BucketTest(test_base.DatasetTestBase):
+
+ def _dynamicPad(self, bucket, window, window_size):
+ # TODO(mrry): To match `tf.contrib.training.bucket()`, implement a
+ # generic form of padded_batch that pads every component
+ # dynamically and does not rely on static shape information about
+ # the arguments.
+ return dataset_ops.Dataset.zip(
+ (dataset_ops.Dataset.from_tensors(bucket),
+ window.padded_batch(
+ 32, (tensor_shape.TensorShape([]), tensor_shape.TensorShape(
+ [None]), tensor_shape.TensorShape([3])))))
+
+ def testSingleBucket(self):
+
+ def _map_fn(v):
+ return (v, array_ops.fill([v], v),
+ array_ops.fill([3], string_ops.as_string(v)))
+
+ input_dataset = (
+ dataset_ops.Dataset.from_tensor_slices(math_ops.range(32)).map(_map_fn))
+
+ bucketed_dataset = input_dataset.apply(
+ grouping.group_by_window(
+ lambda x, y, z: 0,
+ lambda k, bucket: self._dynamicPad(k, bucket, 32), 32))
+
+ iterator = bucketed_dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+
+ which_bucket, bucketed_values = sess.run(get_next)
+
+ self.assertEqual(0, which_bucket)
+
+ expected_scalar_int = np.arange(32, dtype=np.int64)
+ expected_unk_int64 = np.zeros((32, 31)).astype(np.int64)
+ for i in range(32):
+ expected_unk_int64[i, :i] = i
+ expected_vec3_str = np.vstack(3 * [np.arange(32).astype(bytes)]).T
+
+ self.assertAllEqual(expected_scalar_int, bucketed_values[0])
+ self.assertAllEqual(expected_unk_int64, bucketed_values[1])
+ self.assertAllEqual(expected_vec3_str, bucketed_values[2])
+
+ def testEvenOddBuckets(self):
+
+ def _map_fn(v):
+ return (v, array_ops.fill([v], v),
+ array_ops.fill([3], string_ops.as_string(v)))
+
+ input_dataset = (
+ dataset_ops.Dataset.from_tensor_slices(math_ops.range(64)).map(_map_fn))
+
+ bucketed_dataset = input_dataset.apply(
+ grouping.group_by_window(
+ lambda x, y, z: math_ops.cast(x % 2, dtypes.int64),
+ lambda k, bucket: self._dynamicPad(k, bucket, 32), 32))
+
+ iterator = bucketed_dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+
+ # Get two minibatches (one containing even values, one containing odds)
+ which_bucket_even, bucketed_values_even = sess.run(get_next)
+ which_bucket_odd, bucketed_values_odd = sess.run(get_next)
+
+ # Count number of bucket_tensors.
+ self.assertEqual(3, len(bucketed_values_even))
+ self.assertEqual(3, len(bucketed_values_odd))
+
+ # Ensure bucket 0 was used for all minibatch entries.
+ self.assertAllEqual(0, which_bucket_even)
+ self.assertAllEqual(1, which_bucket_odd)
+
+ # Test the first bucket outputted, the events starting at 0
+ expected_scalar_int = np.arange(0, 32 * 2, 2, dtype=np.int64)
+ expected_unk_int64 = np.zeros((32, 31 * 2)).astype(np.int64)
+ for i in range(0, 32):
+ expected_unk_int64[i, :2 * i] = 2 * i
+ expected_vec3_str = np.vstack(
+ 3 * [np.arange(0, 32 * 2, 2).astype(bytes)]).T
+
+ self.assertAllEqual(expected_scalar_int, bucketed_values_even[0])
+ self.assertAllEqual(expected_unk_int64, bucketed_values_even[1])
+ self.assertAllEqual(expected_vec3_str, bucketed_values_even[2])
+
+ # Test the second bucket outputted, the odds starting at 1
+ expected_scalar_int = np.arange(1, 32 * 2 + 1, 2, dtype=np.int64)
+ expected_unk_int64 = np.zeros((32, 31 * 2 + 1)).astype(np.int64)
+ for i in range(0, 32):
+ expected_unk_int64[i, :2 * i + 1] = 2 * i + 1
+ expected_vec3_str = np.vstack(
+ 3 * [np.arange(1, 32 * 2 + 1, 2).astype(bytes)]).T
+
+ self.assertAllEqual(expected_scalar_int, bucketed_values_odd[0])
+ self.assertAllEqual(expected_unk_int64, bucketed_values_odd[1])
+ self.assertAllEqual(expected_vec3_str, bucketed_values_odd[2])
+
+ def testEvenOddBucketsFilterOutAllOdd(self):
+
+ def _map_fn(v):
+ return {
+ "x": v,
+ "y": array_ops.fill([v], v),
+ "z": array_ops.fill([3], string_ops.as_string(v))
+ }
+
+ def _dynamic_pad_fn(bucket, window, _):
+ return dataset_ops.Dataset.zip(
+ (dataset_ops.Dataset.from_tensors(bucket),
+ window.padded_batch(
+ 32, {
+ "x": tensor_shape.TensorShape([]),
+ "y": tensor_shape.TensorShape([None]),
+ "z": tensor_shape.TensorShape([3])
+ })))
+
+ input_dataset = (
+ dataset_ops.Dataset.from_tensor_slices(math_ops.range(128)).map(_map_fn)
+ .filter(lambda d: math_ops.equal(d["x"] % 2, 0)))
+
+ bucketed_dataset = input_dataset.apply(
+ grouping.group_by_window(
+ lambda d: math_ops.cast(d["x"] % 2, dtypes.int64),
+ lambda k, bucket: _dynamic_pad_fn(k, bucket, 32), 32))
+
+ iterator = bucketed_dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+
+ # Get two minibatches ([0, 2, ...] and [64, 66, ...])
+ which_bucket0, bucketed_values_even0 = sess.run(get_next)
+ which_bucket1, bucketed_values_even1 = sess.run(get_next)
+
+ # Ensure that bucket 1 was completely filtered out
+ self.assertAllEqual(0, which_bucket0)
+ self.assertAllEqual(0, which_bucket1)
+ self.assertAllEqual(
+ np.arange(0, 64, 2, dtype=np.int64), bucketed_values_even0["x"])
+ self.assertAllEqual(
+ np.arange(64, 128, 2, dtype=np.int64), bucketed_values_even1["x"])
+
+ def testDynamicWindowSize(self):
+ components = np.arange(100).astype(np.int64)
+
+ # Key fn: even/odd
+ # Reduce fn: batches of 5
+ # Window size fn: even=5, odd=10
+
+ def window_size_func(key):
+ window_sizes = constant_op.constant([5, 10], dtype=dtypes.int64)
+ return window_sizes[key]
+
+ dataset = dataset_ops.Dataset.from_tensor_slices(components).apply(
+ grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(20),
+ None, window_size_func))
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ with self.assertRaises(errors.OutOfRangeError):
+ batches = 0
+ while True:
+ result = sess.run(get_next)
+ is_even = all(x % 2 == 0 for x in result)
+ is_odd = all(x % 2 == 1 for x in result)
+ self.assertTrue(is_even or is_odd)
+ expected_batch_size = 5 if is_even else 10
+ self.assertEqual(expected_batch_size, result.shape[0])
+ batches += 1
+
+ self.assertEqual(batches, 15)
+
+
+def _element_length_fn(x, y=None):
+ del y
+ return array_ops.shape(x)[0]
+
+
+def _to_sparse_tensor(record):
+ return sparse_tensor.SparseTensor(**record)
+
+
+def _format_record(array, sparse):
+ if sparse:
+ return {
+ "values": array,
+ "indices": [[i] for i in range(len(array))],
+ "dense_shape": (len(array),)
+ }
+ return array
+
+
+def _get_record_type(sparse):
+ if sparse:
+ return {
+ "values": dtypes.int64,
+ "indices": dtypes.int64,
+ "dense_shape": dtypes.int64
+ }
+ return dtypes.int32
+
+
+def _get_record_shape(sparse):
+ if sparse:
+ return {
+ "values": tensor_shape.TensorShape([None,]),
+ "indices": tensor_shape.TensorShape([None, 1]),
+ "dense_shape": tensor_shape.TensorShape([1,])
+ }
+ return tensor_shape.TensorShape([None])
+
+
+class BucketBySequenceLength(test_base.DatasetTestBase):
+
+ def testBucket(self):
+
+ boundaries = [10, 20, 30]
+ batch_sizes = [10, 8, 4, 2]
+ lengths = [8, 13, 25, 35]
+
+ def build_dataset(sparse):
+ def _generator():
+ # Produce 1 batch for each bucket
+ elements = []
+ for batch_size, length in zip(batch_sizes, lengths):
+ record_len = length - 1
+ for _ in range(batch_size):
+ elements.append([1] * record_len)
+ record_len = length
+ random.shuffle(elements)
+ for el in elements:
+ yield (_format_record(el, sparse),)
+ dataset = dataset_ops.Dataset.from_generator(
+ _generator,
+ (_get_record_type(sparse),),
+ (_get_record_shape(sparse),))
+ if sparse:
+ dataset = dataset.map(lambda x: (_to_sparse_tensor(x),))
+ return dataset
+
+ def _test_bucket_by_padding(no_padding):
+ dataset = build_dataset(sparse=no_padding)
+ dataset = dataset.apply(
+ grouping.bucket_by_sequence_length(
+ _element_length_fn,
+ boundaries,
+ batch_sizes,
+ no_padding=no_padding))
+ batch, = dataset.make_one_shot_iterator().get_next()
+
+ with self.cached_session() as sess:
+ batches = []
+ for _ in range(4):
+ batches.append(sess.run(batch))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(batch)
+ batch_sizes_val = []
+ lengths_val = []
+ for batch in batches:
+ shape = batch.dense_shape if no_padding else batch.shape
+ batch_size = shape[0]
+ length = shape[1]
+ batch_sizes_val.append(batch_size)
+ lengths_val.append(length)
+ sum_check = batch.values.sum() if no_padding else batch.sum()
+ self.assertEqual(sum_check, batch_size * length - 1)
+ self.assertEqual(sum(batch_sizes_val), sum(batch_sizes))
+ self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val))
+ self.assertEqual(sorted(lengths), sorted(lengths_val))
+
+ for no_padding in (True, False):
+ _test_bucket_by_padding(no_padding)
+
+ def testPadToBoundary(self):
+
+ boundaries = [10, 20, 30]
+ batch_sizes = [10, 8, 4, 2]
+ lengths = [8, 13, 25]
+
+ def element_gen():
+ # Produce 1 batch for each bucket
+ elements = []
+ for batch_size, length in zip(batch_sizes[:-1], lengths):
+ for _ in range(batch_size):
+ elements.append([1] * length)
+ random.shuffle(elements)
+ for el in elements:
+ yield (el,)
+ for _ in range(batch_sizes[-1]):
+ el = [1] * (boundaries[-1] + 5)
+ yield (el,)
+
+ element_len = lambda el: array_ops.shape(el)[0]
+ dataset = dataset_ops.Dataset.from_generator(
+ element_gen, (dtypes.int64,), ([None],)).apply(
+ grouping.bucket_by_sequence_length(
+ element_len, boundaries, batch_sizes,
+ pad_to_bucket_boundary=True))
+ batch, = dataset.make_one_shot_iterator().get_next()
+
+ with self.cached_session() as sess:
+ batches = []
+ for _ in range(3):
+ batches.append(sess.run(batch))
+ with self.assertRaisesOpError("bucket_boundaries"):
+ sess.run(batch)
+ batch_sizes_val = []
+ lengths_val = []
+ for batch in batches:
+ batch_size = batch.shape[0]
+ length = batch.shape[1]
+ batch_sizes_val.append(batch_size)
+ lengths_val.append(length)
+ batch_sizes = batch_sizes[:-1]
+ self.assertEqual(sum(batch_sizes_val), sum(batch_sizes))
+ self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val))
+ self.assertEqual([boundary - 1 for boundary in sorted(boundaries)],
+ sorted(lengths_val))
+
+ def testPadToBoundaryNoExtraneousPadding(self):
+
+ boundaries = [3, 7, 11]
+ batch_sizes = [2, 2, 2, 2]
+ lengths = range(1, 11)
+
+ def element_gen():
+ for length in lengths:
+ yield ([1] * length,)
+
+ element_len = lambda element: array_ops.shape(element)[0]
+ dataset = dataset_ops.Dataset.from_generator(
+ element_gen, (dtypes.int64,), ([None],)).apply(
+ grouping.bucket_by_sequence_length(
+ element_len, boundaries, batch_sizes,
+ pad_to_bucket_boundary=True))
+ batch, = dataset.make_one_shot_iterator().get_next()
+
+ with self.cached_session() as sess:
+ batches = []
+ for _ in range(5):
+ batches.append(sess.run(batch))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(batch)
+
+ self.assertAllEqual(batches[0], [[1, 0],
+ [1, 1]])
+ self.assertAllEqual(batches[1], [[1, 1, 1, 0, 0, 0],
+ [1, 1, 1, 1, 0, 0]])
+ self.assertAllEqual(batches[2], [[1, 1, 1, 1, 1, 0],
+ [1, 1, 1, 1, 1, 1]])
+ self.assertAllEqual(batches[3], [[1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
+ [1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
+ self.assertAllEqual(batches[4], [[1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
+ [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
+
+ def testTupleElements(self):
+
+ def build_dataset(sparse):
+ def _generator():
+ text = [[1, 2, 3], [3, 4, 5, 6, 7], [1, 2], [8, 9, 0, 2, 3]]
+ label = [1, 2, 1, 2]
+ for x, y in zip(text, label):
+ yield (_format_record(x, sparse), y)
+ dataset = dataset_ops.Dataset.from_generator(
+ generator=_generator,
+ output_types=(_get_record_type(sparse), dtypes.int32),
+ output_shapes=(_get_record_shape(sparse),
+ tensor_shape.TensorShape([])))
+ if sparse:
+ dataset = dataset.map(lambda x, y: (_to_sparse_tensor(x), y))
+ return dataset
+
+ def _test_tuple_elements_by_padding(no_padding):
+ dataset = build_dataset(sparse=no_padding)
+ dataset = dataset.apply(grouping.bucket_by_sequence_length(
+ element_length_func=_element_length_fn,
+ bucket_batch_sizes=[2, 2, 2],
+ bucket_boundaries=[0, 8],
+ no_padding=no_padding))
+ shapes = dataset.output_shapes
+ self.assertEqual([None, None], shapes[0].as_list())
+ self.assertEqual([None], shapes[1].as_list())
+
+ for no_padding in (True, False):
+ _test_tuple_elements_by_padding(no_padding)
+
+ def testBucketSparse(self):
+ """Tests bucketing of sparse tensors (case where `no_padding` == True).
+
+ Test runs on following dataset:
+ [
+ [0],
+ [0, 1],
+ [0, 1, 2]
+ ...
+ [0, ..., max_len - 1]
+ ]
+ Sequences are bucketed by length and batched with
+ `batch_size` < `bucket_size`.
+ """
+
+ min_len = 0
+ max_len = 100
+ batch_size = 7
+ bucket_size = 10
+
+ def _build_dataset():
+ input_data = [range(i+1) for i in range(min_len, max_len)]
+ def generator_fn():
+ for record in input_data:
+ yield _format_record(record, sparse=True)
+ dataset = dataset_ops.Dataset.from_generator(
+ generator=generator_fn,
+ output_types=_get_record_type(sparse=True))
+ dataset = dataset.map(_to_sparse_tensor)
+ return dataset
+
+ def _compute_expected_batches():
+ """Computes expected batch outputs and stores in a set."""
+ all_expected_sparse_tensors = set()
+ for bucket_start_len in range(min_len, max_len, bucket_size):
+ for batch_offset in range(0, bucket_size, batch_size):
+ batch_start_len = bucket_start_len + batch_offset
+ batch_end_len = min(batch_start_len + batch_size,
+ bucket_start_len + bucket_size)
+ expected_indices = []
+ expected_values = []
+ for length in range(batch_start_len, batch_end_len):
+ for val in range(length + 1):
+ expected_indices.append((length - batch_start_len, val))
+ expected_values.append(val)
+ expected_sprs_tensor = (tuple(expected_indices),
+ tuple(expected_values))
+ all_expected_sparse_tensors.add(expected_sprs_tensor)
+ return all_expected_sparse_tensors
+
+ def _compute_batches(dataset):
+ """Computes actual batch outputs of dataset and stores in a set."""
+ batch = dataset.make_one_shot_iterator().get_next()
+ all_sparse_tensors = set()
+ with self.cached_session() as sess:
+ with self.assertRaises(errors.OutOfRangeError):
+ while True:
+ output = sess.run(batch)
+ sprs_tensor = (tuple([tuple(idx) for idx in output.indices]),
+ tuple(output.values))
+ all_sparse_tensors.add(sprs_tensor)
+ return all_sparse_tensors
+
+ dataset = _build_dataset()
+ boundaries = range(min_len + bucket_size + 1, max_len, bucket_size)
+ dataset = dataset.apply(grouping.bucket_by_sequence_length(
+ _element_length_fn,
+ boundaries,
+ [batch_size] * (len(boundaries) + 1),
+ no_padding=True))
+ batches = _compute_batches(dataset)
+ expected_batches = _compute_expected_batches()
+ self.assertEqual(batches, expected_batches)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/csv_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/csv_dataset_op_test.py
new file mode 100644
index 0000000000..4ee1779710
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/csv_dataset_op_test.py
@@ -0,0 +1,632 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for CsvDatasetOp."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gzip
+import os
+import string
+import tempfile
+import time
+import zlib
+
+import numpy as np
+
+from tensorflow.python.client import session
+from tensorflow.python.data.experimental.ops import error_ops
+from tensorflow.python.data.experimental.ops import readers
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import readers as core_readers
+from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import parsing_ops
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import googletest
+from tensorflow.python.platform import test
+
+
+@test_util.run_all_in_graph_and_eager_modes
+class CsvDatasetOpTest(test_base.DatasetTestBase):
+
+ def _setup_files(self, inputs, linebreak='\n', compression_type=None):
+ filenames = []
+ for i, ip in enumerate(inputs):
+ fn = os.path.join(self.get_temp_dir(), 'temp_%d.csv' % i)
+ contents = linebreak.join(ip).encode('utf-8')
+ if compression_type is None:
+ with open(fn, 'wb') as f:
+ f.write(contents)
+ elif compression_type == 'GZIP':
+ with gzip.GzipFile(fn, 'wb') as f:
+ f.write(contents)
+ elif compression_type == 'ZLIB':
+ contents = zlib.compress(contents)
+ with open(fn, 'wb') as f:
+ f.write(contents)
+ else:
+ raise ValueError('Unsupported compression_type', compression_type)
+ filenames.append(fn)
+ return filenames
+
+ def _make_test_datasets(self, inputs, **kwargs):
+ # Test by comparing its output to what we could get with map->decode_csv
+ filenames = self._setup_files(inputs)
+ dataset_expected = core_readers.TextLineDataset(filenames)
+ dataset_expected = dataset_expected.map(
+ lambda l: parsing_ops.decode_csv(l, **kwargs))
+ dataset_actual = readers.CsvDataset(filenames, **kwargs)
+ return (dataset_actual, dataset_expected)
+
+ def _test_by_comparison(self, inputs, **kwargs):
+ """Checks that CsvDataset is equiv to TextLineDataset->map(decode_csv)."""
+ dataset_actual, dataset_expected = self._make_test_datasets(
+ inputs, **kwargs)
+ self.assertDatasetsEqual(dataset_actual, dataset_expected)
+
+ def _verify_output_or_err(self,
+ dataset,
+ expected_output=None,
+ expected_err_re=None):
+ if expected_err_re is None:
+ # Verify that output is expected, without errors
+ nxt = self.getNext(dataset)
+ expected_output = [[
+ v.encode('utf-8') if isinstance(v, str) else v for v in op
+ ] for op in expected_output]
+ for value in expected_output:
+ op = self.evaluate(nxt())
+ self.assertAllEqual(op, value)
+ with self.assertRaises(errors.OutOfRangeError):
+ self.evaluate(nxt())
+ else:
+ # Verify that OpError is produced as expected
+ with self.assertRaisesOpError(expected_err_re):
+ nxt = self.getNext(dataset)
+ while True:
+ try:
+ self.evaluate(nxt())
+ except errors.OutOfRangeError:
+ break
+
+ def _test_dataset(
+ self,
+ inputs,
+ expected_output=None,
+ expected_err_re=None,
+ linebreak='\n',
+ compression_type=None, # Used for both setup and parsing
+ **kwargs):
+ """Checks that elements produced by CsvDataset match expected output."""
+ # Convert str type because py3 tf strings are bytestrings
+ filenames = self._setup_files(inputs, linebreak, compression_type)
+ kwargs['compression_type'] = compression_type
+ dataset = readers.CsvDataset(filenames, **kwargs)
+ self._verify_output_or_err(dataset, expected_output, expected_err_re)
+
+ def testCsvDataset_requiredFields(self):
+ record_defaults = [[]] * 4
+ inputs = [['1,2,3,4']]
+ self._test_by_comparison(inputs, record_defaults=record_defaults)
+
+ def testCsvDataset_int(self):
+ record_defaults = [[0]] * 4
+ inputs = [['1,2,3,4', '5,6,7,8']]
+ self._test_by_comparison(inputs, record_defaults=record_defaults)
+
+ def testCsvDataset_float(self):
+ record_defaults = [[0.0]] * 4
+ inputs = [['1.0,2.1,3.2,4.3', '5.4,6.5,7.6,8.7']]
+ self._test_by_comparison(inputs, record_defaults=record_defaults)
+
+ def testCsvDataset_string(self):
+ record_defaults = [['']] * 4
+ inputs = [['1.0,2.1,hello,4.3', '5.4,6.5,goodbye,8.7']]
+ self._test_by_comparison(inputs, record_defaults=record_defaults)
+
+ def testCsvDataset_withEmptyFields(self):
+ record_defaults = [[0]] * 4
+ inputs = [[',,,', '1,1,1,', ',2,2,2']]
+ self._test_dataset(
+ inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]],
+ record_defaults=record_defaults)
+
+ def testCsvDataset_errWithUnquotedQuotes(self):
+ record_defaults = [['']] * 3
+ inputs = [['1,2"3,4']]
+ self._test_dataset(
+ inputs,
+ expected_err_re='Unquoted fields cannot have quotes inside',
+ record_defaults=record_defaults)
+
+ def testCsvDataset_errWithUnescapedQuotes(self):
+ record_defaults = [['']] * 3
+ inputs = [['"a"b","c","d"']]
+ self._test_dataset(
+ inputs,
+ expected_err_re=
+ 'Quote inside a string has to be escaped by another quote',
+ record_defaults=record_defaults)
+
+ def testCsvDataset_ignoreErrWithUnescapedQuotes(self):
+ record_defaults = [['']] * 3
+ inputs = [['1,"2"3",4', '1,"2"3",4",5,5', 'a,b,"c"d"', 'e,f,g']]
+ filenames = self._setup_files(inputs)
+ dataset = readers.CsvDataset(filenames, record_defaults=record_defaults)
+ dataset = dataset.apply(error_ops.ignore_errors())
+ self._verify_output_or_err(dataset, [['e', 'f', 'g']])
+
+ def testCsvDataset_ignoreErrWithUnquotedQuotes(self):
+ record_defaults = [['']] * 3
+ inputs = [['1,2"3,4', 'a,b,c"d', '9,8"7,6,5', 'e,f,g']]
+ filenames = self._setup_files(inputs)
+ dataset = readers.CsvDataset(filenames, record_defaults=record_defaults)
+ dataset = dataset.apply(error_ops.ignore_errors())
+ self._verify_output_or_err(dataset, [['e', 'f', 'g']])
+
+ def testCsvDataset_withNoQuoteDelimAndUnquotedQuotes(self):
+ record_defaults = [['']] * 3
+ inputs = [['1,2"3,4']]
+ self._test_by_comparison(
+ inputs, record_defaults=record_defaults, use_quote_delim=False)
+
+ def testCsvDataset_mixedTypes(self):
+ record_defaults = [
+ constant_op.constant([], dtype=dtypes.int32),
+ constant_op.constant([], dtype=dtypes.float32),
+ constant_op.constant([], dtype=dtypes.string),
+ constant_op.constant([], dtype=dtypes.float64)
+ ]
+ inputs = [['1,2.1,3.2,4.3', '5,6.5,7.6,8.7']]
+ self._test_by_comparison(inputs, record_defaults=record_defaults)
+
+ def testCsvDataset_withUseQuoteDelimFalse(self):
+ record_defaults = [['']] * 4
+ inputs = [['1,2,"3,4"', '"5,6",7,8']]
+ self._test_by_comparison(
+ inputs, record_defaults=record_defaults, use_quote_delim=False)
+
+ def testCsvDataset_withFieldDelim(self):
+ record_defaults = [[0]] * 4
+ inputs = [['1:2:3:4', '5:6:7:8']]
+ self._test_by_comparison(
+ inputs, record_defaults=record_defaults, field_delim=':')
+
+ def testCsvDataset_withNaValue(self):
+ record_defaults = [[0]] * 4
+ inputs = [['1,NA,3,4', 'NA,6,7,8']]
+ self._test_by_comparison(
+ inputs, record_defaults=record_defaults, na_value='NA')
+
+ def testCsvDataset_withSelectCols(self):
+ record_defaults = [['']] * 2
+ inputs = [['1,2,3,4', '"5","6","7","8"']]
+ self._test_by_comparison(
+ inputs, record_defaults=record_defaults, select_cols=[1, 2])
+
+ def testCsvDataset_withSelectColsTooHigh(self):
+ record_defaults = [[0]] * 2
+ inputs = [['1,2,3,4', '5,6,7,8']]
+ self._test_dataset(
+ inputs,
+ expected_err_re='Expect 2 fields but have 1 in record',
+ record_defaults=record_defaults,
+ select_cols=[3, 4])
+
+ def testCsvDataset_withOneCol(self):
+ record_defaults = [['NA']]
+ inputs = [['0', '', '2']]
+ self._test_dataset(
+ inputs, [['0'], ['NA'], ['2']], record_defaults=record_defaults)
+
+ def testCsvDataset_withMultipleFiles(self):
+ record_defaults = [[0]] * 4
+ inputs = [['1,2,3,4', '5,6,7,8'], ['5,6,7,8']]
+ self._test_by_comparison(inputs, record_defaults=record_defaults)
+
+ def testCsvDataset_withLeadingAndTrailingSpaces(self):
+ record_defaults = [[0.0]] * 4
+ inputs = [['0, 1, 2, 3']]
+ expected = [[0.0, 1.0, 2.0, 3.0]]
+ self._test_dataset(inputs, expected, record_defaults=record_defaults)
+
+ def testCsvDataset_errorWithMissingDefault(self):
+ record_defaults = [[]] * 2
+ inputs = [['0,']]
+ self._test_dataset(
+ inputs,
+ expected_err_re='Field 1 is required but missing in record!',
+ record_defaults=record_defaults)
+
+ def testCsvDataset_errorWithFewerDefaultsThanFields(self):
+ record_defaults = [[0.0]] * 2
+ inputs = [['0,1,2,3']]
+ self._test_dataset(
+ inputs,
+ expected_err_re='Expect 2 fields but have more in record',
+ record_defaults=record_defaults)
+
+ def testCsvDataset_errorWithMoreDefaultsThanFields(self):
+ record_defaults = [[0.0]] * 5
+ inputs = [['0,1,2,3']]
+ self._test_dataset(
+ inputs,
+ expected_err_re='Expect 5 fields but have 4 in record',
+ record_defaults=record_defaults)
+
+ def testCsvDataset_withHeader(self):
+ record_defaults = [[0]] * 2
+ inputs = [['col1,col2', '1,2']]
+ expected = [[1, 2]]
+ self._test_dataset(
+ inputs,
+ expected,
+ record_defaults=record_defaults,
+ header=True,
+ )
+
+ def testCsvDataset_withHeaderAndNoRecords(self):
+ record_defaults = [[0]] * 2
+ inputs = [['col1,col2']]
+ expected = []
+ self._test_dataset(
+ inputs,
+ expected,
+ record_defaults=record_defaults,
+ header=True,
+ )
+
+ def testCsvDataset_errorWithHeaderEmptyFile(self):
+ record_defaults = [[0]] * 2
+ inputs = [[]]
+ expected_err_re = "Can't read header of file"
+ self._test_dataset(
+ inputs,
+ expected_err_re=expected_err_re,
+ record_defaults=record_defaults,
+ header=True,
+ )
+
+ def testCsvDataset_withEmptyFile(self):
+ record_defaults = [['']] * 2
+ inputs = [['']] # Empty file
+ self._test_dataset(
+ inputs, expected_output=[], record_defaults=record_defaults)
+
+ def testCsvDataset_errorWithEmptyRecord(self):
+ record_defaults = [['']] * 2
+ inputs = [['', '1,2']] # First record is empty
+ self._test_dataset(
+ inputs,
+ expected_err_re='Expect 2 fields but have 1 in record',
+ record_defaults=record_defaults)
+
+ def testCsvDataset_withChainedOps(self):
+ # Testing that one dataset can create multiple iterators fine.
+ # `repeat` creates multiple iterators from the same C++ Dataset.
+ record_defaults = [[0]] * 4
+ inputs = [['1,,3,4', '5,6,,8']]
+ ds_actual, ds_expected = self._make_test_datasets(
+ inputs, record_defaults=record_defaults)
+ self.assertDatasetsEqual(
+ ds_actual.repeat(5).prefetch(1),
+ ds_expected.repeat(5).prefetch(1))
+
+ def testCsvDataset_withTypeDefaults(self):
+ # Testing using dtypes as record_defaults for required fields
+ record_defaults = [dtypes.float32, [0.0]]
+ inputs = [['1.0,2.0', '3.0,4.0']]
+ self._test_dataset(
+ inputs,
+ [[1.0, 2.0], [3.0, 4.0]],
+ record_defaults=record_defaults,
+ )
+
+ def testMakeCsvDataset_fieldOrder(self):
+ data = [[
+ '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19',
+ '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19'
+ ]]
+ file_path = self._setup_files(data)
+
+ ds = readers.make_csv_dataset(
+ file_path, batch_size=1, shuffle=False, num_epochs=1)
+ nxt = self.getNext(ds)
+
+ result = list(self.evaluate(nxt()).values())
+
+ self.assertEqual(result, sorted(result))
+
+## The following tests exercise parsing logic for quoted fields
+
+ def testCsvDataset_withQuoted(self):
+ record_defaults = [['']] * 4
+ inputs = [['"a","b","c :)","d"', '"e","f","g :(","h"']]
+ self._test_by_comparison(inputs, record_defaults=record_defaults)
+
+ def testCsvDataset_withOneColAndQuotes(self):
+ record_defaults = [['']]
+ inputs = [['"0"', '"1"', '"2"']]
+ self._test_dataset(
+ inputs, [['0'], ['1'], ['2']], record_defaults=record_defaults)
+
+ def testCsvDataset_withNewLine(self):
+ # In this case, we expect it to behave differently from
+ # TextLineDataset->map(decode_csv) since that flow has bugs
+ record_defaults = [['']] * 4
+ inputs = [['a,b,"""c""\n0","d\ne"', 'f,g,h,i']]
+ expected = [['a', 'b', '"c"\n0', 'd\ne'], ['f', 'g', 'h', 'i']]
+ self._test_dataset(inputs, expected, record_defaults=record_defaults)
+
+ def testCsvDataset_withNewLineInUnselectedCol(self):
+ record_defaults = [['']]
+ inputs = [['1,"2\n3",4', '5,6,7']]
+ self._test_dataset(
+ inputs,
+ expected_output=[['1'], ['5']],
+ record_defaults=record_defaults,
+ select_cols=[0])
+
+ def testCsvDataset_withMultipleNewLines(self):
+ # In this case, we expect it to behave differently from
+ # TextLineDataset->map(decode_csv) since that flow has bugs
+ record_defaults = [['']] * 4
+ inputs = [['a,"b\n\nx","""c""\n \n0","d\ne"', 'f,g,h,i']]
+ expected = [['a', 'b\n\nx', '"c"\n \n0', 'd\ne'], ['f', 'g', 'h', 'i']]
+ self._test_dataset(inputs, expected, record_defaults=record_defaults)
+
+ def testCsvDataset_errorWithTerminateMidRecord(self):
+ record_defaults = [['']] * 4
+ inputs = [['a,b,c,"a']]
+ self._test_dataset(
+ inputs,
+ expected_err_re=
+ 'Reached end of file without closing quoted field in record',
+ record_defaults=record_defaults)
+
+ def testCsvDataset_withEscapedQuotes(self):
+ record_defaults = [['']] * 4
+ inputs = [['1.0,2.1,"she said: ""hello""",4.3', '5.4,6.5,goodbye,8.7']]
+ self._test_by_comparison(inputs, record_defaults=record_defaults)
+
+
+## Testing that parsing works with all buffer sizes, quoted/unquoted fields,
+## and different types of line breaks
+
+ def testCsvDataset_withInvalidBufferSize(self):
+ record_defaults = [['']] * 4
+ inputs = [['a,b,c,d']]
+ self._test_dataset(
+ inputs,
+ expected_err_re='buffer_size should be positive',
+ record_defaults=record_defaults,
+ buffer_size=0)
+
+ def _test_dataset_on_buffer_sizes(self,
+ inputs,
+ expected,
+ linebreak,
+ record_defaults,
+ compression_type=None,
+ num_sizes_to_test=20):
+ # Testing reading with a range of buffer sizes that should all work.
+ for i in list(range(1, 1 + num_sizes_to_test)) + [None]:
+ self._test_dataset(
+ inputs,
+ expected,
+ linebreak=linebreak,
+ compression_type=compression_type,
+ record_defaults=record_defaults,
+ buffer_size=i)
+
+ def testCsvDataset_withLF(self):
+ record_defaults = [['NA']] * 3
+ inputs = [['abc,def,ghi', '0,1,2', ',,']]
+ expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']]
+ self._test_dataset_on_buffer_sizes(
+ inputs, expected, linebreak='\n', record_defaults=record_defaults)
+
+ def testCsvDataset_withCR(self):
+ # Test that when the line separator is '\r', parsing works with all buffer
+ # sizes
+ record_defaults = [['NA']] * 3
+ inputs = [['abc,def,ghi', '0,1,2', ',,']]
+ expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']]
+ self._test_dataset_on_buffer_sizes(
+ inputs, expected, linebreak='\r', record_defaults=record_defaults)
+
+ def testCsvDataset_withCRLF(self):
+ # Test that when the line separator is '\r\n', parsing works with all buffer
+ # sizes
+ record_defaults = [['NA']] * 3
+ inputs = [['abc,def,ghi', '0,1,2', ',,']]
+ expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']]
+ self._test_dataset_on_buffer_sizes(
+ inputs, expected, linebreak='\r\n', record_defaults=record_defaults)
+
+ def testCsvDataset_withBufferSizeAndQuoted(self):
+ record_defaults = [['NA']] * 3
+ inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']]
+ expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'],
+ ['NA', 'NA', 'NA']]
+ self._test_dataset_on_buffer_sizes(
+ inputs, expected, linebreak='\n', record_defaults=record_defaults)
+
+ def testCsvDataset_withCRAndQuoted(self):
+ # Test that when the line separator is '\r', parsing works with all buffer
+ # sizes
+ record_defaults = [['NA']] * 3
+ inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']]
+ expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'],
+ ['NA', 'NA', 'NA']]
+ self._test_dataset_on_buffer_sizes(
+ inputs, expected, linebreak='\r', record_defaults=record_defaults)
+
+ def testCsvDataset_withCRLFAndQuoted(self):
+ # Test that when the line separator is '\r\n', parsing works with all buffer
+ # sizes
+ record_defaults = [['NA']] * 3
+ inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']]
+ expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'],
+ ['NA', 'NA', 'NA']]
+ self._test_dataset_on_buffer_sizes(
+ inputs, expected, linebreak='\r\n', record_defaults=record_defaults)
+
+ def testCsvDataset_withGzipCompressionType(self):
+ record_defaults = [['NA']] * 3
+ inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']]
+ expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'],
+ ['NA', 'NA', 'NA']]
+ self._test_dataset_on_buffer_sizes(
+ inputs,
+ expected,
+ linebreak='\r\n',
+ compression_type='GZIP',
+ record_defaults=record_defaults)
+
+ def testCsvDataset_withZlibCompressionType(self):
+ record_defaults = [['NA']] * 3
+ inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']]
+ expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'],
+ ['NA', 'NA', 'NA']]
+ self._test_dataset_on_buffer_sizes(
+ inputs,
+ expected,
+ linebreak='\r\n',
+ compression_type='ZLIB',
+ record_defaults=record_defaults)
+
+ def testCsvDataset_withScalarDefaults(self):
+ record_defaults = [constant_op.constant(0, dtype=dtypes.int64)] * 4
+ inputs = [[',,,', '1,1,1,', ',2,2,2']]
+ self._test_dataset(
+ inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]],
+ record_defaults=record_defaults)
+
+ def testCsvDataset_with2DDefaults(self):
+ record_defaults = [constant_op.constant([[0]], dtype=dtypes.int64)] * 4
+ inputs = [[',,,', '1,1,1,', ',2,2,2']]
+
+ if context.executing_eagerly():
+ err_spec = errors.InvalidArgumentError, (
+ 'Each record default should be at '
+ 'most rank 1.')
+ else:
+ err_spec = ValueError, 'Shape must be at most rank 1 but is rank 2'
+
+ with self.assertRaisesWithPredicateMatch(*err_spec):
+ self._test_dataset(
+ inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]],
+ record_defaults=record_defaults)
+
+
+class CsvDatasetBenchmark(test.Benchmark):
+ """Benchmarks for the various ways of creating a dataset from CSV files.
+ """
+ FLOAT_VAL = '1.23456E12'
+ STR_VAL = string.ascii_letters * 10
+
+ def _setUp(self, str_val):
+ # Since this isn't test.TestCase, have to manually create a test dir
+ gfile.MakeDirs(googletest.GetTempDir())
+ self._temp_dir = tempfile.mkdtemp(dir=googletest.GetTempDir())
+
+ self._num_cols = [4, 64, 256]
+ self._num_per_iter = 5000
+ self._filenames = []
+ for n in self._num_cols:
+ fn = os.path.join(self._temp_dir, 'file%d.csv' % n)
+ with open(fn, 'wb') as f:
+ # Just write 100 rows and use `repeat`... Assumes the cost
+ # of creating an iterator is not significant
+ row = ','.join([str_val for _ in range(n)])
+ f.write('\n'.join([row for _ in range(100)]))
+ self._filenames.append(fn)
+
+ def _tearDown(self):
+ gfile.DeleteRecursively(self._temp_dir)
+
+ def _runBenchmark(self, dataset, num_cols, prefix):
+ dataset = dataset.skip(self._num_per_iter - 1)
+ deltas = []
+ for _ in range(10):
+ next_element = dataset.make_one_shot_iterator().get_next()
+ with session.Session() as sess:
+ start = time.time()
+ # NOTE: This depends on the underlying implementation of skip, to have
+ # the net effect of calling `GetNext` num_per_iter times on the
+ # input dataset. We do it this way (instead of a python for loop, or
+ # batching N inputs in one iter) so that the overhead from session.run
+ # or batch doesn't dominate. If we eventually optimize skip, this has
+ # to change.
+ sess.run(next_element)
+ end = time.time()
+ deltas.append(end - start)
+ # Median wall time per CSV record read and decoded
+ median_wall_time = np.median(deltas) / self._num_per_iter
+ print('%s num_cols: %d Median wall time: %f' % (prefix, num_cols,
+ median_wall_time))
+ self.report_benchmark(
+ iters=self._num_per_iter,
+ wall_time=median_wall_time,
+ name='%s_with_cols_%d' % (prefix, num_cols))
+
+ def benchmarkMapWithFloats(self):
+ self._setUp(self.FLOAT_VAL)
+ for i in range(len(self._filenames)):
+ num_cols = self._num_cols[i]
+ kwargs = {'record_defaults': [[0.0]] * num_cols}
+ dataset = core_readers.TextLineDataset(self._filenames[i]).repeat()
+ dataset = dataset.map(lambda l: parsing_ops.decode_csv(l, **kwargs)) # pylint: disable=cell-var-from-loop
+ self._runBenchmark(dataset, num_cols, 'csv_float_map_decode_csv')
+ self._tearDown()
+
+ def benchmarkMapWithStrings(self):
+ self._setUp(self.STR_VAL)
+ for i in range(len(self._filenames)):
+ num_cols = self._num_cols[i]
+ kwargs = {'record_defaults': [['']] * num_cols}
+ dataset = core_readers.TextLineDataset(self._filenames[i]).repeat()
+ dataset = dataset.map(lambda l: parsing_ops.decode_csv(l, **kwargs)) # pylint: disable=cell-var-from-loop
+ self._runBenchmark(dataset, num_cols, 'csv_strings_map_decode_csv')
+ self._tearDown()
+
+ def benchmarkCsvDatasetWithFloats(self):
+ self._setUp(self.FLOAT_VAL)
+ for i in range(len(self._filenames)):
+ num_cols = self._num_cols[i]
+ kwargs = {'record_defaults': [[0.0]] * num_cols}
+ dataset = core_readers.TextLineDataset(self._filenames[i]).repeat()
+ dataset = readers.CsvDataset(self._filenames[i], **kwargs).repeat() # pylint: disable=cell-var-from-loop
+ self._runBenchmark(dataset, num_cols, 'csv_float_fused_dataset')
+ self._tearDown()
+
+ def benchmarkCsvDatasetWithStrings(self):
+ self._setUp(self.STR_VAL)
+ for i in range(len(self._filenames)):
+ num_cols = self._num_cols[i]
+ kwargs = {'record_defaults': [['']] * num_cols}
+ dataset = core_readers.TextLineDataset(self._filenames[i]).repeat()
+ dataset = readers.CsvDataset(self._filenames[i], **kwargs).repeat() # pylint: disable=cell-var-from-loop
+ self._runBenchmark(dataset, num_cols, 'csv_strings_fused_dataset')
+ self._tearDown()
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/dataset_constructor_op_test.py b/tensorflow/python/data/experimental/kernel_tests/dataset_constructor_op_test.py
new file mode 100644
index 0000000000..3fc7157bc5
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/dataset_constructor_op_test.py
@@ -0,0 +1,71 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.ops import batching
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import nest
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class DatasetConstructorTest(test_base.DatasetTestBase):
+
+ def testRestructureDataset(self):
+ components = (array_ops.placeholder(dtypes.int32),
+ (array_ops.placeholder(dtypes.int32, shape=[None]),
+ array_ops.placeholder(dtypes.int32, shape=[20, 30])))
+ dataset = dataset_ops.Dataset.from_tensors(components)
+
+ i32 = dtypes.int32
+
+ test_cases = [((i32, i32, i32), None),
+ (((i32, i32), i32), None),
+ ((i32, i32, i32), (None, None, None)),
+ ((i32, i32, i32), ([17], [17], [20, 30]))]
+
+ for new_types, new_shape_lists in test_cases:
+ # pylint: disable=protected-access
+ new = batching._RestructuredDataset(dataset, new_types, new_shape_lists)
+ # pylint: enable=protected-access
+ self.assertEqual(new_types, new.output_types)
+ if new_shape_lists is not None:
+ for expected_shape_list, shape in zip(
+ nest.flatten(new_shape_lists), nest.flatten(new.output_shapes)):
+ if expected_shape_list is None:
+ self.assertIs(None, shape.ndims)
+ else:
+ self.assertEqual(expected_shape_list, shape.as_list())
+
+ fail_cases = [((i32, dtypes.int64, i32), None),
+ ((i32, i32, i32, i32), None),
+ ((i32, i32, i32), ((None, None), None)),
+ ((i32, i32, i32), (None, None, None, None)),
+ ((i32, i32, i32), (None, [None], [21, 30]))]
+
+ for new_types, new_shape_lists in fail_cases:
+ with self.assertRaises(ValueError):
+ # pylint: disable=protected-access
+ new = batching._RestructuredDataset(dataset, new_types, new_shape_lists)
+ # pylint: enable=protected-access
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/dataset_serialization_test_base.py b/tensorflow/python/data/experimental/kernel_tests/dataset_serialization_test_base.py
new file mode 100644
index 0000000000..7f435b8239
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/dataset_serialization_test_base.py
@@ -0,0 +1,692 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Base class for testing serializable datasets."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+import numpy as np
+
+from tensorflow.python.data.experimental.ops import iterator_ops as contrib_iterator_ops
+from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import lookup_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import test
+from tensorflow.python.training import checkpoint_management
+from tensorflow.python.training import saver as saver_lib
+from tensorflow.python.util import nest
+
+
+def remove_variants(get_next_op):
+ # TODO(b/72408568): Remove this once session.run can get
+ # variant tensors.
+ """Remove variants from a nest structure, so sess.run will execute."""
+
+ def _remove_variant(x):
+ if isinstance(x, ops.Tensor) and x.dtype == dtypes.variant:
+ return ()
+ else:
+ return x
+
+ return nest.map_structure(_remove_variant, get_next_op)
+
+
+class DatasetSerializationTestBase(test.TestCase):
+ """Base class for testing serializable datasets."""
+
+ def tearDown(self):
+ self._delete_ckpt()
+
+ # TODO(b/72657739): Remove sparse_tensor argument, which is to test the
+ # (deprecated) saveable `SparseTensorSliceDataset`, once the API
+ # `from_sparse_tensor_slices()`and related tests are deleted.
+ def run_core_tests(self, ds_fn1, ds_fn2, num_outputs, sparse_tensors=False):
+ """Runs the core tests.
+
+ Args:
+ ds_fn1: 0-argument function that returns a Dataset.
+ ds_fn2: 0-argument function that returns a Dataset different from
+ ds_fn1. If None, verify_restore_in_modified_graph test is not run.
+ num_outputs: Total number of outputs expected from this Dataset.
+ sparse_tensors: Whether dataset is built from SparseTensor(s).
+
+ Raises:
+ AssertionError if any test fails.
+ """
+ self.verify_unused_iterator(
+ ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
+ self.verify_fully_used_iterator(
+ ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
+ self.verify_exhausted_iterator(
+ ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
+ self.verify_init_before_restore(
+ ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
+ self.verify_multiple_breaks(
+ ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
+ self.verify_reset_restored_iterator(
+ ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
+ self.verify_restore_in_empty_graph(
+ ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
+ if ds_fn2:
+ self.verify_restore_in_modified_graph(
+ ds_fn1, ds_fn2, num_outputs, sparse_tensors=sparse_tensors)
+
+ def verify_unused_iterator(self,
+ ds_fn,
+ num_outputs,
+ sparse_tensors=False,
+ verify_exhausted=True):
+ """Verifies that saving and restoring an unused iterator works.
+
+ Args:
+ ds_fn: See `run_core_tests`.
+ num_outputs: See `run_core_tests`.
+ sparse_tensors: See `run_core_tests`.
+ verify_exhausted: See `gen_outputs`.
+
+ Raises:
+ AssertionError if any test fails.
+ """
+ self.verify_run_with_breaks(
+ ds_fn, [0],
+ num_outputs,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=verify_exhausted)
+
+ def verify_fully_used_iterator(self, ds_fn, num_outputs,
+ sparse_tensors=False):
+ """Verifies that saving and restoring a fully used iterator works.
+
+ Note that this only checks saving and restoring an iterator from which
+ `num_outputs` items have been produced but does not check for an
+ exhausted iterator, i.e., one from which an OutOfRange error has been
+ returned.
+
+ Args:
+ ds_fn: See `run_core_tests`.
+ num_outputs: See `run_core_tests`.
+ sparse_tensors: See `run_core_tests`.
+
+ Raises:
+ AssertionError if test fails.
+ """
+ self.verify_run_with_breaks(
+ ds_fn, [num_outputs], num_outputs, sparse_tensors=sparse_tensors)
+
+ def verify_exhausted_iterator(self, ds_fn, num_outputs, sparse_tensors=False):
+ """Verifies that saving and restoring an exhausted iterator works.
+
+ An exhausted iterator is one which has returned an OutOfRange error.
+
+ Args:
+ ds_fn: See `run_core_tests`.
+ num_outputs: See `run_core_tests`.
+ sparse_tensors: See `run_core_tests`.
+
+ Raises:
+ AssertionError if any test fails.
+ """
+ self.gen_outputs(
+ ds_fn, [],
+ num_outputs,
+ verify_exhausted=True,
+ sparse_tensors=sparse_tensors)
+ actual = self.gen_outputs(
+ ds_fn, [],
+ 0,
+ ckpt_saved=True,
+ verify_exhausted=True,
+ sparse_tensors=sparse_tensors)
+ self.assertEqual(len(actual), 0)
+
+ def verify_init_before_restore(self,
+ ds_fn,
+ num_outputs,
+ sparse_tensors=False,
+ verify_exhausted=True):
+ """Verifies that restoring into an already initialized iterator works.
+
+ Args:
+ ds_fn: See `run_core_tests`.
+ num_outputs: See `run_core_tests`.
+ sparse_tensors: See `run_core_tests`.
+ verify_exhausted: See `gen_outputs`.
+
+ Raises:
+ AssertionError if any test fails.
+ """
+ self.verify_run_with_breaks(
+ ds_fn,
+ self.gen_break_points(num_outputs),
+ num_outputs,
+ init_before_restore=True,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=verify_exhausted)
+
+ def verify_multiple_breaks(self,
+ ds_fn,
+ num_outputs,
+ num_breaks=10,
+ sparse_tensors=False,
+ verify_exhausted=True):
+ """Attempts to save/restore at multiple break points.
+
+ Args:
+ ds_fn: See `run_core_tests`.
+ num_outputs: See `run_core_tests`.
+ num_breaks: The number of break points. These are uniformly spread in
+ [0, num_outputs] both inclusive.
+ sparse_tensors: See `run_core_tests`.
+ verify_exhausted: See `gen_outputs`.
+
+ Raises:
+ AssertionError if any test fails.
+ """
+ self.verify_run_with_breaks(
+ ds_fn,
+ self.gen_break_points(num_outputs, num_breaks),
+ num_outputs,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=verify_exhausted)
+
+ def verify_reset_restored_iterator(self,
+ ds_fn,
+ num_outputs,
+ break_point=None,
+ sparse_tensors=False,
+ verify_exhausted=True):
+ """Attempts to re-initialize a restored iterator.
+
+ This is useful when restoring a training checkpoint during validation.
+
+ Args:
+ ds_fn: See `run_core_tests`.
+ num_outputs: See `run_core_tests`.
+ break_point: Break point. Optional. Defaults to num_outputs/2.
+ sparse_tensors: See `run_core_tests`.
+ verify_exhausted: See `gen_outputs`.
+
+ Raises:
+ AssertionError if any test fails.
+ """
+ break_point = num_outputs // 2 if not break_point else break_point
+
+ # Collect ground truth containing all outputs.
+ expected = self.gen_outputs(
+ ds_fn, [],
+ num_outputs,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=verify_exhausted)
+
+ # Skip some items and save checkpoint.
+ self.gen_outputs(
+ ds_fn, [],
+ break_point,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=False)
+
+ actual = []
+ # Restore from checkpoint and then run init_op.
+ with ops.Graph().as_default() as g:
+ saver = self._import_meta_graph()
+ init_op, get_next_op = self._get_iterator_ops_from_collection(
+ ds_fn, sparse_tensors=sparse_tensors)
+ get_next_op = remove_variants(get_next_op)
+ with self.session(graph=g) as sess:
+ self._restore(saver, sess)
+ self._initialize(init_op, sess)
+ for _ in range(num_outputs):
+ actual.append(sess.run(get_next_op))
+ if verify_exhausted:
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
+ self.match(expected, actual)
+
+ def verify_restore_in_modified_graph(self,
+ ds_fn1,
+ ds_fn2,
+ num_outputs,
+ break_point=None,
+ sparse_tensors=False,
+ verify_exhausted=True):
+ """Attempts to restore an iterator in a modified graph.
+
+ Builds an input pipeline using ds_fn1, runs it for `break_point` steps
+ and saves a checkpoint. Then builds a new graph using ds_fn2, restores
+ the checkpoint from ds_fn1 and verifies that the restore is successful.
+
+ Args:
+ ds_fn1: See `run_core_tests`.
+ ds_fn2: See `run_core_tests`.
+ num_outputs: See `run_core_tests`.
+ break_point: Break point. Optional. Defaults to num_outputs/2.
+ sparse_tensors: See `run_core_tests`.
+ verify_exhausted: See `gen_outputs`.
+
+ Raises:
+ AssertionError if any test fails.
+ """
+ break_point = num_outputs // 2 if not break_point else break_point
+
+ # Skip `break_point` items and store the remaining produced from ds_fn1
+ # in `expected`.
+ self.gen_outputs(
+ ds_fn1, [],
+ break_point,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=False)
+ expected = self.gen_outputs(
+ ds_fn1, [],
+ num_outputs - break_point,
+ ckpt_saved=True,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=verify_exhausted)
+
+ # Generate `break_point` items from ds_fn1 and save checkpoint.
+ self.gen_outputs(
+ ds_fn1, [],
+ break_point,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=False)
+
+ actual = []
+ # Build graph for ds_fn2 but load checkpoint for ds_fn1.
+ with ops.Graph().as_default() as g:
+ _, get_next_op, saver = self._build_graph(
+ ds_fn2, sparse_tensors=sparse_tensors)
+ get_next_op = remove_variants(get_next_op)
+ with self.session(graph=g) as sess:
+ self._restore(saver, sess)
+ for _ in range(num_outputs - break_point):
+ actual.append(sess.run(get_next_op))
+ if verify_exhausted:
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
+
+ self.match(expected, actual)
+
+ def verify_restore_in_empty_graph(self,
+ ds_fn,
+ num_outputs,
+ break_point=None,
+ sparse_tensors=False,
+ verify_exhausted=True):
+ """Attempts to restore an iterator in an empty graph.
+
+ Builds an input pipeline using ds_fn, runs it for `break_point` steps
+ and saves a checkpoint. Then builds a new empty graph, restores
+ the checkpoint from ds_fn and verifies that the restore is successful.
+
+ Args:
+ ds_fn: See `run_core_tests`.
+ num_outputs: See `run_core_tests`.
+ break_point: Break point. Optional. Defaults to num_outputs/2.
+ sparse_tensors: See `run_core_tests`.
+ verify_exhausted: See `gen_outputs`.
+
+ Raises:
+ AssertionError if any test fails.
+ """
+ break_point = num_outputs // 2 if not break_point else break_point
+
+ # Skip `break_point` items and store the remaining produced from ds_fn
+ # in `expected`.
+ self.gen_outputs(
+ ds_fn, [],
+ break_point,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=False)
+ expected = self.gen_outputs(
+ ds_fn, [],
+ num_outputs - break_point,
+ ckpt_saved=True,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=verify_exhausted)
+
+ # Generate `break_point` items from ds_fn and save checkpoint.
+ self.gen_outputs(
+ ds_fn, [],
+ break_point,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=False)
+
+ actual = []
+ # Build an empty graph but load checkpoint for ds_fn.
+ with ops.Graph().as_default() as g:
+ get_next_op, saver = self._build_empty_graph(
+ ds_fn, sparse_tensors=sparse_tensors)
+ get_next_op = remove_variants(get_next_op)
+ with self.session(graph=g) as sess:
+ self._restore(saver, sess)
+ for _ in range(num_outputs - break_point):
+ actual.append(sess.run(get_next_op))
+ if verify_exhausted:
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
+
+ self.match(expected, actual)
+
+ def verify_error_on_save(self,
+ ds_fn,
+ num_outputs,
+ error,
+ break_point=None,
+ sparse_tensors=False):
+ """Attempts to save a non-saveable iterator.
+
+ Args:
+ ds_fn: See `run_core_tests`.
+ num_outputs: See `run_core_tests`.
+ error: Declared error when trying to save iterator.
+ break_point: Break point. Optional. Defaults to num_outputs/2.
+ sparse_tensors: See `run_core_tests`.
+
+ Raises:
+ AssertionError if any test fails.
+ """
+
+ break_point = num_outputs // 2 if not break_point else break_point
+ with ops.Graph().as_default() as g:
+ init_op, get_next_op, saver = self._build_graph(
+ ds_fn, sparse_tensors=sparse_tensors)
+ get_next_op = remove_variants(get_next_op)
+ with self.session(graph=g) as sess:
+ self._initialize(init_op, sess)
+ for _ in range(break_point):
+ sess.run(get_next_op)
+ with self.assertRaises(error):
+ self._save(sess, saver)
+
+ def verify_run_with_breaks(self,
+ ds_fn,
+ break_points,
+ num_outputs,
+ init_before_restore=False,
+ sparse_tensors=False,
+ verify_exhausted=True):
+ """Verifies that ds_fn() produces the same outputs with and without breaks.
+
+ 1. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it
+ *without* stopping at break points.
+ 2. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it
+ with stopping at break points.
+
+ Deep matches outputs from 1 and 2.
+
+ Args:
+ ds_fn: See `gen_outputs`.
+ break_points: See `gen_outputs`.
+ num_outputs: See `gen_outputs`.
+ init_before_restore: See `gen_outputs`.
+ sparse_tensors: See `run_core_tests`.
+ verify_exhausted: See `gen_outputs`.
+
+ Raises:
+ AssertionError if any test fails.
+ """
+ expected = self.gen_outputs(
+ ds_fn, [],
+ num_outputs,
+ init_before_restore=init_before_restore,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=verify_exhausted)
+
+ actual = self.gen_outputs(
+ ds_fn,
+ break_points,
+ num_outputs,
+ init_before_restore=init_before_restore,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=verify_exhausted)
+
+ self.match(expected, actual)
+
+ def gen_outputs(self,
+ ds_fn,
+ break_points,
+ num_outputs,
+ ckpt_saved=False,
+ init_before_restore=False,
+ sparse_tensors=False,
+ verify_exhausted=True,
+ save_checkpoint_at_end=True):
+ """Generates elements from input dataset while stopping at break points.
+
+ Produces `num_outputs` outputs and saves the state of the iterator in the
+ Saver checkpoint.
+
+ Args:
+ ds_fn: 0-argument function that returns the dataset.
+ break_points: A list of integers. For each `break_point` in
+ `break_points`, we produce outputs till `break_point` number of items
+ have been produced and then checkpoint the state. The current graph
+ and session are destroyed and a new graph and session are used to
+ produce outputs till next checkpoint or till `num_outputs` elements
+ have been produced. `break_point` must be <= `num_outputs`.
+ num_outputs: The total number of outputs to produce from the iterator.
+ ckpt_saved: Whether a checkpoint already exists. If False, we build the
+ graph from ds_fn.
+ init_before_restore: Whether init should be called before saver.restore.
+ This is just so that we can verify that restoring an already initialized
+ iterator works.
+ sparse_tensors: Whether dataset is built from SparseTensor(s).
+ verify_exhausted: Whether to verify that the iterator has been exhausted
+ after producing `num_outputs` elements.
+ save_checkpoint_at_end: Whether to save a checkpoint after producing all
+ outputs. If False, checkpoints are saved each break point but not at the
+ end. Note that checkpoints overwrite each other so there is always only
+ a single checkpoint available. Defaults to True.
+
+ Returns:
+ A list of `num_outputs` items.
+ """
+ outputs = []
+
+ def get_ops():
+ if ckpt_saved:
+ saver = self._import_meta_graph()
+ init_op, get_next_op = self._get_iterator_ops_from_collection(
+ ds_fn, sparse_tensors=sparse_tensors)
+ else:
+ init_op, get_next_op, saver = self._build_graph(
+ ds_fn, sparse_tensors=sparse_tensors)
+ return init_op, get_next_op, saver
+
+ for i in range(len(break_points) + 1):
+ with ops.Graph().as_default() as g:
+ init_op, get_next_op, saver = get_ops()
+ get_next_op = remove_variants(get_next_op)
+ with self.session(graph=g) as sess:
+ if ckpt_saved:
+ if init_before_restore:
+ self._initialize(init_op, sess)
+ self._restore(saver, sess)
+ else:
+ self._initialize(init_op, sess)
+ start = break_points[i - 1] if i > 0 else 0
+ end = break_points[i] if i < len(break_points) else num_outputs
+ num_iters = end - start
+ for _ in range(num_iters):
+ outputs.append(sess.run(get_next_op))
+ if i == len(break_points) and verify_exhausted:
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
+ if save_checkpoint_at_end or i < len(break_points):
+ self._save(sess, saver)
+ ckpt_saved = True
+
+ return outputs
+
+ def match(self, expected, actual):
+ """Matches nested structures.
+
+ Recursively matches shape and values of `expected` and `actual`.
+ Handles scalars, numpy arrays and other python sequence containers
+ e.g. list, dict.
+
+ Args:
+ expected: Nested structure 1.
+ actual: Nested structure 2.
+
+ Raises:
+ AssertionError if matching fails.
+ """
+ if isinstance(expected, np.ndarray):
+ expected = expected.tolist()
+ if isinstance(actual, np.ndarray):
+ actual = actual.tolist()
+ self.assertEqual(type(expected), type(actual))
+
+ if nest.is_sequence(expected):
+ self.assertEqual(len(expected), len(actual))
+ if isinstance(expected, dict):
+ for key1, key2 in zip(sorted(expected), sorted(actual)):
+ self.assertEqual(key1, key2)
+ self.match(expected[key1], actual[key2])
+ else:
+ for item1, item2 in zip(expected, actual):
+ self.match(item1, item2)
+ else:
+ self.assertEqual(expected, actual)
+
+ def does_not_match(self, expected, actual):
+ with self.assertRaises(AssertionError):
+ self.match(expected, actual)
+
+ def gen_break_points(self, num_outputs, num_samples=10):
+ """Generates `num_samples` breaks points in [0, num_outputs]."""
+ return np.linspace(0, num_outputs, num_samples, dtype=int)
+
+ def _build_graph(self, ds_fn, sparse_tensors=False):
+ iterator = ds_fn().make_initializable_iterator()
+
+ saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
+ ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
+ init_op = iterator.initializer
+ if sparse_tensors:
+ get_next = sparse_tensor.SparseTensor(*iterator.get_next())
+ else:
+ get_next = iterator.get_next()
+ self._add_iterator_ops_to_collection(init_op, get_next, ds_fn,
+ sparse_tensors)
+ saver = saver_lib.Saver(allow_empty=True)
+ return init_op, get_next, saver
+
+ def _build_empty_graph(self, ds_fn, sparse_tensors=False):
+ iterator = iterator_ops.Iterator.from_structure(
+ self._get_output_types(ds_fn),
+ output_shapes=self._get_output_shapes(ds_fn),
+ output_classes=self._get_output_classes(ds_fn))
+ saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
+ ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
+ if sparse_tensors:
+ get_next = sparse_tensor.SparseTensor(*iterator.get_next())
+ else:
+ get_next = iterator.get_next()
+ saver = saver_lib.Saver(allow_empty=True)
+ return get_next, saver
+
+ def _add_iterator_ops_to_collection(self,
+ init_op,
+ get_next,
+ ds_fn,
+ sparse_tensors=False):
+ ops.add_to_collection("iterator_ops", init_op)
+ # `get_next` may be a tuple e.g. in TensorSliceDataset. Since Collections
+ # do not support tuples we flatten the tensors and restore the shape in
+ # `_get_iterator_ops_from_collection`.
+ if sparse_tensors: # specific for deprecated `from_sparse_tensor_slices`.
+ ops.add_to_collection("iterator_ops", get_next.indices)
+ ops.add_to_collection("iterator_ops", get_next.values)
+ ops.add_to_collection("iterator_ops", get_next.dense_shape)
+ return
+
+ get_next_list = nest.flatten(get_next)
+ for i, output_class in enumerate(
+ nest.flatten(self._get_output_classes(ds_fn))):
+ if output_class is sparse_tensor.SparseTensor:
+ ops.add_to_collection("iterator_ops", get_next_list[i].indices)
+ ops.add_to_collection("iterator_ops", get_next_list[i].values)
+ ops.add_to_collection("iterator_ops", get_next_list[i].dense_shape)
+ else:
+ ops.add_to_collection("iterator_ops", get_next_list[i])
+
+ def _get_iterator_ops_from_collection(self, ds_fn, sparse_tensors=False):
+ all_ops = ops.get_collection("iterator_ops")
+ if sparse_tensors: # specific for deprecated `from_sparse_tensor_slices`.
+ init_op, indices, values, dense_shape = all_ops
+ return init_op, sparse_tensor.SparseTensor(indices, values, dense_shape)
+ get_next_list = []
+ i = 1
+ for output_class in nest.flatten(self._get_output_classes(ds_fn)):
+ if output_class is sparse_tensor.SparseTensor:
+ indices, values, dense_shape = all_ops[i:i + 3]
+ i += 3
+ get_next_list.append(
+ sparse_tensor.SparseTensor(indices, values, dense_shape))
+ else:
+ get_next_list.append(all_ops[i])
+ i += 1
+ return all_ops[0], nest.pack_sequence_as(
+ self._get_output_types(ds_fn), get_next_list)
+
+ def _get_output_types(self, ds_fn):
+ with ops.Graph().as_default():
+ return ds_fn().output_types
+
+ def _get_output_shapes(self, ds_fn):
+ with ops.Graph().as_default():
+ return ds_fn().output_shapes
+
+ def _get_output_classes(self, ds_fn):
+ with ops.Graph().as_default():
+ return ds_fn().output_classes
+
+ def _ckpt_path(self):
+ return os.path.join(self.get_temp_dir(), "iterator")
+
+ def _latest_ckpt(self):
+ return checkpoint_management.latest_checkpoint(self.get_temp_dir())
+
+ def _save(self, sess, saver):
+ saver.save(sess, self._ckpt_path())
+
+ def _restore(self, saver, sess):
+ sess.run(lookup_ops.tables_initializer())
+ saver.restore(sess, self._latest_ckpt())
+
+ def _initialize(self, init_op, sess):
+ sess.run(variables.global_variables_initializer())
+ sess.run(lookup_ops.tables_initializer())
+ sess.run(init_op)
+
+ def _import_meta_graph(self):
+ meta_file_path = self._ckpt_path() + ".meta"
+ return saver_lib.import_meta_graph(meta_file_path)
+
+ def _delete_ckpt(self):
+ # Remove all checkpoint files.
+ prefix = self._ckpt_path()
+ pattern = prefix + "*"
+ files = gfile.Glob(pattern)
+ map(gfile.Remove, files)
diff --git a/tensorflow/python/data/experimental/kernel_tests/directed_interleave_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/directed_interleave_dataset_test.py
new file mode 100644
index 0000000000..796a692c56
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/directed_interleave_dataset_test.py
@@ -0,0 +1,148 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.ops import interleave_ops
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import random_seed
+from tensorflow.python.platform import test
+
+
+class DirectedInterleaveDatasetTest(test_base.DatasetTestBase):
+
+ def testBasic(self):
+ selector_dataset = dataset_ops.Dataset.range(10).repeat(100)
+ input_datasets = [
+ dataset_ops.Dataset.from_tensors(i).repeat(100) for i in range(10)
+ ]
+ dataset = interleave_ops._DirectedInterleaveDataset(selector_dataset,
+ input_datasets)
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ for _ in range(100):
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def _normalize(self, vec):
+ return vec / vec.sum()
+
+ def _chi2(self, expected, actual):
+ actual = np.asarray(actual)
+ expected = np.asarray(expected)
+ diff = actual - expected
+ chi2 = np.sum(diff * diff / expected, axis=0)
+ return chi2
+
+ def _testSampleFromDatasetsHelper(self, weights, num_datasets, num_samples):
+ # Create a dataset that samples each integer in `[0, num_datasets)`
+ # with probability given by `weights[i]`.
+ dataset = interleave_ops.sample_from_datasets([
+ dataset_ops.Dataset.from_tensors(i).repeat(None)
+ for i in range(num_datasets)
+ ], weights)
+ dataset = dataset.take(num_samples)
+ iterator = dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ freqs = np.zeros([num_datasets])
+ for _ in range(num_samples):
+ freqs[sess.run(next_element)] += 1
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ return freqs
+
+ def testSampleFromDatasets(self):
+ random_seed.set_random_seed(1619)
+ num_samples = 5000
+ rand_probs = self._normalize(np.random.random_sample((15,)))
+
+ # Use chi-squared test to assert that the observed distribution matches the
+ # expected distribution. Based on the implementation in
+ # "third_party/tensorflow/python/kernel_tests/multinomial_op_test.py".
+ for probs in [[.85, .05, .1], rand_probs, [1.]]:
+ probs = np.asarray(probs)
+ classes = len(probs)
+ freqs = self._testSampleFromDatasetsHelper(probs, classes, num_samples)
+ self.assertLess(self._chi2(probs, freqs / num_samples), 1e-2)
+
+ # Also check that `weights` as a dataset samples correctly.
+ probs_ds = dataset_ops.Dataset.from_tensors(probs).repeat()
+ freqs = self._testSampleFromDatasetsHelper(probs_ds, classes, num_samples)
+ self.assertLess(self._chi2(probs, freqs / num_samples), 1e-2)
+
+ def testSelectFromDatasets(self):
+ words = [b"foo", b"bar", b"baz"]
+ datasets = [dataset_ops.Dataset.from_tensors(w).repeat() for w in words]
+ choice_array = np.random.randint(3, size=(15,), dtype=np.int64)
+ choice_dataset = dataset_ops.Dataset.from_tensor_slices(choice_array)
+ dataset = interleave_ops.choose_from_datasets(datasets, choice_dataset)
+ iterator = dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ for i in choice_array:
+ self.assertEqual(words[i], sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testErrors(self):
+ with self.assertRaisesRegexp(ValueError,
+ r"vector of length `len\(datasets\)`"):
+ interleave_ops.sample_from_datasets(
+ [dataset_ops.Dataset.range(10),
+ dataset_ops.Dataset.range(20)],
+ weights=[0.25, 0.25, 0.25, 0.25])
+
+ with self.assertRaisesRegexp(TypeError, "`tf.float32` or `tf.float64`"):
+ interleave_ops.sample_from_datasets(
+ [dataset_ops.Dataset.range(10),
+ dataset_ops.Dataset.range(20)],
+ weights=[1, 1])
+
+ with self.assertRaisesRegexp(TypeError, "must have the same type"):
+ interleave_ops.sample_from_datasets([
+ dataset_ops.Dataset.from_tensors(0),
+ dataset_ops.Dataset.from_tensors(0.0)
+ ])
+
+ with self.assertRaisesRegexp(TypeError, "tf.int64"):
+ interleave_ops.choose_from_datasets([
+ dataset_ops.Dataset.from_tensors(0),
+ dataset_ops.Dataset.from_tensors(1)
+ ], choice_dataset=dataset_ops.Dataset.from_tensors(1.0))
+
+ with self.assertRaisesRegexp(TypeError, "scalar"):
+ interleave_ops.choose_from_datasets([
+ dataset_ops.Dataset.from_tensors(0),
+ dataset_ops.Dataset.from_tensors(1)
+ ], choice_dataset=dataset_ops.Dataset.from_tensors([1.0]))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/filter_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/filter_dataset_op_test.py
new file mode 100644
index 0000000000..c6ee88c676
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/filter_dataset_op_test.py
@@ -0,0 +1,76 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Benchmarks FilterDataset input pipeline op."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+import numpy as np
+
+from tensorflow.python.client import session
+from tensorflow.python.data.experimental.ops import optimization
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class FilterBenchmark(test.Benchmark):
+
+ # This benchmark compares the performance of pipeline with multiple chained
+ # filter with and without filter fusion.
+ def benchmarkFilters(self):
+ chain_lengths = [0, 1, 2, 5, 10, 20, 50]
+ for chain_length in chain_lengths:
+ self._benchmarkFilters(chain_length, False)
+ self._benchmarkFilters(chain_length, True)
+
+ def _benchmarkFilters(self, chain_length, optimize_dataset):
+ with ops.Graph().as_default():
+ dataset = dataset_ops.Dataset.from_tensors(5).repeat(None)
+ for _ in range(chain_length):
+ dataset = dataset.filter(lambda x: math_ops.greater_equal(x - 5, 0))
+ if optimize_dataset:
+ dataset = dataset.apply(optimization.optimize(["filter_fusion"]))
+
+ iterator = dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ with session.Session() as sess:
+ for _ in range(10):
+ sess.run(next_element.op)
+ deltas = []
+ for _ in range(100):
+ start = time.time()
+ for _ in range(100):
+ sess.run(next_element.op)
+ end = time.time()
+ deltas.append(end - start)
+
+ median_wall_time = np.median(deltas) / 100
+ opt_mark = "opt" if optimize_dataset else "no-opt"
+ print("Filter dataset {} chain length: {} Median wall time: {}".format(
+ opt_mark, chain_length, median_wall_time))
+ self.report_benchmark(
+ iters=1000,
+ wall_time=median_wall_time,
+ name="benchmark_filter_dataset_chain_latency_{}_{}".format(
+ opt_mark, chain_length))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/get_single_element_test.py b/tensorflow/python/data/experimental/kernel_tests/get_single_element_test.py
new file mode 100644
index 0000000000..8c07afbac5
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/get_single_element_test.py
@@ -0,0 +1,72 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+
+from tensorflow.python.data.experimental.ops import get_single_element
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class GetSingleElementTest(test_base.DatasetTestBase, parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ ("Zero", 0, 1),
+ ("Five", 5, 1),
+ ("Ten", 10, 1),
+ ("Empty", 100, 1, errors.InvalidArgumentError, "Dataset was empty."),
+ ("MoreThanOne", 0, 2, errors.InvalidArgumentError,
+ "Dataset had more than one element."),
+ )
+ def testGetSingleElement(self, skip, take, error=None, error_msg=None):
+ skip_t = array_ops.placeholder(dtypes.int64, shape=[])
+ take_t = array_ops.placeholder(dtypes.int64, shape=[])
+
+ def make_sparse(x):
+ x_1d = array_ops.reshape(x, [1])
+ x_2d = array_ops.reshape(x, [1, 1])
+ return sparse_tensor.SparseTensor(x_2d, x_1d, x_1d)
+
+ dataset = dataset_ops.Dataset.range(100).skip(skip_t).map(
+ lambda x: (x * x, make_sparse(x))).take(take_t)
+ element = get_single_element.get_single_element(dataset)
+
+ with self.cached_session() as sess:
+ if error is None:
+ dense_val, sparse_val = sess.run(
+ element, feed_dict={
+ skip_t: skip,
+ take_t: take
+ })
+ self.assertEqual(skip * skip, dense_val)
+ self.assertAllEqual([[skip]], sparse_val.indices)
+ self.assertAllEqual([skip], sparse_val.values)
+ self.assertAllEqual([skip], sparse_val.dense_shape)
+ else:
+ with self.assertRaisesRegexp(error, error_msg):
+ sess.run(element, feed_dict={skip_t: skip, take_t: take})
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/indexed_dataset_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/indexed_dataset_ops_test.py
new file mode 100644
index 0000000000..c93a8353ce
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/indexed_dataset_ops_test.py
@@ -0,0 +1,79 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for experimental indexed dataset ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import unittest
+
+from tensorflow.python.data.experimental.ops import indexed_dataset_ops
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
+from tensorflow.python.platform import test
+
+
+class IndexedDatasetOpsTest(test_base.DatasetTestBase):
+
+ def testLowLevelIndexedDatasetOps(self):
+ identity = ged_ops.experimental_identity_indexed_dataset(
+ ops.convert_to_tensor(16, dtype=dtypes.uint64))
+ handle = ged_ops.experimental_materialized_index_dataset_handle(
+ container="",
+ shared_name="",
+ output_types=[dtypes.uint64],
+ output_shapes=[[]])
+ materialize = ged_ops.experimental_indexed_dataset_materialize(
+ identity, handle)
+ index = array_ops.placeholder(dtypes.uint64)
+ get_op = ged_ops.experimental_indexed_dataset_get(
+ handle, index, output_types=[dtypes.uint64], output_shapes=[[]])
+
+ with self.cached_session() as sess:
+ sess.run(materialize)
+ self.assertEqual([3], sess.run(get_op, feed_dict={index: 3}))
+
+ def testIdentityIndexedDataset(self):
+ ds = indexed_dataset_ops.IdentityIndexedDataset(16)
+ materialized = ds.materialize()
+ with self.cached_session() as sess:
+ sess.run(materialized.initializer)
+ placeholder = array_ops.placeholder(dtypes.uint64, shape=[])
+ for i in range(16):
+ output = sess.run(
+ materialized.get(placeholder), feed_dict={placeholder: i})
+ self.assertEqual([i], output)
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(materialized.get(placeholder), feed_dict={placeholder: 16})
+
+ @unittest.skip("Requisite functionality currently unimplemented.")
+ def testIdentityIndexedDatasetIterator(self):
+ ds = indexed_dataset_ops.IdentityIndexedDataset(16)
+ itr = ds.make_initializable_iterator()
+ n = itr.get_next()
+ with self.cached_session() as sess:
+ sess.run(itr.initializer)
+ for i in range(16):
+ output = sess.run(n)
+ self.assertEqual(i, output)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(n)
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/interleave_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/interleave_dataset_op_test.py
new file mode 100644
index 0000000000..560902caad
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/interleave_dataset_op_test.py
@@ -0,0 +1,811 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import itertools
+import math
+import threading
+import time
+
+from six.moves import zip_longest
+
+from tensorflow.python.data.experimental.ops import interleave_ops
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import script_ops
+from tensorflow.python.ops import sparse_ops
+from tensorflow.python.platform import test
+
+
+class ParallelInterleaveDatasetTest(test_base.DatasetTestBase):
+
+ def setUp(self):
+
+ self.input_values = array_ops.placeholder(dtypes.int64, shape=[None])
+ self.cycle_length = array_ops.placeholder(dtypes.int64, shape=[])
+ self.block_length = array_ops.placeholder(dtypes.int64, shape=[])
+ self.sloppy = array_ops.placeholder(dtypes.bool, shape=[])
+ self.buffer_output_elements = array_ops.placeholder(dtypes.int64, shape=[])
+ self.prefetch_input_elements = array_ops.placeholder(dtypes.int64, shape=[])
+
+ self.error = None
+ self.repeat_count = 2
+
+ # Set up threading events used to sequence when items are produced that
+ # are subsequently interleaved. These events allow us to deterministically
+ # simulate slowdowns and force sloppiness.
+ self.read_coordination_events = {}
+ self.write_coordination_events = {}
+ # input values [4, 5, 6] are the common case for the tests; set defaults
+ for i in range(4, 7):
+ self.read_coordination_events[i] = threading.Semaphore(0)
+ self.write_coordination_events[i] = threading.Event()
+
+ def map_py_fn(x):
+ self.write_coordination_events[x].wait()
+ self.write_coordination_events[x].clear()
+ self.read_coordination_events[x].release()
+ if self.error:
+ err = self.error
+ self.error = None
+ raise err # pylint: disable=raising-bad-type
+ return x * x
+
+ def map_fn(x):
+ return script_ops.py_func(map_py_fn, [x], x.dtype)
+
+ def interleave_fn(x):
+ dataset = dataset_ops.Dataset.from_tensors(x)
+ dataset = dataset.repeat(x)
+ return dataset.map(map_fn)
+
+ self.dataset = (
+ dataset_ops.Dataset.from_tensor_slices(self.input_values)
+ .repeat(self.repeat_count).apply(
+ interleave_ops.parallel_interleave(interleave_fn, self.cycle_length,
+ self.block_length, self.sloppy,
+ self.buffer_output_elements,
+ self.prefetch_input_elements)))
+ self.iterator = self.dataset.make_initializable_iterator()
+ self.init_op = self.iterator.initializer
+ self.next_element = self.iterator.get_next()
+
+ def _interleave(self, lists, cycle_length, block_length):
+ """Python implementation of interleave used for testing."""
+ num_open = 0
+
+ # `all_iterators` acts as a queue of iterators over each element of `lists`.
+ all_iterators = [iter(l) for l in lists]
+
+ # `open_iterators` are the iterators whose elements are currently being
+ # interleaved.
+ open_iterators = []
+ for i in range(cycle_length):
+ if all_iterators:
+ open_iterators.append(all_iterators.pop(0))
+ num_open += 1
+ else:
+ open_iterators.append(None)
+
+ while num_open or all_iterators:
+ for i in range(cycle_length):
+ if open_iterators[i] is None:
+ if all_iterators:
+ open_iterators[i] = all_iterators.pop(0)
+ num_open += 1
+ else:
+ continue
+ for _ in range(block_length):
+ try:
+ yield next(open_iterators[i])
+ except StopIteration:
+ open_iterators[i] = None
+ num_open -= 1
+ break
+
+ def testPythonImplementation(self):
+ input_lists = [[4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6],
+ [4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]]
+
+ # Cycle length 1 acts like `Dataset.flat_map()`.
+ expected_elements = itertools.chain(*input_lists)
+ for expected, produced in zip(expected_elements,
+ self._interleave(input_lists, 1, 1)):
+ self.assertEqual(expected, produced)
+
+ # Cycle length > 1.
+ expected_elements = [
+ 4, 5, 4, 5, 4, 5, 4, 5, 5, 6, 6, 4, 6, 4, 6, 4, 6, 4, 6, 5, 6, 5, 6, 5,
+ 6, 5, 6, 5, 6, 6
+ ]
+ for index, (expected, produced) in enumerate(
+ zip_longest(expected_elements, self._interleave(input_lists, 2, 1))):
+ self.assertEqual(expected, produced, "Values differ at %s. %s != %s" %
+ (index, expected, produced))
+
+ def testPythonImplementationBlockLength(self):
+ input_lists = [[4] * 4, [5] * 5, [6] * 6] * 2
+ expected_elements = [
+ 4, 4, 5, 5, 4, 4, 5, 5, 5, 6, 6, 4, 4, 6, 6, 4, 4, 6, 6, 5, 5, 6, 6, 5,
+ 5, 6, 6, 5, 6, 6
+ ]
+ for index, (expected, produced) in enumerate(
+ zip_longest(expected_elements, self._interleave(input_lists, 2, 2))):
+ self.assertEqual(expected, produced, "Values differ at %s. %s != %s" %
+ (index, expected, produced))
+
+ def testPythonImplementationEmptyLists(self):
+ input_lists = [[4, 4, 4, 4], [], [6, 6, 6, 6, 6, 6], [4, 4, 4, 4], [],
+ [6, 6, 6, 6, 6, 6]]
+
+ expected_elements = [
+ 4, 4, 6, 4, 6, 4, 6, 6, 4, 6, 4, 6, 4, 4, 6, 6, 6, 6, 6, 6
+ ]
+ for index, (expected, produced) in enumerate(
+ zip_longest(expected_elements, self._interleave(input_lists, 2, 1))):
+ self.assertEqual(expected, produced, "Values differ at %s. %s != %s" %
+ (index, expected, produced))
+
+ def _clear_coordination_events(self):
+ for i in range(4, 7):
+ self.read_coordination_events[i] = threading.Semaphore(0)
+ self.write_coordination_events[i].clear()
+
+ def _allow_all_map_threads(self):
+ for i in range(4, 7):
+ self.write_coordination_events[i].set()
+
+ def _testSingleThreaded(self, sloppy=False, prefetch_input_elements=0):
+ # cycle_length=1,block_length=1 acts like `Dataset.interleave()` and
+ # `Dataset.flat_map()` and is single-threaded. No synchronization required.
+ with self.cached_session() as sess:
+ self._clear_coordination_events()
+ sess.run(
+ self.init_op,
+ feed_dict={
+ self.input_values: [4, 5, 6],
+ self.cycle_length: 1,
+ self.block_length: 1,
+ self.sloppy: sloppy,
+ self.buffer_output_elements: 1,
+ self.prefetch_input_elements: prefetch_input_elements,
+ })
+
+ for expected_element in self._interleave(
+ [[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 1, 1):
+ self.write_coordination_events[expected_element].set()
+ self.assertEqual(expected_element * expected_element,
+ sess.run(self.next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(self.next_element)
+
+ def testSingleThreaded(self):
+ self._testSingleThreaded()
+
+ def testSingleThreadedSloppy(self):
+ self._testSingleThreaded(sloppy=True)
+
+ def testSingleThreadedPrefetch1Itr(self):
+ self._testSingleThreaded(prefetch_input_elements=1)
+
+ def testSingleThreadedPrefetch1ItrSloppy(self):
+ self._testSingleThreaded(prefetch_input_elements=1, sloppy=True)
+
+ def testSingleThreadedRagged(self):
+ # Tests a sequence with wildly different elements per iterator.
+ with self.cached_session() as sess:
+ self._clear_coordination_events()
+ sess.run(
+ self.init_op,
+ feed_dict={
+ self.input_values: [3, 7, 4],
+ self.cycle_length: 2,
+ self.block_length: 1,
+ self.sloppy: False,
+ self.buffer_output_elements: 1,
+ self.prefetch_input_elements: 1,
+ })
+
+ # Add coordination values for 3 and 7
+ self.read_coordination_events[3] = threading.Semaphore(0)
+ self.write_coordination_events[3] = threading.Event()
+ self.read_coordination_events[7] = threading.Semaphore(0)
+ self.write_coordination_events[7] = threading.Event()
+
+ for expected_element in self._interleave(
+ [[3] * 3, [7] * 7, [4] * 4] * self.repeat_count, 2, 1):
+ self.write_coordination_events[expected_element].set()
+ output = sess.run(self.next_element)
+ self.assertEqual(expected_element * expected_element, output)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(self.next_element)
+
+ def _testTwoThreadsNoContention(self, sloppy=False):
+ # num_threads > 1.
+ # Explicit coordination should result in `Dataset.interleave()` behavior
+ with self.cached_session() as sess:
+ self._clear_coordination_events()
+ done_first_event = False
+ sess.run(
+ self.init_op,
+ feed_dict={
+ self.input_values: [4, 5, 6],
+ self.cycle_length: 2,
+ self.block_length: 1,
+ self.sloppy: sloppy,
+ self.buffer_output_elements: 1,
+ self.prefetch_input_elements: 1,
+ })
+ for i, expected_element in enumerate(
+ self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
+ 1)):
+ self.write_coordination_events[expected_element].set()
+ if done_first_event: # First event starts the worker threads.
+ self.read_coordination_events[expected_element].acquire()
+ actual_element = sess.run(self.next_element)
+ if not done_first_event:
+ self.read_coordination_events[expected_element].acquire()
+ done_first_event = True
+ self.assertEqual(expected_element * expected_element, actual_element,
+ "At index %s: %s expected, got: %s" %
+ (i, expected_element, actual_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(self.next_element)
+
+ def testTwoThreadsNoContention(self):
+ self._testTwoThreadsNoContention()
+
+ def testTwoThreadsNoContentionSloppy(self):
+ self._testTwoThreadsNoContention(sloppy=True)
+
+ def _testTwoThreadsNoContentionWithRaces(self, sloppy=False):
+ """Tests where all the workers race in producing elements.
+
+ Note: this is in contrast with the previous test which carefully sequences
+ the execution of the map functions.
+
+ Args:
+ sloppy: Whether to be sloppy or not.
+ """
+ with self.cached_session() as sess:
+ self._clear_coordination_events()
+ done_first_event = False
+ sess.run(
+ self.init_op,
+ feed_dict={
+ self.input_values: [4, 5, 6],
+ self.cycle_length: 2,
+ self.block_length: 1,
+ self.sloppy: sloppy,
+ self.buffer_output_elements: 1,
+ self.prefetch_input_elements: 1,
+ })
+ for i, expected_element in enumerate(
+ self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
+ 1)):
+ if done_first_event: # First event starts the worker threads.
+ self._allow_all_map_threads()
+ self.read_coordination_events[expected_element].acquire()
+ else:
+ self.write_coordination_events[expected_element].set()
+ time.sleep(0.5) # Sleep to consistently "avoid" the race condition.
+ actual_element = sess.run(self.next_element)
+ if not done_first_event:
+ done_first_event = True
+ self.assertTrue(
+ self.read_coordination_events[expected_element].acquire(False))
+ self.assertEqual(expected_element * expected_element, actual_element,
+ "At index %s: %s expected, got: %s" %
+ (i, expected_element, actual_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(self.next_element)
+
+ def testTwoThreadsNoContentionWithRaces(self):
+ self._testTwoThreadsNoContentionWithRaces()
+
+ def testTwoThreadsNoContentionWithRacesSloppy(self):
+ self._testTwoThreadsNoContentionWithRaces(sloppy=True)
+
+ def _testTwoThreadsNoContentionBlockLength(self, sloppy=False):
+ # num_threads > 1.
+ # Explicit coordination should result in `Dataset.interleave()` behavior
+ with self.cached_session() as sess:
+ self._clear_coordination_events()
+ done_first_event = False
+ sess.run(
+ self.init_op,
+ feed_dict={
+ self.input_values: [4, 5, 6],
+ self.cycle_length: 2,
+ self.block_length: 2,
+ self.sloppy: sloppy,
+ self.buffer_output_elements: 1,
+ self.prefetch_input_elements: 1,
+ })
+ for i, expected_element in enumerate(
+ self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
+ 2)):
+ self.write_coordination_events[expected_element].set()
+ if done_first_event: # First event starts the worker threads.
+ self.read_coordination_events[expected_element].acquire()
+ actual_element = sess.run(self.next_element)
+ if not done_first_event:
+ done_first_event = True
+ self.read_coordination_events[expected_element].acquire()
+ self.assertEqual(expected_element * expected_element, actual_element,
+ "At index %s: %s expected, got: %s" %
+ (i, expected_element, actual_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(self.next_element)
+
+ def testTwoThreadsNoContentionBlockLength(self):
+ self._testTwoThreadsNoContentionBlockLength()
+
+ def testTwoThreadsNoContentionBlockLengthSloppy(self):
+ self._testTwoThreadsNoContentionBlockLength(sloppy=True)
+
+ def _testTwoThreadsNoContentionWithRacesAndBlocking(self, sloppy=False):
+ """Tests where all the workers race in producing elements.
+
+ Note: this is in contrast with the previous test which carefully sequences
+ the execution of the map functions.
+
+
+ Args:
+ sloppy: Whether to be sloppy or not.
+ """
+ with self.cached_session() as sess:
+ self._clear_coordination_events()
+ done_first_event = False
+ sess.run(
+ self.init_op,
+ feed_dict={
+ self.input_values: [4, 5, 6],
+ self.cycle_length: 2,
+ self.block_length: 2,
+ self.sloppy: sloppy,
+ self.buffer_output_elements: 1,
+ self.prefetch_input_elements: 1,
+ })
+ for i, expected_element in enumerate(
+ self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
+ 2)):
+ if done_first_event: # First event starts the worker threads.
+ self._allow_all_map_threads()
+ self.read_coordination_events[expected_element].acquire()
+ else:
+ self.write_coordination_events[expected_element].set()
+ time.sleep(0.5) # Sleep to consistently "avoid" the race condition.
+ actual_element = sess.run(self.next_element)
+ if not done_first_event:
+ done_first_event = True
+ self.assertTrue(
+ self.read_coordination_events[expected_element].acquire(False))
+ self.assertEqual(expected_element * expected_element, actual_element,
+ "At index %s: %s expected, got: %s" %
+ (i, expected_element, actual_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(self.next_element)
+
+ def testTwoThreadsNoContentionWithRacesAndBlocking(self):
+ self._testTwoThreadsNoContentionWithRacesAndBlocking()
+
+ def testTwoThreadsNoContentionWithRacesAndBlockingSloppy(self):
+ self._testTwoThreadsNoContentionWithRacesAndBlocking(sloppy=True)
+
+ def _testEmptyInput(self, sloppy=False):
+ with self.cached_session() as sess:
+ # Empty input.
+ self._clear_coordination_events()
+ sess.run(
+ self.init_op,
+ feed_dict={
+ self.input_values: [],
+ self.cycle_length: 2,
+ self.block_length: 3,
+ self.sloppy: sloppy,
+ self.buffer_output_elements: 1,
+ self.prefetch_input_elements: 0,
+ })
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(self.next_element)
+
+ def testEmptyInput(self):
+ self._testEmptyInput()
+
+ def testEmptyInputSloppy(self):
+ self._testEmptyInput(sloppy=True)
+
+ def _testNonEmptyInputIntoEmptyOutputs(self, sloppy=False):
+ # Non-empty input leading to empty output.
+ with self.cached_session() as sess:
+ self._clear_coordination_events()
+ sess.run(
+ self.init_op,
+ feed_dict={
+ self.input_values: [0, 0, 0],
+ self.cycle_length: 2,
+ self.block_length: 3,
+ self.sloppy: sloppy,
+ self.buffer_output_elements: 1,
+ self.prefetch_input_elements: 0,
+ })
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(self.next_element)
+
+ def testNonEmptyInputIntoEmptyOutputs(self):
+ self._testNonEmptyInputIntoEmptyOutputs()
+
+ def testNonEmptyInputIntoEmptyOutputsSloppy(self):
+ self._testNonEmptyInputIntoEmptyOutputs(sloppy=True)
+
+ def _testPartiallyEmptyOutputs(self, sloppy=False, prefetch_input_elements=1):
+ race_indices = {2, 8, 14} # Sequence points when sloppy mode has race conds
+ # Mixture of non-empty and empty interleaved datasets.
+ with self.cached_session() as sess:
+ self._clear_coordination_events()
+ done_first_event = False
+ sess.run(
+ self.init_op,
+ feed_dict={
+ self.input_values: [4, 0, 6],
+ self.cycle_length: 2,
+ self.block_length: 1,
+ self.sloppy: sloppy,
+ self.buffer_output_elements: 1,
+ self.prefetch_input_elements: prefetch_input_elements,
+ })
+ for i, expected_element in enumerate(
+ self._interleave([[4] * 4, [], [6] * 6] * self.repeat_count, 2, 1)):
+ self.write_coordination_events[expected_element].set()
+ # First event starts the worker threads. Additionally, when running the
+ # sloppy case with prefetch_input_elements=0, we get stuck if we wait
+ # for the read coordination event for certain event orderings in the
+ # presence of finishing iterators.
+ if done_first_event and not (sloppy and (i in race_indices)):
+ self.read_coordination_events[expected_element].acquire()
+ actual_element = sess.run(self.next_element)
+ if not done_first_event or (sloppy and (i in race_indices)):
+ done_first_event = True
+ self.read_coordination_events[expected_element].acquire()
+ self.assertEqual(expected_element * expected_element, actual_element,
+ "At index %s: %s expected, got: %s" %
+ (i, expected_element, actual_element))
+
+ def testPartiallyEmptyOutputs(self):
+ self._testPartiallyEmptyOutputs()
+
+ def testPartiallyEmptyOutputsSloppy(self):
+ self._testPartiallyEmptyOutputs(sloppy=True, prefetch_input_elements=0)
+
+ def testDelayedOutputSloppy(self):
+ # Explicitly control the sequence of events to ensure we correctly avoid
+ # head-of-line blocking.
+ with self.cached_session() as sess:
+ self._clear_coordination_events()
+ sess.run(
+ self.init_op,
+ feed_dict={
+ self.input_values: [4, 5, 6],
+ self.cycle_length: 2,
+ self.block_length: 1,
+ self.sloppy: True,
+ self.buffer_output_elements: 1,
+ self.prefetch_input_elements: 0,
+ })
+
+ mis_ordering = [
+ 4, 4, 5, 4, 5, 5, 4, 5, 6, 6, 6, 5, 4, 4, 6, 6, 4, 4, 6, 5, 6, 6, 6,
+ 6, 5, 5, 5, 5, 6, 6
+ ]
+ for element in mis_ordering:
+ self.write_coordination_events[element].set()
+ self.assertEqual(element * element, sess.run(self.next_element))
+ self.assertTrue(self.read_coordination_events[element].acquire(False))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(self.next_element)
+
+ def testBlockLengthWithContentionSloppy(self):
+ with self.cached_session() as sess:
+ self._clear_coordination_events()
+ done_first_event = False
+ sess.run(
+ self.init_op,
+ feed_dict={
+ self.input_values: [4, 5, 6],
+ self.cycle_length: 2,
+ self.block_length: 1,
+ self.sloppy: True,
+ self.buffer_output_elements: 1,
+ self.prefetch_input_elements: 1,
+ })
+ # Test against a generating sequence that differs from the uncontended
+ # case, in order to prove sloppy correctness.
+ for i, expected_element in enumerate(
+ self._interleave(
+ [[4] * 4, [5] * 5, [6] * 6] * self.repeat_count,
+ cycle_length=2,
+ block_length=3)):
+ self.write_coordination_events[expected_element].set()
+ if done_first_event: # First event starts the worker threads.
+ self.read_coordination_events[expected_element].acquire()
+ actual_element = sess.run(self.next_element)
+ if not done_first_event:
+ self.read_coordination_events[expected_element].acquire()
+ done_first_event = True
+ self.assertEqual(expected_element * expected_element, actual_element,
+ "At index %s: %s expected, got: %s" %
+ (i, expected_element, actual_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(self.next_element)
+
+ def _testEarlyExit(self, sloppy=False):
+ # Exiting without consuming all input should not block
+ with self.cached_session() as sess:
+ self._clear_coordination_events()
+ sess.run(
+ self.init_op,
+ feed_dict={
+ self.input_values: [4, 5, 6],
+ self.cycle_length: 3,
+ self.block_length: 2,
+ self.sloppy: sloppy,
+ self.buffer_output_elements: 1,
+ self.prefetch_input_elements: 0,
+ })
+ for i in range(4, 7):
+ self.write_coordination_events[i].set()
+ elem = sess.run(self.next_element) # Start all workers
+ # Allow the one successful worker to progress beyond the py_func again.
+ elem = int(math.sqrt(elem))
+ self.write_coordination_events[elem].set()
+ self.read_coordination_events[elem].acquire()
+ # Allow the prefetch to succeed
+ for i in range(4, 7):
+ self.read_coordination_events[i].acquire()
+ self.write_coordination_events[i].set()
+
+ def testEarlyExit(self):
+ self._testEarlyExit()
+
+ def testEarlyExitSloppy(self):
+ self._testEarlyExit(sloppy=True)
+
+ def _testTooManyReaders(self, sloppy=False):
+
+ def interleave_fn(x):
+ dataset = dataset_ops.Dataset.from_tensors(x)
+ dataset = dataset.repeat(math_ops.cast(x, dtype=dtypes.int64))
+ return dataset
+
+ dataset = dataset_ops.Dataset.from_tensor_slices([4, 5, 6])
+ dataset = dataset.repeat(self.repeat_count)
+ dataset = dataset.apply(
+ interleave_ops.parallel_interleave(
+ interleave_fn, cycle_length=16, block_length=2, sloppy=sloppy))
+ iterator = dataset.make_one_shot_iterator()
+
+ with self.cached_session() as sess:
+ output_values = []
+ for _ in range(30):
+ output_values.append(sess.run(iterator.get_next()))
+
+ expected_values = self._interleave(
+ [[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 1, 2)
+ self.assertItemsEqual(output_values, expected_values)
+
+ def testTooManyReaders(self):
+ self._testTooManyReaders()
+
+ def testTooManyReadersSloppy(self):
+ self._testTooManyReaders(sloppy=True)
+
+ def testSparse(self):
+ def _map_fn(i):
+ return sparse_tensor.SparseTensor(
+ indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2])
+
+ def _interleave_fn(x):
+ return dataset_ops.Dataset.from_tensor_slices(
+ sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values))
+
+ dataset = dataset_ops.Dataset.range(10).map(_map_fn)
+ iterator = dataset.apply(
+ interleave_ops.parallel_interleave(
+ _interleave_fn, cycle_length=1)).make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ for i in range(10):
+ for j in range(2):
+ expected = [i, 0] if j % 2 == 0 else [0, -i]
+ self.assertAllEqual(expected, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testErrorsInOutputFn(self):
+ with self.cached_session() as sess:
+ self._clear_coordination_events()
+ sess.run(
+ self.init_op,
+ feed_dict={
+ self.input_values: [4, 5, 6],
+ self.cycle_length: 2,
+ self.block_length: 1,
+ self.sloppy: False,
+ self.buffer_output_elements: 1,
+ self.prefetch_input_elements: 0,
+ })
+
+ except_on_element_indices = set([3])
+
+ for i, expected_element in enumerate(
+ self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
+ 1)):
+ if i in except_on_element_indices:
+ self.error = ValueError()
+ self.write_coordination_events[expected_element].set()
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(self.next_element)
+ else:
+ self.write_coordination_events[expected_element].set()
+ actual_element = sess.run(self.next_element)
+ self.assertEqual(expected_element * expected_element, actual_element,
+ "At index %s: %s expected, got: %s" %
+ (i, expected_element, actual_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(self.next_element)
+
+ def testErrorsInInputFn(self):
+
+ def map_py_fn(x):
+ if x == 5:
+ raise ValueError()
+ return x
+
+ def map_fn(x):
+ return script_ops.py_func(map_py_fn, [x], x.dtype)
+
+ def interleave_fn(x):
+ dataset = dataset_ops.Dataset.from_tensors(x)
+ dataset = dataset.repeat(x)
+ return dataset
+
+ self.dataset = (
+ dataset_ops.Dataset.from_tensor_slices(self.input_values).map(map_fn)
+ .repeat(self.repeat_count).apply(
+ interleave_ops.parallel_interleave(interleave_fn, self.cycle_length,
+ self.block_length, self.sloppy,
+ self.buffer_output_elements,
+ self.prefetch_input_elements)))
+
+ self.iterator = self.dataset.make_initializable_iterator()
+ self.init_op = self.iterator.initializer
+ self.next_element = self.iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(
+ self.init_op,
+ feed_dict={
+ self.input_values: [4, 5, 6],
+ self.cycle_length: 2,
+ self.block_length: 1,
+ self.sloppy: False,
+ self.buffer_output_elements: 1,
+ self.prefetch_input_elements: 0,
+ })
+ for i, expected_element in enumerate(
+ self._interleave([[4] * 4, [5], [6] * 6] * self.repeat_count, 2, 1)):
+ if expected_element == 5:
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(self.next_element)
+ else:
+ actual_element = sess.run(self.next_element)
+ self.assertEqual(expected_element, actual_element,
+ "At index %s: %s expected, got: %s" %
+ (i, expected_element, actual_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(self.next_element)
+
+ def testErrorsInInterleaveFn(self):
+
+ def map_py_fn(x):
+ if x == 5:
+ raise ValueError()
+ return x
+
+ def interleave_fn(x):
+ dataset = dataset_ops.Dataset.from_tensors(x)
+ y = script_ops.py_func(map_py_fn, [x], x.dtype)
+ dataset = dataset.repeat(y)
+ return dataset
+
+ self.dataset = (
+ dataset_ops.Dataset.from_tensor_slices(self.input_values)
+ .repeat(self.repeat_count).apply(
+ interleave_ops.parallel_interleave(interleave_fn, self.cycle_length,
+ self.block_length, self.sloppy,
+ self.buffer_output_elements,
+ self.prefetch_input_elements)))
+
+ self.iterator = self.dataset.make_initializable_iterator()
+ self.init_op = self.iterator.initializer
+ self.next_element = self.iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(
+ self.init_op,
+ feed_dict={
+ self.input_values: [4, 5, 6],
+ self.cycle_length: 2,
+ self.block_length: 1,
+ self.sloppy: False,
+ self.buffer_output_elements: 1,
+ self.prefetch_input_elements: 0,
+ })
+ for i, expected_element in enumerate(
+ self._interleave([[4] * 4, [5], [6] * 6] * self.repeat_count, 2, 1)):
+ if expected_element == 5:
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(self.next_element)
+ else:
+ actual_element = sess.run(self.next_element)
+ self.assertEqual(expected_element, actual_element,
+ "At index %s: %s expected, got: %s" %
+ (i, expected_element, actual_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(self.next_element)
+
+ def testShutdownRace(self):
+ dataset = dataset_ops.Dataset.range(20)
+ map_fn = lambda x: dataset_ops.Dataset.range(20 * x, 20 * (x + 1))
+ dataset = dataset.apply(
+ interleave_ops.parallel_interleave(
+ map_fn,
+ cycle_length=3,
+ sloppy=False,
+ buffer_output_elements=1,
+ prefetch_input_elements=0))
+ dataset = dataset.batch(32)
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ results = []
+ with self.cached_session() as sess:
+ for _ in range(2):
+ elements = []
+ sess.run(iterator.initializer)
+ try:
+ while True:
+ elements.extend(sess.run(next_element))
+ except errors.OutOfRangeError:
+ pass
+ results.append(elements)
+
+ self.assertAllEqual(results[0], results[1])
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/iterator_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/iterator_ops_test.py
new file mode 100644
index 0000000000..94393d6d4b
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/iterator_ops_test.py
@@ -0,0 +1,125 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for experimental iterator_ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.ops import iterator_ops
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.estimator import estimator
+from tensorflow.python.estimator import model_fn
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.training import checkpoint_management
+from tensorflow.python.training import saver as saver_lib
+from tensorflow.python.training import training_util
+
+
+class CheckpointInputPipelineHookTest(test_base.DatasetTestBase):
+
+ @staticmethod
+ def _model_fn(features, labels, mode, config):
+ del labels
+ del mode
+ del config
+ global_step = training_util.get_or_create_global_step()
+ update_global_step_op = global_step.assign_add(1)
+ latest_feature = variables.VariableV1(
+ 0, name='latest_feature', dtype=dtypes.int64)
+ store_latest_feature_op = latest_feature.assign(features)
+ ops.add_to_collection('my_vars', global_step)
+ ops.add_to_collection('my_vars', latest_feature)
+ return model_fn.EstimatorSpec(
+ mode='train',
+ train_op=control_flow_ops.group(
+ [update_global_step_op, store_latest_feature_op]),
+ loss=constant_op.constant(2.0))
+
+ def _read_vars(self, model_dir):
+ """Returns (global_step, latest_feature)."""
+ with ops.Graph().as_default() as g:
+ ckpt_path = checkpoint_management.latest_checkpoint(model_dir)
+ meta_filename = ckpt_path + '.meta'
+ saver_lib.import_meta_graph(meta_filename)
+ saver = saver_lib.Saver()
+ with self.session(graph=g) as sess:
+ saver.restore(sess, ckpt_path)
+ return sess.run(ops.get_collection('my_vars'))
+
+ def _build_iterator_saver_hook(self, est):
+ return iterator_ops.CheckpointInputPipelineHook(est)
+
+ def testReturnDatasetFromInputFn(self):
+
+ def _input_fn():
+ return dataset_ops.Dataset.range(10)
+
+ est = estimator.Estimator(model_fn=self._model_fn)
+
+ est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
+ self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1))
+ est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
+ self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3))
+
+ def testBuildIteratorInInputFn(self):
+
+ def _input_fn():
+ ds = dataset_ops.Dataset.range(10)
+ iterator = ds.make_one_shot_iterator()
+ return iterator.get_next()
+
+ est = estimator.Estimator(model_fn=self._model_fn)
+
+ est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
+ self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1))
+ est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
+ self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3))
+
+ def testDoNotRestore(self):
+
+ def _input_fn():
+ return dataset_ops.Dataset.range(10)
+
+ est = estimator.Estimator(model_fn=self._model_fn)
+
+ est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
+ self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1))
+ est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
+ self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3))
+ # Hook not provided, input pipeline was not restored.
+ est.train(_input_fn, steps=2)
+ self.assertSequenceEqual(self._read_vars(est.model_dir), (6, 1))
+
+ def testRaiseErrorIfNoIterator(self):
+
+ def _input_fn():
+ return constant_op.constant(1, dtype=dtypes.int64)
+
+ est = estimator.Estimator(model_fn=self._model_fn)
+
+ with self.assertRaises(ValueError):
+ est.train(
+ _input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/map_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/map_dataset_op_test.py
new file mode 100644
index 0000000000..2f0bd1456b
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/map_dataset_op_test.py
@@ -0,0 +1,359 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import hashlib
+import itertools
+import os
+import time
+
+import numpy as np
+
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.client import session
+from tensorflow.python.data.experimental.ops import batching
+from tensorflow.python.data.experimental.ops import error_ops
+from tensorflow.python.data.experimental.ops import optimization
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import io_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+from tensorflow.python.util import compat
+
+_NUMPY_RANDOM_SEED = 42
+
+
+class MapDatasetTest(test_base.DatasetTestBase):
+
+ def testMapIgnoreError(self):
+ components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
+
+ dataset = (
+ dataset_ops.Dataset.from_tensor_slices(components)
+ .map(lambda x: array_ops.check_numerics(x, "message")).apply(
+ error_ops.ignore_errors()))
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ for x in [1., 2., 3., 5.]:
+ self.assertEqual(x, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testParallelMapIgnoreError(self):
+ components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
+
+ dataset = (
+ dataset_ops.Dataset.from_tensor_slices(components).map(
+ lambda x: array_ops.check_numerics(x, "message"),
+ num_parallel_calls=2).prefetch(2).apply(error_ops.ignore_errors()))
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ for x in [1., 2., 3., 5.]:
+ self.assertEqual(x, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testReadFileIgnoreError(self):
+
+ def write_string_to_file(value, filename):
+ with open(filename, "w") as f:
+ f.write(value)
+
+ filenames = [
+ os.path.join(self.get_temp_dir(), "file_%d.txt" % i) for i in range(5)
+ ]
+ for filename in filenames:
+ write_string_to_file(filename, filename)
+
+ dataset = (
+ dataset_ops.Dataset.from_tensor_slices(filenames).map(
+ io_ops.read_file,
+ num_parallel_calls=2).prefetch(2).apply(error_ops.ignore_errors()))
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ # All of the files are present.
+ sess.run(init_op)
+ for filename in filenames:
+ self.assertEqual(compat.as_bytes(filename), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Delete one of the files.
+ os.remove(filenames[0])
+
+ # Attempting to read filenames[0] will fail, but ignore_errors()
+ # will catch the error.
+ sess.run(init_op)
+ for filename in filenames[1:]:
+ self.assertEqual(compat.as_bytes(filename), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testCaptureResourceInMapFn(self):
+
+ def _build_ds(iterator):
+
+ def _map_fn(x):
+ get_next = iterator.get_next()
+ return x * get_next
+
+ return dataset_ops.Dataset.range(10).map(_map_fn)
+
+ def _build_graph():
+ captured_iterator = dataset_ops.Dataset.range(
+ 10).make_initializable_iterator()
+ ds = _build_ds(captured_iterator)
+ iterator = ds.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ return captured_iterator.initializer, init_op, get_next
+
+ with ops.Graph().as_default() as g:
+ captured_init_op, init_op, get_next = _build_graph()
+ with self.session(graph=g) as sess:
+ sess.run(captured_init_op)
+ sess.run(init_op)
+ for i in range(10):
+ self.assertEquals(i * i, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+
+class MapDatasetBenchmark(test.Benchmark):
+
+ # The purpose of this benchmark is to compare the performance of chaining vs
+ # fusing of the map and batch transformations across various configurations.
+ #
+ # NOTE: It is recommended to build the benchmark with
+ # `-c opt --copt=-mavx --copt=-mavx2 --copt=-mfma --copt=-gmlt`
+ # and execute it on a machine with at least 32 CPU cores.
+ def benchmarkMapAndBatch(self):
+
+ # Sequential pipeline configurations.
+ seq_elem_size_series = itertools.product([1], [1], [1, 2, 4, 8], [16])
+ seq_batch_size_series = itertools.product([1], [1], [1], [8, 16, 32, 64])
+
+ # Parallel pipeline configuration.
+ par_elem_size_series = itertools.product([32], [32], [1, 2, 4, 8], [256])
+ par_batch_size_series = itertools.product([32], [32], [1],
+ [128, 256, 512, 1024])
+ par_num_calls_series = itertools.product([8, 16, 32, 64], [32], [1], [512])
+ par_inter_op_series = itertools.product([32], [8, 16, 32, 64], [1], [512])
+
+ def name(method, label, num_calls, inter_op, element_size, batch_size):
+ return ("%s_id_%s_num_calls_%d_inter_op_%d_elem_size_%d_batch_size_%d" % (
+ method,
+ hashlib.sha1(label).hexdigest(),
+ num_calls,
+ inter_op,
+ element_size,
+ batch_size,
+ ))
+
+ def benchmark(label, series):
+
+ print("%s:" % label)
+ for num_calls, inter_op, element_size, batch_size in series:
+
+ num_iters = 1024 // (
+ (element_size * batch_size) // min(num_calls, inter_op))
+ k = 1024 * 1024
+ dataset = dataset_ops.Dataset.from_tensors((np.random.rand(
+ element_size, 4 * k), np.random.rand(4 * k, 1))).repeat()
+
+ chained_dataset = dataset.map(
+ math_ops.matmul,
+ num_parallel_calls=num_calls).batch(batch_size=batch_size)
+ chained_iterator = chained_dataset.make_one_shot_iterator()
+ chained_get_next = chained_iterator.get_next()
+
+ chained_deltas = []
+ with session.Session(
+ config=config_pb2.ConfigProto(
+ inter_op_parallelism_threads=inter_op,
+ use_per_session_threads=True)) as sess:
+ for _ in range(5):
+ sess.run(chained_get_next.op)
+ for _ in range(num_iters):
+ start = time.time()
+ sess.run(chained_get_next.op)
+ end = time.time()
+ chained_deltas.append(end - start)
+
+ fused_dataset = dataset.apply(
+ batching.map_and_batch(
+ math_ops.matmul,
+ num_parallel_calls=num_calls,
+ batch_size=batch_size))
+ fused_iterator = fused_dataset.make_one_shot_iterator()
+ fused_get_next = fused_iterator.get_next()
+
+ fused_deltas = []
+ with session.Session(
+ config=config_pb2.ConfigProto(
+ inter_op_parallelism_threads=inter_op,
+ use_per_session_threads=True)) as sess:
+
+ for _ in range(5):
+ sess.run(fused_get_next.op)
+ for _ in range(num_iters):
+ start = time.time()
+ sess.run(fused_get_next.op)
+ end = time.time()
+ fused_deltas.append(end - start)
+
+ print(
+ "batch size: %d, num parallel calls: %d, inter-op parallelism: %d, "
+ "element size: %d, num iters: %d\nchained wall time: %f (median), "
+ "%f (mean), %f (stddev), %f (min), %f (max)\n fused wall time: "
+ "%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n "
+ "chained/fused: %.2fx (median), %.2fx (mean)" %
+ (batch_size, num_calls, inter_op, element_size, num_iters,
+ np.median(chained_deltas), np.mean(chained_deltas),
+ np.std(chained_deltas), np.min(chained_deltas),
+ np.max(chained_deltas), np.median(fused_deltas),
+ np.mean(fused_deltas), np.std(fused_deltas), np.min(fused_deltas),
+ np.max(fused_deltas),
+ np.median(chained_deltas) / np.median(fused_deltas),
+ np.mean(chained_deltas) / np.mean(fused_deltas)))
+
+ self.report_benchmark(
+ iters=num_iters,
+ wall_time=np.median(chained_deltas),
+ name=name("chained", label, num_calls, inter_op, element_size,
+ batch_size))
+
+ self.report_benchmark(
+ iters=num_iters,
+ wall_time=np.median(fused_deltas),
+ name=name("fused", label, num_calls, inter_op, element_size,
+ batch_size))
+
+ print("")
+
+ np.random.seed(_NUMPY_RANDOM_SEED)
+ benchmark("Sequential element size evaluation", seq_elem_size_series)
+ benchmark("Sequential batch size evaluation", seq_batch_size_series)
+ benchmark("Parallel element size evaluation", par_elem_size_series)
+ benchmark("Parallel batch size evaluation", par_batch_size_series)
+ benchmark("Transformation parallelism evaluation", par_num_calls_series)
+ benchmark("Threadpool size evaluation", par_inter_op_series)
+
+ # This benchmark compares the performance of pipeline with multiple chained
+ # maps with and without map fusion.
+ def benchmarkChainOfMaps(self):
+ chain_lengths = [0, 1, 2, 5, 10, 20, 50]
+ for chain_length in chain_lengths:
+ self._benchmarkChainOfMaps(chain_length, False)
+ self._benchmarkChainOfMaps(chain_length, True)
+
+ def _benchmarkChainOfMaps(self, chain_length, optimize_dataset):
+ with ops.Graph().as_default():
+ dataset = dataset_ops.Dataset.from_tensors(0).repeat(None)
+ for _ in range(chain_length):
+ dataset = dataset.map(lambda x: x)
+ if optimize_dataset:
+ dataset = dataset.apply(optimization.optimize(["map_fusion"]))
+
+ iterator = dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ with session.Session() as sess:
+ for _ in range(5):
+ sess.run(next_element.op)
+ deltas = []
+ for _ in range(100):
+ start = time.time()
+ for _ in range(100):
+ sess.run(next_element.op)
+ end = time.time()
+ deltas.append(end - start)
+
+ median_wall_time = np.median(deltas) / 100
+ opt_mark = "opt" if optimize_dataset else "no-opt"
+ print("Map dataset {} chain length: {} Median wall time: {}".format(
+ opt_mark, chain_length, median_wall_time))
+ self.report_benchmark(
+ iters=1000,
+ wall_time=median_wall_time,
+ name="benchmark_map_dataset_chain_latency_{}_{}".format(
+ opt_mark, chain_length))
+
+
+class MapAndFilterBenchmark(test.Benchmark):
+
+ # This benchmark compares the performance of pipeline with multiple chained
+ # map + filter with and without map fusion.
+ def benchmarkMapAndFilter(self):
+ chain_lengths = [0, 1, 2, 5, 10, 20, 50]
+ for chain_length in chain_lengths:
+ self._benchmarkMapAndFilter(chain_length, False)
+ self._benchmarkMapAndFilter(chain_length, True)
+
+ def _benchmarkMapAndFilter(self, chain_length, optimize_dataset):
+ with ops.Graph().as_default():
+ dataset = dataset_ops.Dataset.from_tensors(0).repeat(None)
+ for _ in range(chain_length):
+ dataset = dataset.map(lambda x: x + 5).filter(
+ lambda x: math_ops.greater_equal(x - 5, 0))
+ if optimize_dataset:
+ dataset = dataset.apply(
+ optimization.optimize(["map_and_filter_fusion"]))
+
+ iterator = dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ with session.Session() as sess:
+ for _ in range(10):
+ sess.run(next_element.op)
+ deltas = []
+ for _ in range(100):
+ start = time.time()
+ for _ in range(100):
+ sess.run(next_element.op)
+ end = time.time()
+ deltas.append(end - start)
+
+ median_wall_time = np.median(deltas) / 100
+ opt_mark = "opt" if optimize_dataset else "no-opt"
+ print("Map and filter dataset {} chain length: {} Median wall time: {}".
+ format(opt_mark, chain_length, median_wall_time))
+ self.report_benchmark(
+ iters=1000,
+ wall_time=median_wall_time,
+ name="benchmark_map_and_filter_dataset_chain_latency_{}_{}".format(
+ opt_mark, chain_length))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py b/tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py
new file mode 100644
index 0000000000..612ee332c4
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py
@@ -0,0 +1,281 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for MapDefunOp."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+from tensorflow.python.client import session
+from tensorflow.python.data.experimental.ops import map_defun
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import function
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import functional_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class MapDefunTest(test_base.DatasetTestBase):
+
+ def testMapDefunSimple(self):
+
+ @function.Defun(dtypes.int32)
+ def simple_fn(x):
+ return x * 2 + 3
+
+ nums = [[1, 2], [3, 4], [5, 6]]
+ elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
+ r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(2,)])[0]
+ expected = elems * 2 + 3
+ self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
+
+ def testMapDefunMismatchedTypes(self):
+
+ @function.Defun(dtypes.int32)
+ def fn(x):
+ return math_ops.cast(x, dtypes.float64)
+
+ nums = [1, 2, 3, 4, 5, 6]
+ elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
+ r = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])[0]
+ with self.assertRaises(errors.InvalidArgumentError):
+ self.evaluate(r)
+
+ def testMapDefunReduceDim(self):
+ # Tests where the output has a different rank from the input
+
+ @function.Defun(dtypes.int32)
+ def fn(x):
+ return array_ops.gather(x, 0)
+
+ nums = [[1, 2], [3, 4], [5, 6]]
+ elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
+ r = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])[0]
+ expected = constant_op.constant([1, 3, 5])
+ self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
+
+ def testMapDefunMultipleOutputs(self):
+
+ @function.Defun(dtypes.int32)
+ def fn(x):
+ return (x, math_ops.cast(x * 2 + 3, dtypes.float64))
+
+ nums = [[1, 2], [3, 4], [5, 6]]
+ elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
+ r = map_defun.map_defun(fn, [elems], [dtypes.int32, dtypes.float64], [(2,),
+ (2,)])
+ expected = [elems, elems * 2 + 3]
+ self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
+
+ def testMapDefunShapeInference(self):
+
+ @function.Defun(dtypes.int32)
+ def fn(x):
+ return x
+
+ nums = [[1, 2], [3, 4], [5, 6]]
+ elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
+ result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2,)])[0]
+ self.assertEqual(result.get_shape(), (3, 2))
+
+ def testMapDefunPartialShapeInference(self):
+
+ @function.Defun(dtypes.int32)
+ def fn(x):
+ return x
+
+ elems = array_ops.placeholder(dtypes.int64, (None, 2))
+ result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2,)])
+ self.assertEqual(result[0].get_shape().as_list(), [None, 2])
+
+ def testMapDefunRaisesErrorOnRuntimeShapeMismatch(self):
+
+ @function.Defun(dtypes.int32, dtypes.int32)
+ def fn(x, y):
+ return x, y
+
+ elems1 = array_ops.placeholder(dtypes.int32)
+ elems2 = array_ops.placeholder(dtypes.int32)
+ result = map_defun.map_defun(fn, [elems1, elems2],
+ [dtypes.int32, dtypes.int32], [(), ()])
+ with self.cached_session() as sess:
+ with self.assertRaisesWithPredicateMatch(
+ errors.InvalidArgumentError,
+ "All inputs must have the same dimension 0."):
+ sess.run(result, feed_dict={elems1: [1, 2, 3, 4, 5], elems2: [1, 2, 3]})
+
+ def testMapDefunRaisesDefunError(self):
+
+ @function.Defun(dtypes.int32)
+ def fn(x):
+ with ops.control_dependencies([check_ops.assert_equal(x, 0)]):
+ return array_ops.identity(x)
+
+ elems = constant_op.constant([0, 0, 0, 37, 0])
+ result = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])
+ with self.assertRaises(errors.InvalidArgumentError):
+ self.evaluate(result)
+
+ def testMapDefunCancelledCorrectly(self):
+
+ @function.Defun(dtypes.int64)
+ def defun(x):
+ # x has leading dimension 5, this will raise an error
+ return array_ops.gather(x, 10)
+
+ c = array_ops.tile(
+ array_ops.expand_dims(
+ constant_op.constant([1, 2, 3, 4, 5], dtype=dtypes.int64), 0),
+ [100, 1])
+ map_defun_op = map_defun.map_defun(defun, [c], [dtypes.int64], [()])[0]
+ with self.assertRaisesRegexp(errors.InvalidArgumentError,
+ r"indices = 10 is not in \[0, 5\)"):
+ self.evaluate(map_defun_op)
+
+ def testMapDefunWithUnspecifiedOutputShape(self):
+
+ @function.Defun(dtypes.int32)
+ def simple_fn(x):
+ res = x * 2 + 3
+ return (res, res + 1, res + 2)
+
+ nums = [[1, 2], [3, 4], [5, 6]]
+ elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
+ r = map_defun.map_defun(simple_fn, [elems],
+ [dtypes.int32, dtypes.int32, dtypes.int32],
+ [None, (None,), (2,)])
+ expected = elems * 2 + 3
+ self.assertAllEqual(self.evaluate(r[0]), self.evaluate(expected))
+ self.assertAllEqual(self.evaluate(r[1]), self.evaluate(expected + 1))
+ self.assertAllEqual(self.evaluate(r[2]), self.evaluate(expected + 2))
+
+ def testMapDefunWithDifferentOutputShapeEachRun(self):
+
+ @function.Defun(dtypes.int32)
+ def simple_fn(x):
+ return x * 2 + 3
+
+ elems = array_ops.placeholder(dtypes.int32, name="data")
+ r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [None])[0]
+ with session.Session() as sess:
+ self.assertAllEqual(sess.run(r, feed_dict={elems: [0]}), [3])
+ self.assertAllEqual(
+ sess.run(r, feed_dict={elems: [[0], [1]]}), [[3], [5]])
+
+ def testMapDefunWithWrongOutputShape(self):
+
+ @function.Defun(dtypes.int32)
+ def simple_fn(x):
+ return x * 2 + 3
+
+ nums = [[1, 2], [3, 4], [5, 6]]
+ elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
+ r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(1,)])[0]
+ with self.assertRaises(errors.InvalidArgumentError):
+ self.evaluate(r)
+
+ def testMapDefunWithInvalidInput(self):
+
+ @function.Defun(dtypes.int32)
+ def simple_fn(x):
+ return x * 2
+
+ c = constant_op.constant(2)
+ with self.assertRaises(ValueError):
+ # Fails at graph construction time for inputs with known shapes.
+ r = map_defun.map_defun(simple_fn, [c], [dtypes.int32], [None])[0]
+ p = array_ops.placeholder(dtypes.int32)
+ r = map_defun.map_defun(simple_fn, [p], [dtypes.int32], [None])[0]
+ with session.Session() as sess:
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(r, feed_dict={p: 0})
+
+ def _assert_op_cancelled(self, sess, map_defun_op):
+ with self.assertRaisesRegexp(errors.CancelledError, "was cancelled"):
+ sess.run(map_defun_op)
+
+ def testMapDefunWithParentCancellation(self):
+ # Checks that a cancellation of the parent graph is threaded through to
+ # MapDefunOp correctly.
+ @function.Defun(dtypes.int32)
+ def simple_fn(x):
+ del x
+ queue = data_flow_ops.FIFOQueue(10, dtypes.int32, ())
+ # Blocking
+ return queue.dequeue_many(5)
+
+ c = constant_op.constant([1, 2, 3, 4, 5])
+ map_defun_op = map_defun.map_defun(simple_fn, [c], [dtypes.int32], [()])[0]
+
+ with self.cached_session() as sess:
+ thread = self.checkedThread(
+ self._assert_op_cancelled, args=(sess, map_defun_op))
+ thread.start()
+ time.sleep(0.1)
+ sess.close()
+ thread.join()
+
+
+class MapDefunBenchmark(test.Benchmark):
+
+ def _run(self, op, name=None, num_iters=3000):
+ with session.Session() as sess:
+ # Warm up the session
+ for _ in range(5):
+ sess.run(op)
+ start = time.time()
+ for _ in range(num_iters):
+ sess.run(op)
+ end = time.time()
+ mean_us = (end - start) * 1e6 / num_iters
+ self.report_benchmark(
+ name=name,
+ iters=num_iters,
+ wall_time=mean_us,
+ extras={"examples_per_sec": num_iters / (end - start)})
+
+ def benchmarkDefunVsMapFn(self):
+ """Benchmarks to compare the performance of MapDefun vs tf.map_fn."""
+
+ @function.Defun(dtypes.int32)
+ def defun(x):
+ return array_ops.identity(x)
+
+ def map_fn(x):
+ return array_ops.identity(x)
+
+ base = math_ops.range(100)
+ for input_size in [10, 100, 1000, 10000]:
+ num_iters = 100000 // input_size
+ map_defun_op = map_defun.map_defun(defun, [base], [dtypes.int32], [()])
+ map_fn_op = functional_ops.map_fn(map_fn, base)
+
+ self._run(
+ map_defun_op,
+ "benchmarkMapDefun_size_%d" % input_size,
+ num_iters=num_iters)
+ self._run(
+ map_fn_op, "benchmarkMapFn_size_%d" % input_size, num_iters=num_iters)
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD b/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD
new file mode 100644
index 0000000000..68f73bddb5
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD
@@ -0,0 +1,164 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+py_test(
+ name = "assert_next_dataset_op_test",
+ size = "medium",
+ srcs = ["assert_next_dataset_op_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python/data/experimental/ops:optimization",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "hoist_random_uniform_test",
+ size = "small",
+ srcs = ["hoist_random_uniform_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python/data/experimental/ops:optimization",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+py_test(
+ name = "latency_all_edges_test",
+ size = "small",
+ srcs = ["latency_all_edges_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python/data/experimental/kernel_tests:stats_dataset_test_base",
+ "//tensorflow/python/data/experimental/ops:optimization",
+ "//tensorflow/python/data/experimental/ops:stats_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "map_vectorization_test",
+ size = "small",
+ srcs = ["map_vectorization_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:check_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:session",
+ "//tensorflow/python/data/experimental/ops:optimization",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+py_test(
+ name = "map_and_filter_fusion_test",
+ size = "medium",
+ srcs = ["map_and_filter_fusion_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python/data/experimental/ops:optimization",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+py_test(
+ name = "map_parallelization_test",
+ size = "small",
+ srcs = ["map_parallelization_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python/data/experimental/ops:optimization",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+py_test(
+ name = "model_dataset_op_test",
+ size = "medium",
+ srcs = ["model_dataset_op_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "optonly",
+ ],
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python/data/experimental/ops:batching",
+ "//tensorflow/python/data/experimental/ops:interleave_ops",
+ "//tensorflow/python/data/experimental/ops:optimization",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "noop_elimination_test",
+ size = "small",
+ srcs = ["noop_elimination_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python/data/experimental/ops:batching",
+ "//tensorflow/python/data/experimental/ops:interleave_ops",
+ "//tensorflow/python/data/experimental/ops:optimization",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "optimize_dataset_op_test",
+ size = "small",
+ srcs = ["optimize_dataset_op_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python/data/experimental/ops:optimization",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/assert_next_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/assert_next_dataset_op_test.py
new file mode 100644
index 0000000000..45b77b5c20
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/assert_next_dataset_op_test.py
@@ -0,0 +1,65 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.ops import optimization
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import errors
+from tensorflow.python.platform import test
+
+
+class AssertNextDatasetTest(test_base.DatasetTestBase):
+
+ def testAssertNext(self):
+ dataset = dataset_ops.Dataset.from_tensors(0).apply(
+ optimization.assert_next(["Map"])).map(lambda x: x)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ self.assertEqual(0, sess.run(get_next))
+
+ def testAssertNextInvalid(self):
+ dataset = dataset_ops.Dataset.from_tensors(0).apply(
+ optimization.assert_next(["Whoops"])).map(lambda x: x)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ "Asserted Whoops transformation at offset 0 but encountered "
+ "Map transformation instead."):
+ sess.run(get_next)
+
+ def testAssertNextShort(self):
+ dataset = dataset_ops.Dataset.from_tensors(0).apply(
+ optimization.assert_next(["Map", "Whoops"])).map(lambda x: x)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ "Asserted next 2 transformations but encountered only 1."):
+ sess.run(get_next)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/hoist_random_uniform_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/hoist_random_uniform_test.py
new file mode 100644
index 0000000000..3cd9753665
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/hoist_random_uniform_test.py
@@ -0,0 +1,103 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for HostState optimization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+
+from tensorflow.python.data.experimental.ops import optimization
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.platform import test
+
+
+class HoistRandomUniformTest(test_base.DatasetTestBase, parameterized.TestCase):
+
+ @staticmethod
+ def map_functions():
+ plus_one = lambda x: x + 1
+
+ def random(_):
+ return random_ops.random_uniform([],
+ minval=1,
+ maxval=10,
+ dtype=dtypes.float32,
+ seed=42)
+
+ def random_with_assert(x):
+ y = random(x)
+ assert_op = control_flow_ops.Assert(math_ops.greater_equal(y, 1), [y])
+ with ops.control_dependencies([assert_op]):
+ return y
+
+ twice_random = lambda x: (random(x) + random(x)) / 2.
+
+ tests = [("PlusOne", plus_one, False), ("RandomUniform", random, True),
+ ("RandomWithAssert", random_with_assert, True),
+ ("TwiceRandom", twice_random, False)]
+ return tuple(tests)
+
+ @parameterized.named_parameters(*map_functions.__func__())
+ def testHoisting(self, function, will_optimize):
+ dataset = dataset_ops.Dataset.range(5).apply(
+ optimization.assert_next(
+ ["Zip[0]", "Map"] if will_optimize else ["Map"])).map(function)
+
+ dataset = dataset.apply(optimization.optimize(["hoist_random_uniform"]))
+ self._testDataset(dataset)
+
+ def testAdditionalInputs(self):
+ a = constant_op.constant(1, dtype=dtypes.float32)
+ b = constant_op.constant(0, dtype=dtypes.float32)
+ some_tensor = math_ops.mul(a, b)
+
+ def random_with_capture(_):
+ return some_tensor + random_ops.random_uniform(
+ [], minval=1, maxval=10, dtype=dtypes.float32, seed=42)
+
+ dataset = dataset_ops.Dataset.range(5).apply(
+ optimization.assert_next(
+ ["Zip[0]", "Map"])).map(random_with_capture).apply(
+ optimization.optimize(["hoist_random_uniform"]))
+ self._testDataset(dataset)
+
+ def _testDataset(self, dataset):
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+ previous_result = 0
+ with self.cached_session() as sess:
+ for _ in range(5):
+ result = sess.run(get_next)
+ self.assertLessEqual(1, result)
+ self.assertLessEqual(result, 10)
+ # This checks if the result is somehow random by checking if we are not
+ # generating the same values.
+ self.assertNotEqual(previous_result, result)
+ previous_result = result
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/latency_all_edges_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/latency_all_edges_test.py
new file mode 100644
index 0000000000..45623876ae
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/latency_all_edges_test.py
@@ -0,0 +1,58 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the LatencyAllEdges optimization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.kernel_tests import stats_dataset_test_base
+from tensorflow.python.data.experimental.ops import optimization
+from tensorflow.python.data.experimental.ops import stats_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import errors
+from tensorflow.python.platform import test
+
+
+class OptimizeStatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
+
+ def testLatencyStatsOptimization(self):
+
+ stats_aggregator = stats_ops.StatsAggregator()
+ dataset = dataset_ops.Dataset.from_tensors(1).apply(
+ optimization.assert_next(
+ ["LatencyStats", "Map", "LatencyStats", "Prefetch",
+ "LatencyStats"])).map(lambda x: x * x).prefetch(1).apply(
+ stats_ops.set_stats_aggregator(stats_aggregator)).apply(
+ optimization.optimize(["latency_all_edges"]))
+ iterator = dataset.make_initializable_iterator()
+ get_next = iterator.get_next()
+ summary_t = stats_aggregator.get_summary()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ self.assertEqual(1 * 1, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+ summary_str = sess.run(summary_t)
+ self._assertSummaryHasCount(summary_str,
+ "record_latency_TensorDataset/_1", 1)
+ self._assertSummaryHasCount(summary_str, "record_latency_MapDataset/_4",
+ 1)
+ self._assertSummaryHasCount(summary_str,
+ "record_latency_PrefetchDataset/_6", 1)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/map_and_filter_fusion_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/map_and_filter_fusion_test.py
new file mode 100644
index 0000000000..a439635716
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/map_and_filter_fusion_test.py
@@ -0,0 +1,225 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the MapAndFilterFusion optimization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+
+from tensorflow.python.data.experimental.ops import optimization
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class MapAndFilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
+
+ @staticmethod
+ def map_functions():
+ identity = lambda x: x
+ increment = lambda x: x + 1
+
+ def increment_and_square(x):
+ y = x + 1
+ return y * y
+
+ functions = [identity, increment, increment_and_square]
+ tests = []
+ for i, fun1 in enumerate(functions):
+ for j, fun2 in enumerate(functions):
+ tests.append((
+ "Test{}{}".format(i, j),
+ [fun1, fun2],
+ ))
+ for k, fun3 in enumerate(functions):
+ tests.append((
+ "Test{}{}{}".format(i, j, k),
+ [fun1, fun2, fun3],
+ ))
+
+ swap = lambda x, n: (n, x)
+ tests.append((
+ "Swap1",
+ [lambda x: (x, 42), swap],
+ ))
+ tests.append((
+ "Swap2",
+ [lambda x: (x, 42), swap, swap],
+ ))
+ return tuple(tests)
+
+ @parameterized.named_parameters(*map_functions.__func__())
+ def testMapFusion(self, functions):
+ dataset = dataset_ops.Dataset.range(5).apply(
+ optimization.assert_next(["Map", "Prefetch"]))
+ for function in functions:
+ dataset = dataset.map(function)
+
+ dataset = dataset.prefetch(0).apply(optimization.optimize(["map_fusion"]))
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+ with self.cached_session() as sess:
+ for x in range(5):
+ result = sess.run(get_next)
+ r = x
+ for function in functions:
+ if isinstance(r, tuple):
+ r = function(*r) # Pass tuple as multiple arguments.
+ else:
+ r = function(r)
+ self.assertAllEqual(r, result)
+
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ @staticmethod
+ def map_and_filter_functions():
+ identity = lambda x: x
+ increment = lambda x: x + 1
+ minus_five = lambda x: x - 5
+
+ def increment_and_square(x):
+ y = x + 1
+ return y * y
+
+ take_all = lambda x: constant_op.constant(True)
+ is_zero = lambda x: math_ops.equal(x, 0)
+ is_odd = lambda x: math_ops.equal(x % 2, 0)
+ greater = lambda x: math_ops.greater(x + 5, 0)
+
+ functions = [identity, increment, minus_five, increment_and_square]
+ filters = [take_all, is_zero, is_odd, greater]
+ tests = []
+
+ for x, fun in enumerate(functions):
+ for y, predicate in enumerate(filters):
+ tests.append(("Mixed{}{}".format(x, y), fun, predicate))
+
+ # Multi output
+ tests.append(("Multi1", lambda x: (x, x),
+ lambda x, y: constant_op.constant(True)))
+ tests.append(
+ ("Multi2", lambda x: (x, 2),
+ lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0)))
+ return tuple(tests)
+
+ @parameterized.named_parameters(*map_and_filter_functions.__func__())
+ def testMapFilterFusion(self, function, predicate):
+ dataset = dataset_ops.Dataset.range(10).apply(
+ optimization.assert_next(
+ ["Map",
+ "FilterByLastComponent"])).map(function).filter(predicate).apply(
+ optimization.optimize(["map_and_filter_fusion"]))
+ self._testMapAndFilter(dataset, function, predicate)
+
+ def _testMapAndFilter(self, dataset, function, predicate):
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+ with self.cached_session() as sess:
+ for x in range(10):
+ r = function(x)
+ if isinstance(r, tuple):
+ b = predicate(*r) # Pass tuple as multiple arguments.
+ else:
+ b = predicate(r)
+ if sess.run(b):
+ result = sess.run(get_next)
+ self.assertAllEqual(r, result)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testAdditionalInputs(self):
+ a = constant_op.constant(3, dtype=dtypes.int64)
+ b = constant_op.constant(4, dtype=dtypes.int64)
+ some_tensor = math_ops.mul(a, b)
+ function = lambda x: x * x
+
+ def predicate(y):
+ return math_ops.less(math_ops.cast(y, dtypes.int64), some_tensor)
+
+ # We are currently not supporting functions with additional inputs.
+ dataset = dataset_ops.Dataset.range(10).apply(
+ optimization.assert_next(
+ ["Map", "Filter"])).map(function).filter(predicate).apply(
+ optimization.optimize(["map_and_filter_fusion"]))
+
+ self._testMapAndFilter(dataset, function, predicate)
+
+ @staticmethod
+ def filter_functions():
+ take_all = lambda x: constant_op.constant(True)
+ is_zero = lambda x: math_ops.equal(x, 0)
+ greater = lambda x: math_ops.greater(x + 5, 0)
+
+ tests = []
+ filters = [take_all, is_zero, greater]
+ identity = lambda x: x
+ for x, predicate_1 in enumerate(filters):
+ for y, predicate_2 in enumerate(filters):
+ tests.append(("Mixed{}{}".format(x, y), identity,
+ [predicate_1, predicate_2]))
+ for z, predicate_3 in enumerate(filters):
+ tests.append(("Mixed{}{}{}".format(x, y, z), identity,
+ [predicate_1, predicate_2, predicate_3]))
+
+ take_all_multiple = lambda x, y: constant_op.constant(True)
+ # Multi output
+ tests.append(("Multi1", lambda x: (x, x),
+ [take_all_multiple, take_all_multiple]))
+ tests.append(("Multi2", lambda x: (x, 2), [
+ take_all_multiple,
+ lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0)
+ ]))
+ return tuple(tests)
+
+ @parameterized.named_parameters(*filter_functions.__func__())
+ def testFilterFusion(self, map_function, predicates):
+ dataset = dataset_ops.Dataset.range(5).apply(
+ optimization.assert_next(["Map", "Filter",
+ "Prefetch"])).map(map_function)
+ for predicate in predicates:
+ dataset = dataset.filter(predicate)
+
+ dataset = dataset.prefetch(0).apply(
+ optimization.optimize(["filter_fusion"]))
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+ with self.cached_session() as sess:
+ for x in range(5):
+ r = map_function(x)
+ filtered = False
+ for predicate in predicates:
+ if isinstance(r, tuple):
+ b = predicate(*r) # Pass tuple as multiple arguments.
+ else:
+ b = predicate(r)
+ if not sess.run(b):
+ filtered = True
+ break
+
+ if not filtered:
+ result = sess.run(get_next)
+ self.assertAllEqual(r, result)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/map_parallelization_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/map_parallelization_test.py
new file mode 100644
index 0000000000..334d8e3778
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/map_parallelization_test.py
@@ -0,0 +1,85 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the MapParallelization optimization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+
+from tensorflow.python.data.experimental.ops import optimization
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.platform import test
+
+
+class MapParallelizationTest(test_base.DatasetTestBase, parameterized.TestCase):
+
+ @staticmethod
+ def map_functions():
+ identity = lambda x: x
+ increment = lambda x: x + 1
+
+ def assert_greater(x):
+ assert_op = control_flow_ops.Assert(math_ops.greater(x, -1), [x])
+ with ops.control_dependencies([assert_op]):
+ return x
+
+ def random(_):
+ return random_ops.random_uniform([],
+ minval=0,
+ maxval=10,
+ dtype=dtypes.int64,
+ seed=42)
+
+ def assert_with_random(x):
+ x = assert_greater(x)
+ return random(x)
+
+ return (("Identity", identity, True), ("Increment", increment, True),
+ ("AssertGreater", assert_greater, True), ("Random", random, False),
+ ("AssertWithRandom", assert_with_random, False))
+
+ @parameterized.named_parameters(*map_functions.__func__())
+ def testMapParallelization(self, function, should_optimize):
+ next_nodes = ["ParallelMap"] if should_optimize else ["Map"]
+ dataset = dataset_ops.Dataset.range(5).apply(
+ optimization.assert_next(next_nodes)).map(function).apply(
+ optimization.optimize(["map_parallelization"]))
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ for x in range(5):
+ result = sess.run(get_next)
+ # No need to run the pipeline if it was not optimized. Also the results
+ # might be hard to check because of random.
+ if not should_optimize:
+ return
+ r = function(x)
+ self.assertAllEqual(r, result)
+
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py
new file mode 100644
index 0000000000..d47492753e
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py
@@ -0,0 +1,223 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the MapVectorization optimization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.python.client import session
+from tensorflow.python.data.experimental.ops import optimization
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
+
+ def _get_test_datasets(self,
+ base_dataset,
+ map_fn,
+ num_parallel_calls=None,
+ expect_optimized=True):
+ """Given base dataset and map fn, creates test datasets.
+
+ Returns a tuple of (unoptimized, dataset, optimized dataset). The
+ unoptimized dataset has the assertion that Batch follows Map. The optimized
+ dataset has the assertion that Map follows Batch, and has the
+ "map_vectorization" optimization applied.
+
+ Args:
+ base_dataset: Input dataset to map->batch
+ map_fn: Map function to use
+ num_parallel_calls: (Optional.) num_parallel_calls argument for map
+ expect_optimized: (Optional.) Whether we expect the optimization to take
+ place, in which case we will assert that Batch is followed by Map,
+ otherwise Map followed by Batch. Defaults to True.
+
+ Returns:
+ Tuple of (unoptimized dataset, optimized dataset).
+ """
+ map_node_name = "Map" if num_parallel_calls is None else "ParallelMap"
+ batch_size = 100
+
+ def _make_dataset(node_names):
+ return base_dataset.apply(optimization.assert_next(node_names)).map(
+ map_fn, num_parallel_calls=num_parallel_calls).batch(batch_size)
+
+ unoptimized = _make_dataset([map_node_name, "Batch"])
+ optimized = _make_dataset(["Batch", map_node_name] if expect_optimized else
+ [map_node_name, "Batch"]).apply(
+ optimization.optimize(["map_vectorization"]))
+
+ return unoptimized, optimized
+
+ @parameterized.named_parameters(
+ ("Basic", lambda x: (x, x + 1), None),
+ ("Parallel", lambda x: (x, x + 1), 12),
+ ("Gather", lambda x: array_ops.gather(x, 0), 12),
+ )
+ def testOptimization(self, map_fn, num_parallel_calls):
+ base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2],
+ [3, 4]]).repeat(5)
+ unoptimized, optimized = self._get_test_datasets(base_dataset, map_fn,
+ num_parallel_calls)
+ self.assertDatasetsEqual(unoptimized, optimized)
+
+ def testOptimizationBadMapFn(self):
+ # Test map functions that give an error
+ def map_fn(x):
+ # x has leading dimension 5, this will raise an error
+ return array_ops.gather(x, 10)
+
+ base_dataset = dataset_ops.Dataset.range(5).repeat(5).batch(
+ 5, drop_remainder=True)
+ _, optimized = self._get_test_datasets(base_dataset, map_fn)
+ nxt = optimized.make_one_shot_iterator().get_next()
+ with self.assertRaisesRegexp(errors.InvalidArgumentError,
+ r"indices = 10 is not in \[0, 5\)"):
+ self.evaluate(nxt)
+
+ def testOptimizationWithCapturedInputs(self):
+ # Tests that vectorization works with captured inputs
+ def map_fn(x):
+ return x + y
+
+ y = constant_op.constant(1, shape=(2,))
+ base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2],
+ [3, 4]]).repeat(5)
+ # TODO(rachelim): when this optimization works, turn on expect_optimized
+ unoptimized, optimized = self._get_test_datasets(
+ base_dataset, map_fn, expect_optimized=False)
+ self.assertDatasetsEqual(optimized, unoptimized)
+
+ def testOptimizationIgnoreStateful(self):
+
+ def map_fn(x):
+ with ops.control_dependencies([check_ops.assert_equal(x, 0)]):
+ return array_ops.identity(x)
+
+ base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2],
+ [3, 4]]).repeat(5)
+ unoptimized, optimized = self._get_test_datasets(
+ base_dataset, map_fn, expect_optimized=False)
+ self.assertDatasetsRaiseSameError(
+ unoptimized, optimized, errors.InvalidArgumentError,
+ [("OneShotIterator", "OneShotIterator_1", 1),
+ ("IteratorGetNext", "IteratorGetNext_1", 1)])
+
+ def testOptimizationIgnoreRagged(self):
+ # Make sure we ignore inputs that might not be uniformly sized
+ def map_fn(x):
+ return array_ops.gather(x, 0)
+
+ # output_shape = (?,)
+ base_dataset = dataset_ops.Dataset.range(20).batch(3, drop_remainder=False)
+ unoptimized, optimized = self._get_test_datasets(
+ base_dataset, map_fn, expect_optimized=False)
+ self.assertDatasetsEqual(unoptimized, optimized)
+
+ def testOptimizationIgnoreRaggedMap(self):
+ # Don't optimize when the output of the map fn shapes are unknown.
+ def map_fn(x):
+ return array_ops.tile(x, x)
+
+ base_dataset = dataset_ops.Dataset.range(20).batch(1, drop_remainder=True)
+ unoptimized, optimized = self._get_test_datasets(
+ base_dataset, map_fn, expect_optimized=False)
+ self.assertDatasetsRaiseSameError(
+ unoptimized, optimized, errors.InvalidArgumentError,
+ [("OneShotIterator", "OneShotIterator_1", 1),
+ ("IteratorGetNext", "IteratorGetNext_1", 1)])
+
+
+class MapVectorizationBenchmark(test.Benchmark):
+ # TODO(rachelim): Add a benchmark for more expensive transformations, such as
+ # vgg_preprocessing.
+
+ def _run(self, x, num_iters=100, name=None):
+ deltas = []
+ with session.Session() as sess:
+ for _ in range(5):
+ # Warm up session...
+ sess.run(x)
+ for _ in range(num_iters):
+ start = time.time()
+ sess.run(x)
+ end = time.time()
+ deltas.append(end - start)
+ median_time = np.median(deltas)
+ self.report_benchmark(iters=num_iters, wall_time=median_time, name=name)
+ return median_time
+
+ def _compare(self, input_dataset, map_fn, batch_size, input_size, str_id):
+ num_elems = np.prod(input_size)
+ name_template = "{}__batch_size_{}_input_size_{}_{}"
+ unoptimized = input_dataset.map(map_fn).batch(batch_size)
+ unoptimized_op = unoptimized.make_one_shot_iterator().get_next()
+
+ optimized = unoptimized.apply(optimization.optimize(["map_vectorization"]))
+ optimized_op = optimized.make_one_shot_iterator().get_next()
+
+ unoptimized_time = self._run(
+ unoptimized_op,
+ name=name_template.format(str_id, batch_size, num_elems, "unoptimized"))
+ optimized_time = self._run(
+ optimized_op,
+ name=name_template.format(str_id, batch_size, num_elems, "optimized"))
+
+ print("Batch size: {}\n"
+ "Input size: {}\n"
+ "Transformation: {}\n"
+ "Speedup: {}\n".format(batch_size, input_size, str_id,
+ (unoptimized_time / optimized_time)))
+
+ # Known cheap functions
+ def benchmarkIdentity(self):
+ self._benchmark_helper(lambda *args: [array_ops.identity(x) for x in args],
+ "identity")
+
+ def benchmarkAddConst(self):
+ self._benchmark_helper(lambda *args: [x + 1 for x in args], "add_const")
+
+ def benchmarkSelect(self):
+ self._benchmark_helper(lambda *args: args[0], "select")
+
+ def benchmarkCast(self):
+ self._benchmark_helper(
+ lambda *args: [math_ops.cast(x, dtypes.float64) for x in args], "cast")
+
+ def _benchmark_helper(self, map_fn, str_id):
+ input_sizes = [(10, 10, 3), (10, 100, 300)]
+ batch_size = 1000
+ for input_size in input_sizes:
+ input_dataset = dataset_ops.Dataset.from_tensor_slices(
+ (np.random.rand(*input_size), np.random.rand(*input_size))).repeat()
+ self._compare(input_dataset, map_fn, batch_size, input_size, str_id)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/model_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/model_dataset_op_test.py
new file mode 100644
index 0000000000..a9f2ce8c03
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/model_dataset_op_test.py
@@ -0,0 +1,183 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+import numpy as np
+
+from tensorflow.python.data.experimental.ops import batching
+from tensorflow.python.data.experimental.ops import optimization
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class ModelDatasetTest(test_base.DatasetTestBase):
+
+ def testModelMap(self):
+ k = 1024 * 1024
+ dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k),
+ np.random.rand(4 * k,
+ 1))).repeat()
+ dataset = dataset.map(math_ops.matmul)
+ iterator = dataset.apply(optimization.model()).make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ deltas = []
+ with self.cached_session() as sess:
+ for _ in range(5):
+ sess.run(get_next.op)
+ for _ in range(100):
+ start = time.time()
+ sess.run(get_next.op)
+ end = time.time()
+ deltas.append(end - start)
+
+ print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" %
+ (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas),
+ np.max(deltas)))
+
+ def testModelParallelMap(self):
+ k = 1024 * 1024
+ dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k),
+ np.random.rand(4 * k,
+ 1))).repeat()
+ dataset = dataset.map(
+ math_ops.matmul, num_parallel_calls=optimization.AUTOTUNE)
+ iterator = dataset.apply(optimization.model()).make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ deltas = []
+ with self.cached_session() as sess:
+ for _ in range(5):
+ sess.run(get_next.op)
+ for _ in range(1000):
+ start = time.time()
+ sess.run(get_next.op)
+ end = time.time()
+ deltas.append(end - start)
+
+ print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" %
+ (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas),
+ np.max(deltas)))
+
+ def testModelMapAndBatch(self):
+ batch_size = 16
+ k = 1024 * 1024
+ dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k),
+ np.random.rand(4 * k,
+ 1))).repeat()
+ dataset = dataset.apply(
+ batching.map_and_batch(
+ math_ops.matmul,
+ num_parallel_calls=optimization.AUTOTUNE,
+ batch_size=batch_size))
+ iterator = dataset.apply(optimization.model()).make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ deltas = []
+ with self.cached_session() as sess:
+ for _ in range(5):
+ sess.run(get_next.op)
+ for _ in range(10):
+ start = time.time()
+ sess.run(get_next.op)
+ end = time.time()
+ deltas.append(end - start)
+
+ print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" %
+ (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas),
+ np.max(deltas)))
+
+ def testModelParallelInterleave(self):
+ k = 1024 * 1024
+ dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k),
+ np.random.rand(4 * k,
+ 1))).repeat()
+ dataset = dataset.map(math_ops.matmul)
+ dataset = dataset_ops.Dataset.range(1).repeat().interleave(
+ lambda _: dataset,
+ cycle_length=10,
+ num_parallel_calls=optimization.AUTOTUNE)
+ iterator = dataset.apply(optimization.model()).make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ deltas = []
+ with self.cached_session() as sess:
+ for _ in range(5):
+ sess.run(get_next.op)
+ for _ in range(1000):
+ start = time.time()
+ sess.run(get_next.op)
+ end = time.time()
+ deltas.append(end - start)
+
+ print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" %
+ (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas),
+ np.max(deltas)))
+
+ def testModelNested(self):
+ k = 1024 * 1024
+ a = (np.random.rand(1, 8 * k), np.random.rand(8 * k, 1))
+ b = (np.random.rand(1, 4 * k), np.random.rand(4 * k, 1))
+ c = (np.random.rand(1, 2 * k), np.random.rand(2 * k, 1))
+ dataset = dataset_ops.Dataset.from_tensors((a, b, c)).repeat()
+
+ def f1(a, b, c):
+ x, y = a
+ return math_ops.matmul(x, y), b, c
+
+ def f2(a, b, c):
+ x, y = b
+ return a, math_ops.matmul(x, y), c
+
+ def f3(a, b, c):
+ x, y = c
+ return a, b, math_ops.matmul(x, y)
+
+ dataset = dataset.map(f1, num_parallel_calls=optimization.AUTOTUNE)
+ dataset = dataset_ops.Dataset.range(1).repeat().interleave(
+ lambda _: dataset, cycle_length=2)
+
+ dataset = dataset.map(f2, num_parallel_calls=optimization.AUTOTUNE)
+ dataset = dataset_ops.Dataset.range(1).repeat().interleave(
+ lambda _: dataset, cycle_length=2)
+
+ dataset = dataset.map(f3, num_parallel_calls=optimization.AUTOTUNE)
+ iterator = dataset.apply(optimization.model()).make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ deltas = []
+ with self.cached_session() as sess:
+ for _ in range(5):
+ sess.run(get_next)
+ for _ in range(100):
+ start = time.time()
+ sess.run(get_next)
+ end = time.time()
+ deltas.append(end - start)
+
+ print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" %
+ (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas),
+ np.max(deltas)))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/noop_elimination_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/noop_elimination_test.py
new file mode 100644
index 0000000000..092e0ff62a
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/noop_elimination_test.py
@@ -0,0 +1,58 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the MapParallelization optimization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.ops import optimization
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class NoopEliminationTest(test_base.DatasetTestBase):
+
+ def testNoopElimination(self):
+ a = constant_op.constant(1, dtype=dtypes.int64)
+ b = constant_op.constant(2, dtype=dtypes.int64)
+ some_tensor = math_ops.mul(a, b)
+
+ dataset = dataset_ops.Dataset.range(5)
+ dataset = dataset.apply(
+ optimization.assert_next(
+ ["FiniteRepeat", "FiniteSkip", "Prefetch", "Prefetch"]))
+ dataset = dataset.repeat(some_tensor).skip(5).prefetch(0).take(-1).skip(
+ 0).repeat(1).prefetch(0)
+ dataset = dataset.apply(optimization.optimize(["noop_elimination"]))
+
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ for x in range(5):
+ result = sess.run(get_next)
+ self.assertAllEqual(result, x)
+
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_op_test.py
new file mode 100644
index 0000000000..eb661796c0
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_op_test.py
@@ -0,0 +1,109 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.ops import optimization
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.platform import test
+
+
+class OptimizeDatasetTest(test_base.DatasetTestBase):
+
+ def testOptimizationDefault(self):
+ dataset = dataset_ops.Dataset.range(10).apply(
+ optimization.assert_next(
+ ["Map", "Batch"])).map(lambda x: x * x).batch(10).apply(
+ optimization.optimize())
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testOptimizationEmpty(self):
+ dataset = dataset_ops.Dataset.range(10).apply(
+ optimization.assert_next(
+ ["Map", "Batch"])).map(lambda x: x * x).batch(10).apply(
+ optimization.optimize([]))
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testOptimizationFusion(self):
+ dataset = dataset_ops.Dataset.range(10).apply(
+ optimization.assert_next(
+ ["MapAndBatch"])).map(lambda x: x * x).batch(10).apply(
+ optimization.optimize(["map_and_batch_fusion"]))
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testOptimizationStatefulFunction(self):
+ dataset = dataset_ops.Dataset.range(10).map(
+ lambda _: random_ops.random_uniform([])).batch(10).apply(
+ optimization.optimize(["map_and_batch_fusion"]))
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(get_next)
+
+ def testOptimizationLargeInputFromTensor(self):
+ input_t = array_ops.placeholder(dtypes.int32, (None, None, None))
+ dataset = dataset_ops.Dataset.from_tensors(input_t).apply(
+ optimization.optimize())
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op, {input_t: np.ones([512, 1024, 1025], np.int32)})
+ sess.run(get_next)
+
+ def testOptimizationLargeInputFromTensorSlices(self):
+ input_t = array_ops.placeholder(dtypes.int32, (None, None, None, None))
+ dataset = dataset_ops.Dataset.from_tensor_slices(input_t).apply(
+ optimization.optimize())
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op, {input_t: np.ones([1, 512, 1024, 1025], np.int32)})
+ sess.run(get_next)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/parsing_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/parsing_ops_test.py
new file mode 100644
index 0000000000..13f924b656
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/parsing_ops_test.py
@@ -0,0 +1,850 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.ops.parsing_ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import copy
+
+import numpy as np
+
+from tensorflow.core.example import example_pb2
+from tensorflow.core.example import feature_pb2
+from tensorflow.python.data.experimental.ops import parsing_ops as contrib_parsing_ops
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import nest
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors_impl
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import parsing_ops
+from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging
+
+# Helpers for creating Example objects
+example = example_pb2.Example
+feature = feature_pb2.Feature
+features = lambda d: feature_pb2.Features(feature=d)
+bytes_feature = lambda v: feature(bytes_list=feature_pb2.BytesList(value=v))
+int64_feature = lambda v: feature(int64_list=feature_pb2.Int64List(value=v))
+float_feature = lambda v: feature(float_list=feature_pb2.FloatList(value=v))
+# Helpers for creating SequenceExample objects
+feature_list = lambda l: feature_pb2.FeatureList(feature=l)
+feature_lists = lambda d: feature_pb2.FeatureLists(feature_list=d)
+sequence_example = example_pb2.SequenceExample
+
+
+def _compare_output_to_expected(tester, dict_tensors, expected_tensors,
+ flat_output):
+ tester.assertEqual(set(dict_tensors.keys()), set(expected_tensors.keys()))
+
+ i = 0 # Index into the flattened output of session.run()
+ for k, v in sorted(dict_tensors.items()):
+ # TODO(shivaniagrawal): flat_output is same as v.
+ expected_v = expected_tensors[k]
+ tf_logging.info("Comparing key: %s", k)
+ print("i", i, "flat_output", flat_output[i], "expected_v", expected_v)
+ if sparse_tensor.is_sparse(v):
+ # Three outputs for SparseTensor : indices, values, shape.
+ tester.assertEqual([k, len(expected_v)], [k, 3])
+ print("i", i, "flat_output", flat_output[i].indices, "expected_v",
+ expected_v[0])
+ tester.assertAllEqual(expected_v[0], flat_output[i].indices)
+ tester.assertAllEqual(expected_v[1], flat_output[i].values)
+ tester.assertAllEqual(expected_v[2], flat_output[i].dense_shape)
+ else:
+ # One output for standard Tensor.
+ tester.assertAllEqual(expected_v, flat_output[i])
+ i += 1
+
+
+class ParseExampleTest(test_base.DatasetTestBase):
+
+ def _test(self,
+ input_tensor,
+ feature_val,
+ expected_values=None,
+ expected_err=None):
+
+ with self.cached_session() as sess:
+ if expected_err:
+ with self.assertRaisesWithPredicateMatch(expected_err[0],
+ expected_err[1]):
+ dataset = dataset_ops.Dataset.from_tensors(input_tensor).apply(
+ contrib_parsing_ops.parse_example_dataset(feature_val))
+ get_next = dataset.make_one_shot_iterator().get_next()
+ sess.run(get_next)
+ return
+ else:
+ # Returns dict w/ Tensors and SparseTensors.
+ # Check values.
+ dataset = dataset_ops.Dataset.from_tensors(input_tensor).apply(
+ contrib_parsing_ops.parse_example_dataset(feature_val))
+ get_next = dataset.make_one_shot_iterator().get_next()
+ result = sess.run(get_next)
+ flattened = nest.flatten(result)
+ print("result", result, "expected_values", expected_values)
+ _compare_output_to_expected(self, result, expected_values, flattened)
+
+ # Check shapes; if serialized is a Tensor we need its size to
+ # properly check.
+ batch_size = (
+ input_tensor.eval().size if isinstance(input_tensor, ops.Tensor) else
+ np.asarray(input_tensor).size)
+ for k, f in feature_val.items():
+ print("output_shapes as list ",
+ tuple(dataset.output_shapes[k].as_list()))
+ if isinstance(f, parsing_ops.FixedLenFeature) and f.shape is not None:
+ self.assertEqual(dataset.output_shapes[k].as_list()[0], batch_size)
+ elif isinstance(f, parsing_ops.VarLenFeature):
+ self.assertEqual(dataset.output_shapes[k].as_list()[1], None)
+
+ def testEmptySerializedWithAllDefaults(self):
+ sparse_name = "st_a"
+ a_name = "a"
+ b_name = "b"
+ c_name = "c:has_a_tricky_name"
+ a_default = [0, 42, 0]
+ b_default = np.random.rand(3, 3).astype(bytes)
+ c_default = np.random.rand(2).astype(np.float32)
+
+ expected_st_a = ( # indices, values, shape
+ np.empty(
+ (0, 2), dtype=np.int64), # indices
+ np.empty(
+ (0,), dtype=np.int64), # sp_a is DT_INT64
+ np.array(
+ [2, 0], dtype=np.int64)) # batch == 2, max_elems = 0
+
+ expected_output = {
+ sparse_name: expected_st_a,
+ a_name: np.array(2 * [[a_default]]),
+ b_name: np.array(2 * [b_default]),
+ c_name: np.array(2 * [c_default]),
+ }
+
+ self._test(
+ ops.convert_to_tensor(["", ""]), {
+ sparse_name:
+ parsing_ops.VarLenFeature(dtypes.int64),
+ a_name:
+ parsing_ops.FixedLenFeature(
+ (1, 3), dtypes.int64, default_value=a_default),
+ b_name:
+ parsing_ops.FixedLenFeature(
+ (3, 3), dtypes.string, default_value=b_default),
+ c_name:
+ parsing_ops.FixedLenFeature(
+ (2,), dtypes.float32, default_value=c_default),
+ },
+ expected_values=expected_output)
+
+ def testEmptySerializedWithoutDefaultsShouldFail(self):
+ input_features = {
+ "st_a":
+ parsing_ops.VarLenFeature(dtypes.int64),
+ "a":
+ parsing_ops.FixedLenFeature(
+ (1, 3), dtypes.int64, default_value=[0, 42, 0]),
+ "b":
+ parsing_ops.FixedLenFeature(
+ (3, 3),
+ dtypes.string,
+ default_value=np.random.rand(3, 3).astype(bytes)),
+ # Feature "c" is missing a default, this gap will cause failure.
+ "c":
+ parsing_ops.FixedLenFeature(
+ (2,), dtype=dtypes.float32),
+ }
+
+ # Edge case where the key is there but the feature value is empty
+ original = example(features=features({"c": feature()}))
+ self._test(
+ [original.SerializeToString()],
+ input_features,
+ expected_err=(errors_impl.InvalidArgumentError,
+ "Feature: c \\(data type: float\\) is required"))
+
+ # Standard case of missing key and value.
+ self._test(
+ ["", ""],
+ input_features,
+ expected_err=(errors_impl.InvalidArgumentError,
+ "Feature: c \\(data type: float\\) is required"))
+
+ def testDenseNotMatchingShapeShouldFail(self):
+ original = [
+ example(features=features({
+ "a": float_feature([1, 1, 3]),
+ })), example(features=features({
+ "a": float_feature([-1, -1]),
+ }))
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ self._test(
+ ops.convert_to_tensor(serialized),
+ {"a": parsing_ops.FixedLenFeature((1, 3), dtypes.float32)},
+ expected_err=(errors_impl.InvalidArgumentError,
+ "Key: a, Index: 1. Number of float values"))
+
+ def testDenseDefaultNoShapeShouldFail(self):
+ original = [example(features=features({"a": float_feature([1, 1, 3]),})),]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ self._test(
+ ops.convert_to_tensor(serialized),
+ {"a": parsing_ops.FixedLenFeature(None, dtypes.float32)},
+ expected_err=(ValueError, "Missing shape for feature a"))
+
+ def testSerializedContainingSparse(self):
+ original = [
+ example(features=features({
+ "st_c": float_feature([3, 4])
+ })),
+ example(features=features({
+ "st_c": float_feature([]), # empty float list
+ })),
+ example(features=features({
+ "st_d": feature(), # feature with nothing in it
+ })),
+ example(features=features({
+ "st_c": float_feature([1, 2, -1]),
+ "st_d": bytes_feature([b"hi"])
+ }))
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ expected_st_c = ( # indices, values, shape
+ np.array(
+ [[0, 0], [0, 1], [3, 0], [3, 1], [3, 2]], dtype=np.int64), np.array(
+ [3.0, 4.0, 1.0, 2.0, -1.0], dtype=np.float32), np.array(
+ [4, 3], dtype=np.int64)) # batch == 2, max_elems = 3
+
+ expected_st_d = ( # indices, values, shape
+ np.array(
+ [[3, 0]], dtype=np.int64), np.array(
+ ["hi"], dtype=bytes), np.array(
+ [4, 1], dtype=np.int64)) # batch == 2, max_elems = 1
+
+ expected_output = {
+ "st_c": expected_st_c,
+ "st_d": expected_st_d,
+ }
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ "st_c": parsing_ops.VarLenFeature(dtypes.float32),
+ "st_d": parsing_ops.VarLenFeature(dtypes.string)
+ },
+ expected_values=expected_output)
+
+ def testSerializedContainingSparseFeature(self):
+ original = [
+ example(features=features({
+ "val": float_feature([3, 4]),
+ "idx": int64_feature([5, 10])
+ })),
+ example(features=features({
+ "val": float_feature([]), # empty float list
+ "idx": int64_feature([])
+ })),
+ example(features=features({
+ "val": feature(), # feature with nothing in it
+ # missing idx feature
+ })),
+ example(features=features({
+ "val": float_feature([1, 2, -1]),
+ "idx":
+ int64_feature([0, 9, 3]) # unsorted
+ }))
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ expected_sp = ( # indices, values, shape
+ np.array(
+ [[0, 5], [0, 10], [3, 0], [3, 3], [3, 9]], dtype=np.int64),
+ np.array(
+ [3.0, 4.0, 1.0, -1.0, 2.0], dtype=np.float32), np.array(
+ [4, 13], dtype=np.int64)) # batch == 4, max_elems = 13
+
+ expected_output = {"sp": expected_sp,}
+
+ self._test(
+ ops.convert_to_tensor(serialized),
+ {"sp": parsing_ops.SparseFeature(["idx"], "val", dtypes.float32, [13])},
+ expected_values=expected_output)
+
+ def testSerializedContainingSparseFeatureReuse(self):
+ original = [
+ example(features=features({
+ "val1": float_feature([3, 4]),
+ "val2": float_feature([5, 6]),
+ "idx": int64_feature([5, 10])
+ })),
+ example(features=features({
+ "val1": float_feature([]), # empty float list
+ "idx": int64_feature([])
+ })),
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ expected_sp1 = ( # indices, values, shape
+ np.array(
+ [[0, 5], [0, 10]], dtype=np.int64), np.array(
+ [3.0, 4.0], dtype=np.float32), np.array(
+ [2, 13], dtype=np.int64)) # batch == 2, max_elems = 13
+
+ expected_sp2 = ( # indices, values, shape
+ np.array(
+ [[0, 5], [0, 10]], dtype=np.int64), np.array(
+ [5.0, 6.0], dtype=np.float32), np.array(
+ [2, 7], dtype=np.int64)) # batch == 2, max_elems = 13
+
+ expected_output = {
+ "sp1": expected_sp1,
+ "sp2": expected_sp2,
+ }
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ "sp1":
+ parsing_ops.SparseFeature("idx", "val1", dtypes.float32, 13),
+ "sp2":
+ parsing_ops.SparseFeature(
+ "idx", "val2", dtypes.float32, size=7, already_sorted=True)
+ },
+ expected_values=expected_output)
+
+ def testSerializedContaining3DSparseFeature(self):
+ original = [
+ example(features=features({
+ "val": float_feature([3, 4]),
+ "idx0": int64_feature([5, 10]),
+ "idx1": int64_feature([0, 2]),
+ })),
+ example(features=features({
+ "val": float_feature([]), # empty float list
+ "idx0": int64_feature([]),
+ "idx1": int64_feature([]),
+ })),
+ example(features=features({
+ "val": feature(), # feature with nothing in it
+ # missing idx feature
+ })),
+ example(features=features({
+ "val": float_feature([1, 2, -1]),
+ "idx0": int64_feature([0, 9, 3]), # unsorted
+ "idx1": int64_feature([1, 0, 2]),
+ }))
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ expected_sp = (
+ # indices
+ np.array(
+ [[0, 5, 0], [0, 10, 2], [3, 0, 1], [3, 3, 2], [3, 9, 0]],
+ dtype=np.int64),
+ # values
+ np.array([3.0, 4.0, 1.0, -1.0, 2.0], dtype=np.float32),
+ # shape batch == 4, max_elems = 13
+ np.array([4, 13, 3], dtype=np.int64))
+
+ expected_output = {"sp": expected_sp,}
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ "sp":
+ parsing_ops.SparseFeature(["idx0", "idx1"], "val",
+ dtypes.float32, [13, 3])
+ },
+ expected_values=expected_output)
+
+ def testSerializedContainingDense(self):
+ aname = "a"
+ bname = "b*has+a:tricky_name"
+ original = [
+ example(features=features({
+ aname: float_feature([1, 1]),
+ bname: bytes_feature([b"b0_str"]),
+ })), example(features=features({
+ aname: float_feature([-1, -1]),
+ bname: bytes_feature([b""]),
+ }))
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ expected_output = {
+ aname:
+ np.array(
+ [[1, 1], [-1, -1]], dtype=np.float32).reshape(2, 1, 2, 1),
+ bname:
+ np.array(
+ ["b0_str", ""], dtype=bytes).reshape(2, 1, 1, 1, 1),
+ }
+
+ # No defaults, values required
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ aname:
+ parsing_ops.FixedLenFeature((1, 2, 1), dtype=dtypes.float32),
+ bname:
+ parsing_ops.FixedLenFeature((1, 1, 1, 1), dtype=dtypes.string),
+ },
+ expected_values=expected_output)
+
+ # This test is identical as the previous one except
+ # for the creation of 'serialized'.
+ def testSerializedContainingDenseWithConcat(self):
+ aname = "a"
+ bname = "b*has+a:tricky_name"
+ # TODO(lew): Feature appearing twice should be an error in future.
+ original = [
+ (example(features=features({
+ aname: float_feature([10, 10]),
+ })), example(features=features({
+ aname: float_feature([1, 1]),
+ bname: bytes_feature([b"b0_str"]),
+ }))),
+ (
+ example(features=features({
+ bname: bytes_feature([b"b100"]),
+ })),
+ example(features=features({
+ aname: float_feature([-1, -1]),
+ bname: bytes_feature([b"b1"]),
+ })),),
+ ]
+
+ serialized = [
+ m.SerializeToString() + n.SerializeToString() for (m, n) in original
+ ]
+
+ expected_output = {
+ aname:
+ np.array(
+ [[1, 1], [-1, -1]], dtype=np.float32).reshape(2, 1, 2, 1),
+ bname:
+ np.array(
+ ["b0_str", "b1"], dtype=bytes).reshape(2, 1, 1, 1, 1),
+ }
+
+ # No defaults, values required
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ aname:
+ parsing_ops.FixedLenFeature((1, 2, 1), dtype=dtypes.float32),
+ bname:
+ parsing_ops.FixedLenFeature((1, 1, 1, 1), dtype=dtypes.string),
+ },
+ expected_values=expected_output)
+
+ def testSerializedContainingDenseScalar(self):
+ original = [
+ example(features=features({
+ "a": float_feature([1]),
+ })), example(features=features({}))
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ expected_output = {
+ "a":
+ np.array(
+ [[1], [-1]], dtype=np.float32) # 2x1 (column vector)
+ }
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ "a":
+ parsing_ops.FixedLenFeature(
+ (1,), dtype=dtypes.float32, default_value=-1),
+ },
+ expected_values=expected_output)
+
+ def testSerializedContainingDenseWithDefaults(self):
+ original = [
+ example(features=features({
+ "a": float_feature([1, 1]),
+ })),
+ example(features=features({
+ "b": bytes_feature([b"b1"]),
+ })),
+ example(features=features({
+ "b": feature()
+ })),
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ expected_output = {
+ "a":
+ np.array(
+ [[1, 1], [3, -3], [3, -3]], dtype=np.float32).reshape(3, 1, 2,
+ 1),
+ "b":
+ np.array(
+ ["tmp_str", "b1", "tmp_str"], dtype=bytes).reshape(3, 1, 1, 1,
+ 1),
+ }
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ "a":
+ parsing_ops.FixedLenFeature(
+ (1, 2, 1), dtype=dtypes.float32, default_value=[3.0, -3.0]),
+ "b":
+ parsing_ops.FixedLenFeature(
+ (1, 1, 1, 1), dtype=dtypes.string, default_value="tmp_str"),
+ },
+ expected_values=expected_output)
+
+ def testSerializedContainingSparseAndSparseFeatureAndDenseWithNoDefault(self):
+ expected_st_a = ( # indices, values, shape
+ np.empty(
+ (0, 2), dtype=np.int64), # indices
+ np.empty(
+ (0,), dtype=np.int64), # sp_a is DT_INT64
+ np.array(
+ [2, 0], dtype=np.int64)) # batch == 2, max_elems = 0
+ expected_sp = ( # indices, values, shape
+ np.array(
+ [[0, 0], [0, 3], [1, 7]], dtype=np.int64), np.array(
+ ["a", "b", "c"], dtype="|S"), np.array(
+ [2, 13], dtype=np.int64)) # batch == 4, max_elems = 13
+
+ original = [
+ example(features=features({
+ "c": float_feature([3, 4]),
+ "val": bytes_feature([b"a", b"b"]),
+ "idx": int64_feature([0, 3])
+ })), example(features=features({
+ "c": float_feature([1, 2]),
+ "val": bytes_feature([b"c"]),
+ "idx": int64_feature([7])
+ }))
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ a_default = [1, 2, 3]
+ b_default = np.random.rand(3, 3).astype(bytes)
+ expected_output = {
+ "st_a": expected_st_a,
+ "sp": expected_sp,
+ "a": np.array(2 * [[a_default]]),
+ "b": np.array(2 * [b_default]),
+ "c": np.array(
+ [[3, 4], [1, 2]], dtype=np.float32),
+ }
+
+ self._test(
+ ops.convert_to_tensor(serialized),
+ {
+ "st_a":
+ parsing_ops.VarLenFeature(dtypes.int64),
+ "sp":
+ parsing_ops.SparseFeature("idx", "val", dtypes.string, 13),
+ "a":
+ parsing_ops.FixedLenFeature(
+ (1, 3), dtypes.int64, default_value=a_default),
+ "b":
+ parsing_ops.FixedLenFeature(
+ (3, 3), dtypes.string, default_value=b_default),
+ # Feature "c" must be provided, since it has no default_value.
+ "c":
+ parsing_ops.FixedLenFeature((2,), dtypes.float32),
+ },
+ expected_values=expected_output)
+
+ def testSerializedContainingSparseAndSparseFeatureWithReuse(self):
+ expected_idx = ( # indices, values, shape
+ np.array(
+ [[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.int64),
+ np.array([0, 3, 7, 1]), np.array(
+ [2, 2], dtype=np.int64)) # batch == 4, max_elems = 2
+
+ expected_sp = ( # indices, values, shape
+ np.array(
+ [[0, 0], [0, 3], [1, 1], [1, 7]], dtype=np.int64), np.array(
+ ["a", "b", "d", "c"], dtype="|S"), np.array(
+ [2, 13], dtype=np.int64)) # batch == 4, max_elems = 13
+
+ original = [
+ example(features=features({
+ "val": bytes_feature([b"a", b"b"]),
+ "idx": int64_feature([0, 3])
+ })), example(features=features({
+ "val": bytes_feature([b"c", b"d"]),
+ "idx": int64_feature([7, 1])
+ }))
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ expected_output = {
+ "idx": expected_idx,
+ "sp": expected_sp,
+ }
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ "idx":
+ parsing_ops.VarLenFeature(dtypes.int64),
+ "sp":
+ parsing_ops.SparseFeature(["idx"], "val", dtypes.string, [13]),
+ },
+ expected_values=expected_output)
+
+ def _testSerializedContainingVarLenDenseLargerBatch(self, batch_size):
+ # During parsing, data read from the serialized proto is stored in buffers.
+ # For small batch sizes, a buffer will contain one minibatch entry.
+ # For larger batch sizes, a buffer may contain several minibatch
+ # entries. This test identified a bug where the code that copied
+ # data out of the buffers and into the output tensors assumed each
+ # buffer only contained one minibatch entry. The bug has since been fixed.
+ truth_int = [i for i in range(batch_size)]
+ truth_str = [[("foo%d" % i).encode(), ("bar%d" % i).encode()]
+ for i in range(batch_size)]
+
+ expected_str = copy.deepcopy(truth_str)
+
+ # Delete some intermediate entries
+ for i in range(batch_size):
+ col = 1
+ if np.random.rand() < 0.25:
+ # w.p. 25%, drop out the second entry
+ expected_str[i][col] = b"default"
+ col -= 1
+ truth_str[i].pop()
+ if np.random.rand() < 0.25:
+ # w.p. 25%, drop out the second entry (possibly again)
+ expected_str[i][col] = b"default"
+ truth_str[i].pop()
+
+ expected_output = {
+ # Batch size batch_size, 1 time step.
+ "a": np.array(truth_int, dtype=np.int64).reshape(batch_size, 1),
+ # Batch size batch_size, 2 time steps.
+ "b": np.array(expected_str, dtype="|S").reshape(batch_size, 2),
+ }
+
+ original = [
+ example(features=features(
+ {"a": int64_feature([truth_int[i]]),
+ "b": bytes_feature(truth_str[i])}))
+ for i in range(batch_size)
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ self._test(
+ ops.convert_to_tensor(serialized, dtype=dtypes.string), {
+ "a":
+ parsing_ops.FixedLenSequenceFeature(
+ shape=(),
+ dtype=dtypes.int64,
+ allow_missing=True,
+ default_value=-1),
+ "b":
+ parsing_ops.FixedLenSequenceFeature(
+ shape=[],
+ dtype=dtypes.string,
+ allow_missing=True,
+ default_value="default"),
+ },
+ expected_values=expected_output)
+
+ def testSerializedContainingVarLenDenseLargerBatch(self):
+ np.random.seed(3456)
+ for batch_size in (1, 10, 20, 100, 256):
+ self._testSerializedContainingVarLenDenseLargerBatch(batch_size)
+
+ def testSerializedContainingVarLenDense(self):
+ aname = "a"
+ bname = "b"
+ cname = "c"
+ dname = "d"
+ original = [
+ example(features=features({
+ cname: int64_feature([2]),
+ })),
+ example(features=features({
+ aname: float_feature([1, 1]),
+ bname: bytes_feature([b"b0_str", b"b1_str"]),
+ })),
+ example(features=features({
+ aname: float_feature([-1, -1, 2, 2]),
+ bname: bytes_feature([b"b1"]),
+ })),
+ example(features=features({
+ aname: float_feature([]),
+ cname: int64_feature([3]),
+ })),
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ expected_output = {
+ aname:
+ np.array(
+ [
+ [0, 0, 0, 0],
+ [1, 1, 0, 0],
+ [-1, -1, 2, 2],
+ [0, 0, 0, 0],
+ ],
+ dtype=np.float32).reshape(4, 2, 2, 1),
+ bname:
+ np.array(
+ [["", ""], ["b0_str", "b1_str"], ["b1", ""], ["", ""]],
+ dtype=bytes).reshape(4, 2, 1, 1, 1),
+ cname:
+ np.array([2, 0, 0, 3], dtype=np.int64).reshape(4, 1),
+ dname:
+ np.empty(shape=(4, 0), dtype=bytes),
+ }
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ aname:
+ parsing_ops.FixedLenSequenceFeature(
+ (2, 1), dtype=dtypes.float32, allow_missing=True),
+ bname:
+ parsing_ops.FixedLenSequenceFeature(
+ (1, 1, 1), dtype=dtypes.string, allow_missing=True),
+ cname:
+ parsing_ops.FixedLenSequenceFeature(
+ shape=[], dtype=dtypes.int64, allow_missing=True),
+ dname:
+ parsing_ops.FixedLenSequenceFeature(
+ shape=[], dtype=dtypes.string, allow_missing=True),
+ },
+ expected_values=expected_output)
+
+ # Test with padding values.
+ expected_output_custom_padding = dict(expected_output)
+ expected_output_custom_padding[aname] = np.array(
+ [
+ [-2, -2, -2, -2],
+ [1, 1, -2, -2],
+ [-1, -1, 2, 2],
+ [-2, -2, -2, -2],
+ ],
+ dtype=np.float32).reshape(4, 2, 2, 1)
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ aname:
+ parsing_ops.FixedLenSequenceFeature(
+ (2, 1),
+ dtype=dtypes.float32,
+ allow_missing=True,
+ default_value=-2.0),
+ bname:
+ parsing_ops.FixedLenSequenceFeature(
+ (1, 1, 1), dtype=dtypes.string, allow_missing=True),
+ cname:
+ parsing_ops.FixedLenSequenceFeature(
+ shape=[], dtype=dtypes.int64, allow_missing=True),
+ dname:
+ parsing_ops.FixedLenSequenceFeature(
+ shape=[], dtype=dtypes.string, allow_missing=True),
+ }, expected_output_custom_padding)
+
+ # Change number of required values so the inputs are not a
+ # multiple of this size.
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ aname:
+ parsing_ops.FixedLenSequenceFeature(
+ (2, 1), dtype=dtypes.float32, allow_missing=True),
+ bname:
+ parsing_ops.FixedLenSequenceFeature(
+ (2, 1, 1), dtype=dtypes.string, allow_missing=True),
+ },
+ expected_err=(
+ errors_impl.OpError, "Key: b, Index: 2. "
+ "Number of bytes values is not a multiple of stride length."))
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ aname:
+ parsing_ops.FixedLenSequenceFeature(
+ (2, 1),
+ dtype=dtypes.float32,
+ allow_missing=True,
+ default_value=[]),
+ bname:
+ parsing_ops.FixedLenSequenceFeature(
+ (2, 1, 1), dtype=dtypes.string, allow_missing=True),
+ },
+ expected_err=(ValueError,
+ "Cannot reshape a tensor with 0 elements to shape"))
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ aname:
+ parsing_ops.FixedLenFeature((None, 2, 1), dtype=dtypes.float32),
+ bname:
+ parsing_ops.FixedLenSequenceFeature(
+ (2, 1, 1), dtype=dtypes.string, allow_missing=True),
+ },
+ expected_err=(ValueError,
+ "First dimension of shape for feature a unknown. "
+ "Consider using FixedLenSequenceFeature."))
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ cname:
+ parsing_ops.FixedLenFeature(
+ (1, None), dtype=dtypes.int64, default_value=[[1]]),
+ },
+ expected_err=(ValueError,
+ "All dimensions of shape for feature c need to be known "
+ r"but received \(1, None\)."))
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ aname:
+ parsing_ops.FixedLenSequenceFeature(
+ (2, 1), dtype=dtypes.float32, allow_missing=True),
+ bname:
+ parsing_ops.FixedLenSequenceFeature(
+ (1, 1, 1), dtype=dtypes.string, allow_missing=True),
+ cname:
+ parsing_ops.FixedLenSequenceFeature(
+ shape=[], dtype=dtypes.int64, allow_missing=False),
+ dname:
+ parsing_ops.FixedLenSequenceFeature(
+ shape=[], dtype=dtypes.string, allow_missing=True),
+ },
+ expected_err=(ValueError,
+ "Unsupported: FixedLenSequenceFeature requires "
+ "allow_missing to be True."))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/prefetching_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/prefetching_ops_test.py
new file mode 100644
index 0000000000..7d7b842c17
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/prefetching_ops_test.py
@@ -0,0 +1,948 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for prefetching_ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import threading
+
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.compat import compat
+from tensorflow.python.data.experimental.ops import prefetching_ops
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import function
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.platform import test
+
+
+class PrefetchingKernelsOpsTest(test_base.DatasetTestBase):
+
+ def setUp(self):
+ self._event = threading.Event()
+
+ def _create_ds_and_iterator(self, device0, initializable=False):
+
+ def gen():
+ for i in range(1, 10):
+ yield [float(i)]
+ if i == 6:
+ self._event.set()
+
+ with ops.device(device0):
+ ds = dataset_ops.Dataset.from_generator(gen, (dtypes.float32))
+ if initializable:
+ ds_iterator = ds.make_initializable_iterator()
+ else:
+ ds_iterator = ds.make_one_shot_iterator()
+ return (ds, ds_iterator)
+
+ def _create_ops(self, ds, ds_iterator, buffer_name, device0, device1):
+ ds_iterator_handle = ds_iterator.string_handle()
+
+ @function.Defun(dtypes.string)
+ def _remote_fn(h):
+ remote_iterator = iterator_ops.Iterator.from_string_handle(
+ h, ds.output_types, ds.output_shapes)
+ return remote_iterator.get_next()
+
+ target = constant_op.constant(device0)
+ with ops.device(device1):
+ buffer_resource_handle = prefetching_ops.function_buffering_resource(
+ f=_remote_fn,
+ output_types=[dtypes.float32],
+ target_device=target,
+ string_arg=ds_iterator_handle,
+ buffer_size=3,
+ shared_name=buffer_name)
+
+ with ops.device(device1):
+ prefetch_op = prefetching_ops.function_buffering_resource_get_next(
+ function_buffer_resource=buffer_resource_handle,
+ output_types=[dtypes.float32])
+ reset_op = prefetching_ops.function_buffering_resource_reset(
+ function_buffer_resource=buffer_resource_handle)
+ destroy_op = resource_variable_ops.destroy_resource_op(
+ buffer_resource_handle, ignore_lookup_error=True)
+
+ return (prefetch_op, reset_op, destroy_op)
+
+ def _prefetch_fn_helper_one_shot(self, buffer_name, device0, device1):
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+
+ ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=False)
+ prefetch_op, _, destroy_op = self._create_ops(ds, ds_iterator, buffer_name,
+ device0, device1)
+
+ with self.test_session(config=worker_config) as sess:
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [1.0])
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [2.0])
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [3.0])
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [4.0])
+ self._event.wait()
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [5.0])
+ sess.run(destroy_op)
+
+ def testSameDeviceCPU(self):
+ self._prefetch_fn_helper_one_shot("same_device_cpu",
+ "/job:localhost/replica:0/task:0/cpu:0",
+ "/job:localhost/replica:0/task:0/cpu:0")
+
+ def testDifferentDeviceCPU(self):
+ self._prefetch_fn_helper_one_shot("diff_device_cpu",
+ "/job:localhost/replica:0/task:0/cpu:0",
+ "/job:localhost/replica:0/task:0/cpu:1")
+
+ def testDifferentDeviceCPUGPU(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ self._prefetch_fn_helper_one_shot("cpu_gpu",
+ "/job:localhost/replica:0/task:0/cpu:0",
+ "/job:localhost/replica:0/task:0/gpu:0")
+
+ def testReinitialization(self):
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+
+ device0 = "/job:localhost/replica:0/task:0/cpu:0"
+ device1 = "/job:localhost/replica:0/task:0/cpu:1"
+ ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=True)
+ prefetch_op, reset_op, destroy_op = self._create_ops(
+ ds, ds_iterator, "reinit", device0, device1)
+
+ with self.test_session(config=worker_config) as sess:
+ sess.run(ds_iterator.initializer)
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [1.0])
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [2.0])
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [3.0])
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [4.0])
+ self._event.wait()
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [5.0])
+ # Lets reset the function buffering resource and reinitialize the
+ # iterator. Should be able to go through this again.
+ self._event.clear()
+ sess.run(reset_op)
+ sess.run(ds_iterator.initializer)
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [1.0])
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [2.0])
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [3.0])
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [4.0])
+ self._event.wait()
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [5.0])
+ sess.run(destroy_op)
+
+ def testReinitializationOutOfRange(self):
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+
+ device0 = "/job:localhost/replica:0/task:0/cpu:0"
+ device1 = "/job:localhost/replica:0/task:0/cpu:1"
+ ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=True)
+ prefetch_op, reset_op, destroy_op = self._create_ops(
+ ds, ds_iterator, "reinit", device0, device1)
+
+ with self.test_session(config=worker_config) as sess:
+ sess.run(ds_iterator.initializer)
+ for i in range(1, 10):
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [float(i)])
+ # Try fetching after its over twice to test out end of sequence.
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(prefetch_op)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(prefetch_op)
+
+ # Now reset everything and try it out again.
+ self._event.clear()
+ sess.run(reset_op)
+ sess.run(ds_iterator.initializer)
+ for i in range(1, 10):
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [float(i)])
+ # Try fetching after its over twice to test out end of sequence.
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(prefetch_op)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(prefetch_op)
+
+ sess.run(destroy_op)
+
+ def testStringsGPU(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ device0 = "/job:localhost/replica:0/task:0/cpu:0"
+ device1 = "/job:localhost/replica:0/task:0/gpu:0"
+
+ ds = dataset_ops.Dataset.from_tensor_slices(["a", "b", "c"])
+ ds_iterator = ds.make_one_shot_iterator()
+ ds_iterator_handle = ds_iterator.string_handle()
+
+ @function.Defun(dtypes.string)
+ def _remote_fn(h):
+ remote_iterator = iterator_ops.Iterator.from_string_handle(
+ h, ds.output_types, ds.output_shapes)
+ return remote_iterator.get_next()
+
+ target = constant_op.constant(device0)
+ with ops.device(device1):
+ buffer_resource_handle = prefetching_ops.function_buffering_resource(
+ f=_remote_fn,
+ output_types=[dtypes.string],
+ target_device=target,
+ string_arg=ds_iterator_handle,
+ buffer_size=3,
+ shared_name="strings")
+
+ with ops.device(device1):
+ prefetch_op = prefetching_ops.function_buffering_resource_get_next(
+ function_buffer_resource=buffer_resource_handle,
+ output_types=[dtypes.string])
+ destroy_op = resource_variable_ops.destroy_resource_op(
+ buffer_resource_handle, ignore_lookup_error=True)
+
+ with self.cached_session() as sess:
+ self.assertEqual([b"a"], sess.run(prefetch_op))
+ self.assertEqual([b"b"], sess.run(prefetch_op))
+ self.assertEqual([b"c"], sess.run(prefetch_op))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(prefetch_op)
+
+ sess.run(destroy_op)
+
+
+class PrefetchToDeviceTest(test_base.DatasetTestBase):
+
+ def testPrefetchToDevice(self):
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.prefetch_to_device("/cpu:1"))
+
+ # NOTE(mrry): This device block creates the "host" dataset and iterator on
+ # /cpu:0, and ensures that the prefetching is across devices. In typical use
+ # this would not be necessary, because the GPU device would not support any
+ # of the dataset-related ops.
+ with ops.device("/cpu:0"):
+ iterator = device_dataset.make_one_shot_iterator()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ next_element = iterator.get_next()
+ self.assertEqual(dtypes.int64, next_element.dtype)
+ self.assertEqual([], next_element.shape)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testPrefetchToSameDevice(self):
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.prefetch_to_device(
+ "/job:localhost/replica:0/task:0/device:CPU:0"))
+
+ # NOTE(mrry): This device block creates the "host" dataset and iterator on
+ # /cpu:0, and ensures that the prefetching is across devices. In typical use
+ # this would not be necessary, because the GPU device would not support any
+ # of the dataset-related ops.
+ with ops.device("/cpu:0"):
+ iterator = device_dataset.make_one_shot_iterator()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ next_element = iterator.get_next()
+ self.assertEqual(dtypes.int64, next_element.dtype)
+ self.assertEqual([], next_element.shape)
+
+ with self.cached_session() as sess:
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testPrefetchDictToDevice(self):
+ host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
+ device_dataset = host_dataset.apply(
+ prefetching_ops.prefetch_to_device("/cpu:1"))
+
+ # NOTE(mrry): This device block creates the "host" dataset and iterator on
+ # /cpu:0, and ensures that the prefetching is across devices. In typical use
+ # this would not be necessary, because the GPU device would not support any
+ # of the dataset-related ops.
+ with ops.device("/cpu:0"):
+ iterator = device_dataset.make_one_shot_iterator()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ next_element = iterator.get_next()
+ self.assertEqual(dtypes.int64, next_element["a"].dtype)
+ self.assertEqual([], next_element["a"].shape)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ for i in range(10):
+ self.assertEqual({"a": i}, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testPrefetchSparseTensorsToDevice(self):
+ def make_tensor(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=[[0, 0]], values=(i*[1]), dense_shape=[2, 2])
+ host_dataset = dataset_ops.Dataset.range(10).map(make_tensor)
+
+ device_dataset = host_dataset.apply(
+ prefetching_ops.prefetch_to_device("/cpu:1"))
+
+ # NOTE(mrry): This device block creates the "host" dataset and iterator on
+ # /cpu:0, and ensures that the prefetching is across devices. In typical use
+ # this would not be necessary, because the GPU device would not support any
+ # of the dataset-related ops.
+ with ops.device("/cpu:0"):
+ iterator = device_dataset.make_one_shot_iterator()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ next_element = iterator.get_next()
+ self.assertEqual(dtypes.int64, next_element.dtype)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ for i in range(10):
+ actual = sess.run(next_element)
+ self.assertAllEqual([i], actual.values)
+ self.assertAllEqual([[0, 0]], actual.indices)
+ self.assertAllEqual([2, 2], actual.dense_shape)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testPrefetchToDeviceGpu(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.prefetch_to_device("/gpu:0"))
+
+ iterator = device_dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testPrefetchToDeviceWithReInit(self):
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.prefetch_to_device("/cpu:1"))
+
+ # NOTE(mrry): This device block creates the "host" dataset and iterator on
+ # /cpu:0, and ensures that the prefetching is across devices. In typical use
+ # this would not be necessary, because the GPU device would not support any
+ # of the dataset-related ops.
+ with ops.device("/cpu:0"):
+ iterator = device_dataset.make_initializable_iterator()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ next_element = iterator.get_next()
+ self.assertEqual(dtypes.int64, next_element.dtype)
+ self.assertEqual([], next_element.shape)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ sess.run(iterator.initializer)
+ for i in range(5):
+ self.assertEqual(i, sess.run(next_element))
+ sess.run(iterator.initializer)
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testPrefetchToDeviceGpuWithReInit(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.prefetch_to_device("/gpu:0"))
+
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(5):
+ self.assertEqual(i, sess.run(next_element))
+ sess.run(iterator.initializer)
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+
+class CopyToDeviceTest(test_base.DatasetTestBase):
+
+ def testCopyToDevice(self):
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:1"))
+
+ with ops.device("/cpu:1"):
+ iterator = device_dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ self.assertEqual(dtypes.int64, next_element.dtype)
+ self.assertEqual([], next_element.shape)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceInt32(self):
+ host_dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3])
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:1"))
+
+ with ops.device("/cpu:1"):
+ iterator = device_dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ self.assertEqual(dtypes.int32, next_element.dtype)
+ self.assertEqual((4,), next_element.shape)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ self.assertAllEqual([0, 1, 2, 3], sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToSameDevice(self):
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:0"))
+
+ with ops.device("/cpu:0"):
+ iterator = device_dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ self.assertEqual(dtypes.int64, next_element.dtype)
+ self.assertEqual([], next_element.shape)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceWithPrefetch(self):
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:1")).prefetch(1)
+
+ with ops.device("/cpu:1"):
+ iterator = device_dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ self.assertEqual(dtypes.int64, next_element.dtype)
+ self.assertEqual([], next_element.shape)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyDictToDevice(self):
+ host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:1"))
+
+ with ops.device("/cpu:1"):
+ iterator = device_dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ self.assertEqual(dtypes.int64, next_element["a"].dtype)
+ self.assertEqual([], next_element["a"].shape)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ for i in range(10):
+ self.assertEqual({"a": i}, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyDictToDeviceWithPrefetch(self):
+ host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:1")).prefetch(1)
+
+ with ops.device("/cpu:1"):
+ iterator = device_dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ self.assertEqual(dtypes.int64, next_element["a"].dtype)
+ self.assertEqual([], next_element["a"].shape)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ for i in range(10):
+ self.assertEqual({"a": i}, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopySparseTensorsToDevice(self):
+
+ def make_tensor(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=[[0, 0]], values=(i * [1]), dense_shape=[2, 2])
+
+ host_dataset = dataset_ops.Dataset.range(10).map(make_tensor)
+
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:1"))
+
+ with ops.device("/cpu:1"):
+ iterator = device_dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ self.assertEqual(dtypes.int64, next_element.dtype)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ for i in range(10):
+ actual = sess.run(next_element)
+ self.assertAllEqual([i], actual.values)
+ self.assertAllEqual([[0, 0]], actual.indices)
+ self.assertAllEqual([2, 2], actual.dense_shape)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopySparseTensorsToDeviceWithPrefetch(self):
+
+ def make_tensor(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=[[0, 0]], values=(i * [1]), dense_shape=[2, 2])
+
+ host_dataset = dataset_ops.Dataset.range(10).map(make_tensor)
+
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:1")).prefetch(1)
+
+ with ops.device("/cpu:1"):
+ iterator = device_dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ self.assertEqual(dtypes.int64, next_element.dtype)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ for i in range(10):
+ actual = sess.run(next_element)
+ self.assertAllEqual([i], actual.values)
+ self.assertAllEqual([[0, 0]], actual.indices)
+ self.assertAllEqual([2, 2], actual.dense_shape)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceGpu(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/gpu:0"))
+
+ with ops.device("/gpu:0"):
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceGpuWithPrefetch(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/gpu:0")).prefetch(1)
+
+ with ops.device("/gpu:0"):
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceGpuInt32(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3])
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/gpu:0"))
+
+ with ops.device("/gpu:0"):
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ self.assertAllEqual([0, 1, 2, 3], sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceGpuInt32AndPrefetch(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3])
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/gpu:0")).prefetch(1)
+
+ with ops.device("/gpu:0"):
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ self.assertAllEqual([0, 1, 2, 3], sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceGpuStrings(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.from_tensors(["a", "b", "c"])
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/gpu:0"))
+
+ with ops.device("/gpu:0"):
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ self.assertAllEqual([b"a", b"b", b"c"], sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceGpuStringsAndPrefetch(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.from_tensors(["a", "b", "c"])
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/gpu:0"))
+
+ with ops.device("/gpu:0"):
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ self.assertAllEqual([b"a", b"b", b"c"], sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDevicePingPongCPUGPU(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ with compat.forward_compatibility_horizon(2018, 8, 4):
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/gpu:0", source_device="/cpu:0"))
+ back_to_cpu_dataset = device_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:0", source_device="/gpu:0"))
+
+ with ops.device("/cpu:0"):
+ iterator = back_to_cpu_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceWithReInit(self):
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:1"))
+
+ with ops.device("/cpu:1"):
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ self.assertEqual(dtypes.int64, next_element.dtype)
+ self.assertEqual([], next_element.shape)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ sess.run(iterator.initializer)
+ for i in range(5):
+ self.assertEqual(i, sess.run(next_element))
+ sess.run(iterator.initializer)
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceWithReInitAndPrefetch(self):
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:1")).prefetch(1)
+
+ with ops.device("/cpu:1"):
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ self.assertEqual(dtypes.int64, next_element.dtype)
+ self.assertEqual([], next_element.shape)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ sess.run(iterator.initializer)
+ for i in range(5):
+ self.assertEqual(i, sess.run(next_element))
+ sess.run(iterator.initializer)
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceGpuWithReInit(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/gpu:0"))
+
+ with ops.device("/gpu:0"):
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(5):
+ self.assertEqual(i, sess.run(next_element))
+ sess.run(iterator.initializer)
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceGpuWithReInitAndPrefetch(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/gpu:0")).prefetch(1)
+
+ with ops.device("/gpu:0"):
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(5):
+ self.assertEqual(i, sess.run(next_element))
+ sess.run(iterator.initializer)
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testIteratorGetNextAsOptionalOnGPU(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.range(3)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/gpu:0"))
+ with ops.device("/gpu:0"):
+ iterator = device_dataset.make_initializable_iterator()
+ next_elem = iterator_ops.get_next_as_optional(iterator)
+ elem_has_value_t = next_elem.has_value()
+ elem_value_t = next_elem.get_value()
+
+ with self.cached_session() as sess:
+ # Before initializing the iterator, evaluating the optional fails with
+ # a FailedPreconditionError.
+ with self.assertRaises(errors.FailedPreconditionError):
+ sess.run(elem_has_value_t)
+ with self.assertRaises(errors.FailedPreconditionError):
+ sess.run(elem_value_t)
+
+ # For each element of the dataset, assert that the optional evaluates to
+ # the expected value.
+ sess.run(iterator.initializer)
+ for i in range(3):
+ elem_has_value, elem_value = sess.run([elem_has_value_t, elem_value_t])
+ self.assertTrue(elem_has_value)
+ self.assertEqual(i, elem_value)
+
+ # After exhausting the iterator, `next_elem.has_value()` will evaluate to
+ # false, and attempting to get the value will fail.
+ for _ in range(2):
+ self.assertFalse(sess.run(elem_has_value_t))
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(elem_value_t)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/range_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/range_dataset_op_test.py
new file mode 100644
index 0000000000..22412c3965
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/range_dataset_op_test.py
@@ -0,0 +1,78 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Test RangeDataset."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.ops import counter
+from tensorflow.python.data.experimental.ops import enumerate_ops
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.platform import test
+
+
+class RangeDatasetTest(test_base.DatasetTestBase):
+
+ def testEnumerateDataset(self):
+ components = (["a", "b"], [1, 2], [37.0, 38])
+ start = constant_op.constant(20, dtype=dtypes.int64)
+
+ iterator = (dataset_ops.Dataset.from_tensor_slices(components).apply(
+ enumerate_ops.enumerate_dataset(start)).make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ self.assertEqual(dtypes.int64, get_next[0].dtype)
+ self.assertEqual((), get_next[0].shape)
+ self.assertEqual([tensor_shape.TensorShape([])] * 3,
+ [t.shape for t in get_next[1]])
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ self.assertEqual((20, (b"a", 1, 37.0)), sess.run(get_next))
+ self.assertEqual((21, (b"b", 2, 38.0)), sess.run(get_next))
+
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testCounter(self):
+ """Test dataset construction using `count`."""
+ iterator = (counter.Counter(start=3, step=4)
+ .make_one_shot_iterator())
+ get_next = iterator.get_next()
+ self.assertEqual([], get_next.shape.as_list())
+ self.assertEqual(dtypes.int64, get_next.dtype)
+
+ negative_iterator = (counter.Counter(start=0, step=-1)
+ .make_one_shot_iterator())
+ negative_get_next = negative_iterator.get_next()
+
+ with self.cached_session() as sess:
+ self.assertEqual(3, sess.run(get_next))
+ self.assertEqual(3 + 4, sess.run(get_next))
+ self.assertEqual(3 + 2 * 4, sess.run(get_next))
+
+ self.assertEqual(0, sess.run(negative_get_next))
+ self.assertEqual(-1, sess.run(negative_get_next))
+ self.assertEqual(-2, sess.run(negative_get_next))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/reader_dataset_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/reader_dataset_ops_test.py
new file mode 100644
index 0000000000..a02f4bd14f
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/reader_dataset_ops_test.py
@@ -0,0 +1,1083 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gzip
+import os
+import zlib
+
+import numpy as np
+
+from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
+from tensorflow.python.data.experimental.ops import readers
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import readers as core_readers
+from tensorflow.python.data.util import nest
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import parsing_ops
+from tensorflow.python.ops import string_ops
+from tensorflow.python.platform import test
+
+
+class ReadBatchFeaturesTest(
+ reader_dataset_ops_test_base.ReadBatchFeaturesTestBase):
+
+ def testRead(self):
+ for batch_size in [1, 2]:
+ for num_epochs in [1, 10]:
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
+ # Basic test: read from file 0.
+ self.outputs = self.make_batch_feature(
+ filenames=self.test_filenames[0],
+ label_key="label",
+ num_epochs=num_epochs,
+ batch_size=batch_size).make_one_shot_iterator().get_next()
+ self.verify_records(
+ sess,
+ batch_size,
+ 0,
+ num_epochs=num_epochs,
+ label_key_provided=True)
+ with self.assertRaises(errors.OutOfRangeError):
+ self._next_actual_batch(sess, label_key_provided=True)
+
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
+ # Basic test: read from file 1.
+ self.outputs = self.make_batch_feature(
+ filenames=self.test_filenames[1],
+ label_key="label",
+ num_epochs=num_epochs,
+ batch_size=batch_size).make_one_shot_iterator().get_next()
+ self.verify_records(
+ sess,
+ batch_size,
+ 1,
+ num_epochs=num_epochs,
+ label_key_provided=True)
+ with self.assertRaises(errors.OutOfRangeError):
+ self._next_actual_batch(sess, label_key_provided=True)
+
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
+ # Basic test: read from both files.
+ self.outputs = self.make_batch_feature(
+ filenames=self.test_filenames,
+ label_key="label",
+ num_epochs=num_epochs,
+ batch_size=batch_size).make_one_shot_iterator().get_next()
+ self.verify_records(
+ sess,
+ batch_size,
+ num_epochs=num_epochs,
+ label_key_provided=True)
+ with self.assertRaises(errors.OutOfRangeError):
+ self._next_actual_batch(sess, label_key_provided=True)
+
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
+ # Basic test: read from both files.
+ self.outputs = self.make_batch_feature(
+ filenames=self.test_filenames,
+ num_epochs=num_epochs,
+ batch_size=batch_size).make_one_shot_iterator().get_next()
+ self.verify_records(sess, batch_size, num_epochs=num_epochs)
+ with self.assertRaises(errors.OutOfRangeError):
+ self._next_actual_batch(sess)
+
+ def testReadWithEquivalentDataset(self):
+ features = {
+ "file": parsing_ops.FixedLenFeature([], dtypes.int64),
+ "record": parsing_ops.FixedLenFeature([], dtypes.int64),
+ }
+ dataset = (
+ core_readers.TFRecordDataset(self.test_filenames)
+ .map(lambda x: parsing_ops.parse_single_example(x, features))
+ .repeat(10).batch(2))
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ for file_batch, _, _, _, record_batch, _ in self._next_expected_batch(
+ range(self._num_files), 2, 10):
+ actual_batch = sess.run(next_element)
+ self.assertAllEqual(file_batch, actual_batch["file"])
+ self.assertAllEqual(record_batch, actual_batch["record"])
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testReadWithFusedShuffleRepeatDataset(self):
+ num_epochs = 5
+ total_records = num_epochs * self._num_records
+ for batch_size in [1, 2]:
+ # Test that shuffling with same seed produces the same result.
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
+ outputs1 = self.make_batch_feature(
+ filenames=self.test_filenames[0],
+ num_epochs=num_epochs,
+ batch_size=batch_size,
+ shuffle=True,
+ shuffle_seed=5).make_one_shot_iterator().get_next()
+ outputs2 = self.make_batch_feature(
+ filenames=self.test_filenames[0],
+ num_epochs=num_epochs,
+ batch_size=batch_size,
+ shuffle=True,
+ shuffle_seed=5).make_one_shot_iterator().get_next()
+ for _ in range(total_records // batch_size):
+ batch1 = self._run_actual_batch(outputs1, sess)
+ batch2 = self._run_actual_batch(outputs2, sess)
+ for i in range(len(batch1)):
+ self.assertAllEqual(batch1[i], batch2[i])
+
+ # Test that shuffling with different seeds produces a different order.
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
+ outputs1 = self.make_batch_feature(
+ filenames=self.test_filenames[0],
+ num_epochs=num_epochs,
+ batch_size=batch_size,
+ shuffle=True,
+ shuffle_seed=5).make_one_shot_iterator().get_next()
+ outputs2 = self.make_batch_feature(
+ filenames=self.test_filenames[0],
+ num_epochs=num_epochs,
+ batch_size=batch_size,
+ shuffle=True,
+ shuffle_seed=15).make_one_shot_iterator().get_next()
+ all_equal = True
+ for _ in range(total_records // batch_size):
+ batch1 = self._run_actual_batch(outputs1, sess)
+ batch2 = self._run_actual_batch(outputs2, sess)
+ for i in range(len(batch1)):
+ all_equal = all_equal and np.array_equal(batch1[i], batch2[i])
+ self.assertFalse(all_equal)
+
+ def testParallelReadersAndParsers(self):
+ num_epochs = 5
+ for batch_size in [1, 2]:
+ for reader_num_threads in [2, 4]:
+ for parser_num_threads in [2, 4]:
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
+ self.outputs = self.make_batch_feature(
+ filenames=self.test_filenames,
+ label_key="label",
+ num_epochs=num_epochs,
+ batch_size=batch_size,
+ reader_num_threads=reader_num_threads,
+ parser_num_threads=parser_num_threads).make_one_shot_iterator(
+ ).get_next()
+ self.verify_records(
+ sess,
+ batch_size,
+ num_epochs=num_epochs,
+ label_key_provided=True,
+ interleave_cycle_length=reader_num_threads)
+ with self.assertRaises(errors.OutOfRangeError):
+ self._next_actual_batch(sess, label_key_provided=True)
+
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
+ self.outputs = self.make_batch_feature(
+ filenames=self.test_filenames,
+ num_epochs=num_epochs,
+ batch_size=batch_size,
+ reader_num_threads=reader_num_threads,
+ parser_num_threads=parser_num_threads).make_one_shot_iterator(
+ ).get_next()
+ self.verify_records(
+ sess,
+ batch_size,
+ num_epochs=num_epochs,
+ interleave_cycle_length=reader_num_threads)
+ with self.assertRaises(errors.OutOfRangeError):
+ self._next_actual_batch(sess)
+
+ def testDropFinalBatch(self):
+ for batch_size in [1, 2]:
+ for num_epochs in [1, 10]:
+ with ops.Graph().as_default():
+ # Basic test: read from file 0.
+ outputs = self.make_batch_feature(
+ filenames=self.test_filenames[0],
+ label_key="label",
+ num_epochs=num_epochs,
+ batch_size=batch_size,
+ drop_final_batch=True).make_one_shot_iterator().get_next()
+ for tensor in nest.flatten(outputs):
+ if isinstance(tensor, ops.Tensor): # Guard against SparseTensor.
+ self.assertEqual(tensor.shape[0], batch_size)
+
+ def testIndefiniteRepeatShapeInference(self):
+ dataset = self.make_batch_feature(
+ filenames=self.test_filenames[0],
+ label_key="label",
+ num_epochs=None,
+ batch_size=32)
+ for shape, clazz in zip(nest.flatten(dataset.output_shapes),
+ nest.flatten(dataset.output_classes)):
+ if issubclass(clazz, ops.Tensor):
+ self.assertEqual(32, shape[0])
+
+
+class MakeCsvDatasetTest(test_base.DatasetTestBase):
+
+ def _make_csv_dataset(self, filenames, batch_size, num_epochs=1, **kwargs):
+ return readers.make_csv_dataset(
+ filenames, batch_size=batch_size, num_epochs=num_epochs, **kwargs)
+
+ def _setup_files(self, inputs, linebreak="\n", compression_type=None):
+ filenames = []
+ for i, ip in enumerate(inputs):
+ fn = os.path.join(self.get_temp_dir(), "temp_%d.csv" % i)
+ contents = linebreak.join(ip).encode("utf-8")
+ if compression_type is None:
+ with open(fn, "wb") as f:
+ f.write(contents)
+ elif compression_type == "GZIP":
+ with gzip.GzipFile(fn, "wb") as f:
+ f.write(contents)
+ elif compression_type == "ZLIB":
+ contents = zlib.compress(contents)
+ with open(fn, "wb") as f:
+ f.write(contents)
+ else:
+ raise ValueError("Unsupported compression_type", compression_type)
+ filenames.append(fn)
+ return filenames
+
+ def _next_expected_batch(self, expected_output, expected_keys, batch_size,
+ num_epochs):
+ features = {k: [] for k in expected_keys}
+ for _ in range(num_epochs):
+ for values in expected_output:
+ for n, key in enumerate(expected_keys):
+ features[key].append(values[n])
+ if len(features[expected_keys[0]]) == batch_size:
+ yield features
+ features = {k: [] for k in expected_keys}
+ if features[expected_keys[0]]: # Leftover from the last batch
+ yield features
+
+ def _verify_output(
+ self,
+ sess,
+ dataset,
+ batch_size,
+ num_epochs,
+ label_name,
+ expected_output,
+ expected_keys,
+ ):
+ nxt = dataset.make_one_shot_iterator().get_next()
+
+ for expected_features in self._next_expected_batch(
+ expected_output,
+ expected_keys,
+ batch_size,
+ num_epochs,
+ ):
+ actual_features = sess.run(nxt)
+
+ if label_name is not None:
+ expected_labels = expected_features.pop(label_name)
+ self.assertAllEqual(expected_labels, actual_features[1])
+ actual_features = actual_features[0]
+
+ for k in expected_features.keys():
+ # Compare features
+ self.assertAllEqual(expected_features[k], actual_features[k])
+
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(nxt)
+
+ def _test_dataset(self,
+ inputs,
+ expected_output,
+ expected_keys,
+ batch_size=1,
+ num_epochs=1,
+ label_name=None,
+ **kwargs):
+ """Checks that elements produced by CsvDataset match expected output."""
+ # Convert str type because py3 tf strings are bytestrings
+ filenames = self._setup_files(
+ inputs, compression_type=kwargs.get("compression_type", None))
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
+ dataset = self._make_csv_dataset(
+ filenames,
+ batch_size=batch_size,
+ num_epochs=num_epochs,
+ label_name=label_name,
+ **kwargs)
+ self._verify_output(sess, dataset, batch_size, num_epochs, label_name,
+ expected_output, expected_keys)
+
+ def testMakeCSVDataset(self):
+ """Tests making a CSV dataset with keys and defaults provided."""
+ record_defaults = [
+ constant_op.constant([], dtypes.int32),
+ constant_op.constant([], dtypes.int64),
+ constant_op.constant([], dtypes.float32),
+ constant_op.constant([], dtypes.float64),
+ constant_op.constant([], dtypes.string)
+ ]
+
+ column_names = ["col%d" % i for i in range(5)]
+ inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [
+ ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19"
+ ]]
+ expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"],
+ [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]]
+ label = "col0"
+
+ self._test_dataset(
+ inputs,
+ expected_output=expected_output,
+ expected_keys=column_names,
+ column_names=column_names,
+ label_name=label,
+ batch_size=1,
+ num_epochs=1,
+ shuffle=False,
+ header=True,
+ column_defaults=record_defaults,
+ )
+
+ def testMakeCSVDataset_withBatchSizeAndEpochs(self):
+ """Tests making a CSV dataset with keys and defaults provided."""
+ record_defaults = [
+ constant_op.constant([], dtypes.int32),
+ constant_op.constant([], dtypes.int64),
+ constant_op.constant([], dtypes.float32),
+ constant_op.constant([], dtypes.float64),
+ constant_op.constant([], dtypes.string)
+ ]
+
+ column_names = ["col%d" % i for i in range(5)]
+ inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [
+ ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19"
+ ]]
+ expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"],
+ [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]]
+ label = "col0"
+
+ self._test_dataset(
+ inputs,
+ expected_output=expected_output,
+ expected_keys=column_names,
+ column_names=column_names,
+ label_name=label,
+ batch_size=3,
+ num_epochs=10,
+ shuffle=False,
+ header=True,
+ column_defaults=record_defaults,
+ )
+
+ def testMakeCSVDataset_withCompressionType(self):
+ """Tests `compression_type` argument."""
+ record_defaults = [
+ constant_op.constant([], dtypes.int32),
+ constant_op.constant([], dtypes.int64),
+ constant_op.constant([], dtypes.float32),
+ constant_op.constant([], dtypes.float64),
+ constant_op.constant([], dtypes.string)
+ ]
+
+ column_names = ["col%d" % i for i in range(5)]
+ inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [
+ ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19"
+ ]]
+ expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"],
+ [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]]
+ label = "col0"
+
+ for compression_type in ("GZIP", "ZLIB"):
+ self._test_dataset(
+ inputs,
+ expected_output=expected_output,
+ expected_keys=column_names,
+ column_names=column_names,
+ label_name=label,
+ batch_size=1,
+ num_epochs=1,
+ shuffle=False,
+ header=True,
+ column_defaults=record_defaults,
+ compression_type=compression_type,
+ )
+
+ def testMakeCSVDataset_withBadInputs(self):
+ """Tests that exception is raised when input is malformed.
+ """
+ record_defaults = [
+ constant_op.constant([], dtypes.int32),
+ constant_op.constant([], dtypes.int64),
+ constant_op.constant([], dtypes.float32),
+ constant_op.constant([], dtypes.float64),
+ constant_op.constant([], dtypes.string)
+ ]
+
+ column_names = ["col%d" % i for i in range(5)]
+ inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [
+ ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19"
+ ]]
+ filenames = self._setup_files(inputs)
+
+ # Duplicate column names
+ with self.assertRaises(ValueError):
+ self._make_csv_dataset(
+ filenames,
+ batch_size=1,
+ column_defaults=record_defaults,
+ label_name="col0",
+ column_names=column_names * 2)
+
+ # Label key not one of column names
+ with self.assertRaises(ValueError):
+ self._make_csv_dataset(
+ filenames,
+ batch_size=1,
+ column_defaults=record_defaults,
+ label_name="not_a_real_label",
+ column_names=column_names)
+
+ def testMakeCSVDataset_withNoLabel(self):
+ """Tests making a CSV dataset with no label provided."""
+ record_defaults = [
+ constant_op.constant([], dtypes.int32),
+ constant_op.constant([], dtypes.int64),
+ constant_op.constant([], dtypes.float32),
+ constant_op.constant([], dtypes.float64),
+ constant_op.constant([], dtypes.string)
+ ]
+
+ column_names = ["col%d" % i for i in range(5)]
+ inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [
+ ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19"
+ ]]
+ expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"],
+ [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]]
+
+ self._test_dataset(
+ inputs,
+ expected_output=expected_output,
+ expected_keys=column_names,
+ column_names=column_names,
+ batch_size=1,
+ num_epochs=1,
+ shuffle=False,
+ header=True,
+ column_defaults=record_defaults,
+ )
+
+ def testMakeCSVDataset_withNoHeader(self):
+ """Tests that datasets can be created from CSV files with no header line.
+ """
+ record_defaults = [
+ constant_op.constant([], dtypes.int32),
+ constant_op.constant([], dtypes.int64),
+ constant_op.constant([], dtypes.float32),
+ constant_op.constant([], dtypes.float64),
+ constant_op.constant([], dtypes.string)
+ ]
+
+ column_names = ["col%d" % i for i in range(5)]
+ inputs = [["0,1,2,3,4", "5,6,7,8,9"], ["10,11,12,13,14", "15,16,17,18,19"]]
+ expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"],
+ [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]]
+ label = "col0"
+
+ self._test_dataset(
+ inputs,
+ expected_output=expected_output,
+ expected_keys=column_names,
+ column_names=column_names,
+ label_name=label,
+ batch_size=1,
+ num_epochs=1,
+ shuffle=False,
+ header=False,
+ column_defaults=record_defaults,
+ )
+
+ def testMakeCSVDataset_withTypes(self):
+ """Tests that defaults can be a dtype instead of a Tensor for required vals.
+ """
+ record_defaults = [
+ dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64,
+ dtypes.string
+ ]
+
+ column_names = ["col%d" % i for i in range(5)]
+ inputs = [[",".join(x[0] for x in column_names), "0,1,2,3,4", "5,6,7,8,9"],
+ [
+ ",".join(x[0] for x in column_names), "10,11,12,13,14",
+ "15,16,17,18,19"
+ ]]
+ expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"],
+ [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]]
+ label = "col0"
+
+ self._test_dataset(
+ inputs,
+ expected_output=expected_output,
+ expected_keys=column_names,
+ column_names=column_names,
+ label_name=label,
+ batch_size=1,
+ num_epochs=1,
+ shuffle=False,
+ header=True,
+ column_defaults=record_defaults,
+ )
+
+ def testMakeCSVDataset_withNoColNames(self):
+ """Tests that datasets can be created when column names are not specified.
+
+ In that case, we should infer the column names from the header lines.
+ """
+ record_defaults = [
+ constant_op.constant([], dtypes.int32),
+ constant_op.constant([], dtypes.int64),
+ constant_op.constant([], dtypes.float32),
+ constant_op.constant([], dtypes.float64),
+ constant_op.constant([], dtypes.string)
+ ]
+
+ column_names = ["col%d" % i for i in range(5)]
+ inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [
+ ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19"
+ ]]
+ expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"],
+ [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]]
+ label = "col0"
+
+ self._test_dataset(
+ inputs,
+ expected_output=expected_output,
+ expected_keys=column_names,
+ label_name=label,
+ batch_size=1,
+ num_epochs=1,
+ shuffle=False,
+ header=True,
+ column_defaults=record_defaults,
+ )
+
+ def testMakeCSVDataset_withTypeInferenceMismatch(self):
+ # Test that error is thrown when num fields doesn't match columns
+ column_names = ["col%d" % i for i in range(5)]
+ inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [
+ ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19"
+ ]]
+ filenames = self._setup_files(inputs)
+ with self.assertRaises(ValueError):
+ self._make_csv_dataset(
+ filenames,
+ column_names=column_names + ["extra_name"],
+ column_defaults=None,
+ batch_size=2,
+ num_epochs=10)
+
+ def testMakeCSVDataset_withTypeInference(self):
+ """Tests that datasets can be created when no defaults are specified.
+
+ In that case, we should infer the types from the first N records.
+ """
+ column_names = ["col%d" % i for i in range(5)]
+ str_int32_max = str(2**33)
+ inputs = [[
+ ",".join(x for x in column_names),
+ "0,%s,2.0,3e50,rabbit" % str_int32_max
+ ]]
+ expected_output = [[0, 2**33, 2.0, 3e50, b"rabbit"]]
+ label = "col0"
+
+ self._test_dataset(
+ inputs,
+ expected_output=expected_output,
+ expected_keys=column_names,
+ column_names=column_names,
+ label_name=label,
+ batch_size=1,
+ num_epochs=1,
+ shuffle=False,
+ header=True,
+ )
+
+ def testMakeCSVDataset_withTypeInferenceFallthrough(self):
+ """Tests that datasets can be created when no defaults are specified.
+
+ Tests on a deliberately tricky file.
+ """
+ column_names = ["col%d" % i for i in range(5)]
+ str_int32_max = str(2**33)
+ inputs = [[
+ ",".join(x for x in column_names),
+ ",,,,",
+ "0,0,0.0,0.0,0.0",
+ "0,%s,2.0,3e50,rabbit" % str_int32_max,
+ ",,,,",
+ ]]
+ expected_output = [[0, 0, 0, 0, b""], [0, 0, 0, 0, b"0.0"],
+ [0, 2**33, 2.0, 3e50, b"rabbit"], [0, 0, 0, 0, b""]]
+ label = "col0"
+
+ self._test_dataset(
+ inputs,
+ expected_output=expected_output,
+ expected_keys=column_names,
+ column_names=column_names,
+ label_name=label,
+ batch_size=1,
+ num_epochs=1,
+ shuffle=False,
+ header=True,
+ )
+
+ def testMakeCSVDataset_withSelectCols(self):
+ record_defaults = [
+ constant_op.constant([], dtypes.int32),
+ constant_op.constant([], dtypes.int64),
+ constant_op.constant([], dtypes.float32),
+ constant_op.constant([], dtypes.float64),
+ constant_op.constant([], dtypes.string)
+ ]
+ column_names = ["col%d" % i for i in range(5)]
+ str_int32_max = str(2**33)
+ inputs = [[
+ ",".join(x for x in column_names),
+ "0,%s,2.0,3e50,rabbit" % str_int32_max
+ ]]
+ expected_output = [[0, 2**33, 2.0, 3e50, b"rabbit"]]
+
+ select_cols = [1, 3, 4]
+ self._test_dataset(
+ inputs,
+ expected_output=[[x[i] for i in select_cols] for x in expected_output],
+ expected_keys=[column_names[i] for i in select_cols],
+ column_names=column_names,
+ column_defaults=[record_defaults[i] for i in select_cols],
+ batch_size=1,
+ num_epochs=1,
+ shuffle=False,
+ header=True,
+ select_columns=select_cols,
+ )
+
+ # Can still do inference without provided defaults
+ self._test_dataset(
+ inputs,
+ expected_output=[[x[i] for i in select_cols] for x in expected_output],
+ expected_keys=[column_names[i] for i in select_cols],
+ column_names=column_names,
+ batch_size=1,
+ num_epochs=1,
+ shuffle=False,
+ header=True,
+ select_columns=select_cols,
+ )
+
+ # Can still do column name inference
+ self._test_dataset(
+ inputs,
+ expected_output=[[x[i] for i in select_cols] for x in expected_output],
+ expected_keys=[column_names[i] for i in select_cols],
+ batch_size=1,
+ num_epochs=1,
+ shuffle=False,
+ header=True,
+ select_columns=select_cols,
+ )
+
+ # Can specify column names instead of indices
+ self._test_dataset(
+ inputs,
+ expected_output=[[x[i] for i in select_cols] for x in expected_output],
+ expected_keys=[column_names[i] for i in select_cols],
+ column_names=column_names,
+ batch_size=1,
+ num_epochs=1,
+ shuffle=False,
+ header=True,
+ select_columns=[column_names[i] for i in select_cols],
+ )
+
+ def testMakeCSVDataset_withSelectColsError(self):
+ record_defaults = [
+ constant_op.constant([], dtypes.int32),
+ constant_op.constant([], dtypes.int64),
+ constant_op.constant([], dtypes.float32),
+ constant_op.constant([], dtypes.float64),
+ constant_op.constant([], dtypes.string)
+ ]
+ column_names = ["col%d" % i for i in range(5)]
+ str_int32_max = str(2**33)
+ inputs = [[
+ ",".join(x for x in column_names),
+ "0,%s,2.0,3e50,rabbit" % str_int32_max
+ ]]
+
+ select_cols = [1, 3, 4]
+ filenames = self._setup_files(inputs)
+
+ with self.assertRaises(ValueError):
+ # Mismatch in number of defaults and number of columns selected,
+ # should raise an error
+ self._make_csv_dataset(
+ filenames,
+ batch_size=1,
+ column_defaults=record_defaults,
+ column_names=column_names,
+ select_columns=select_cols)
+
+ with self.assertRaises(ValueError):
+ # Invalid column name should raise an error
+ self._make_csv_dataset(
+ filenames,
+ batch_size=1,
+ column_defaults=[[0]],
+ column_names=column_names,
+ label_name=None,
+ select_columns=["invalid_col_name"])
+
+ def testMakeCSVDataset_withShuffle(self):
+ record_defaults = [
+ constant_op.constant([], dtypes.int32),
+ constant_op.constant([], dtypes.int64),
+ constant_op.constant([], dtypes.float32),
+ constant_op.constant([], dtypes.float64),
+ constant_op.constant([], dtypes.string)
+ ]
+
+ def str_series(st):
+ return ",".join(str(i) for i in range(st, st + 5))
+
+ column_names = ["col%d" % i for i in range(5)]
+ inputs = [
+ [",".join(x for x in column_names)
+ ] + [str_series(5 * i) for i in range(15)],
+ [",".join(x for x in column_names)] +
+ [str_series(5 * i) for i in range(15, 20)],
+ ]
+
+ filenames = self._setup_files(inputs)
+
+ total_records = 20
+ for batch_size in [1, 2]:
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
+ # Test that shuffling with the same seed produces the same result
+ dataset1 = self._make_csv_dataset(
+ filenames,
+ column_defaults=record_defaults,
+ column_names=column_names,
+ batch_size=batch_size,
+ header=True,
+ shuffle=True,
+ shuffle_seed=5,
+ num_epochs=2,
+ )
+ dataset2 = self._make_csv_dataset(
+ filenames,
+ column_defaults=record_defaults,
+ column_names=column_names,
+ batch_size=batch_size,
+ header=True,
+ shuffle=True,
+ shuffle_seed=5,
+ num_epochs=2,
+ )
+ outputs1 = dataset1.make_one_shot_iterator().get_next()
+ outputs2 = dataset2.make_one_shot_iterator().get_next()
+ for _ in range(total_records // batch_size):
+ batch1 = nest.flatten(sess.run(outputs1))
+ batch2 = nest.flatten(sess.run(outputs2))
+ for i in range(len(batch1)):
+ self.assertAllEqual(batch1[i], batch2[i])
+
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
+ # Test that shuffling with a different seed produces different results
+ dataset1 = self._make_csv_dataset(
+ filenames,
+ column_defaults=record_defaults,
+ column_names=column_names,
+ batch_size=batch_size,
+ header=True,
+ shuffle=True,
+ shuffle_seed=5,
+ num_epochs=2,
+ )
+ dataset2 = self._make_csv_dataset(
+ filenames,
+ column_defaults=record_defaults,
+ column_names=column_names,
+ batch_size=batch_size,
+ header=True,
+ shuffle=True,
+ shuffle_seed=6,
+ num_epochs=2,
+ )
+ outputs1 = dataset1.make_one_shot_iterator().get_next()
+ outputs2 = dataset2.make_one_shot_iterator().get_next()
+ all_equal = False
+ for _ in range(total_records // batch_size):
+ batch1 = nest.flatten(sess.run(outputs1))
+ batch2 = nest.flatten(sess.run(outputs2))
+ for i in range(len(batch1)):
+ all_equal = all_equal and np.array_equal(batch1[i], batch2[i])
+ self.assertFalse(all_equal)
+
+ def testIndefiniteRepeatShapeInference(self):
+ column_names = ["col%d" % i for i in range(5)]
+ inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [
+ ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19"
+ ]]
+ filenames = self._setup_files(inputs)
+ dataset = self._make_csv_dataset(filenames, batch_size=32, num_epochs=None)
+ for shape in nest.flatten(dataset.output_shapes):
+ self.assertEqual(32, shape[0])
+
+
+class MakeTFRecordDatasetTest(
+ reader_dataset_ops_test_base.TFRecordDatasetTestBase):
+
+ def _interleave(self, iterators, cycle_length):
+ pending_iterators = iterators
+ open_iterators = []
+ num_open = 0
+ for i in range(cycle_length):
+ if pending_iterators:
+ open_iterators.append(pending_iterators.pop(0))
+ num_open += 1
+
+ while num_open:
+ for i in range(min(cycle_length, len(open_iterators))):
+ if open_iterators[i] is None:
+ continue
+ try:
+ yield next(open_iterators[i])
+ except StopIteration:
+ if pending_iterators:
+ open_iterators[i] = pending_iterators.pop(0)
+ else:
+ open_iterators[i] = None
+ num_open -= 1
+
+ def _next_expected_batch(self,
+ file_indices,
+ batch_size,
+ num_epochs,
+ cycle_length,
+ drop_final_batch,
+ use_parser_fn):
+
+ def _next_record(file_indices):
+ for j in file_indices:
+ for i in range(self._num_records):
+ yield j, i
+
+ def _next_record_interleaved(file_indices, cycle_length):
+ return self._interleave([_next_record([i]) for i in file_indices],
+ cycle_length)
+
+ record_batch = []
+ batch_index = 0
+ for _ in range(num_epochs):
+ if cycle_length == 1:
+ next_records = _next_record(file_indices)
+ else:
+ next_records = _next_record_interleaved(file_indices, cycle_length)
+ for f, r in next_records:
+ record = self._record(f, r)
+ if use_parser_fn:
+ record = record[1:]
+ record_batch.append(record)
+ batch_index += 1
+ if len(record_batch) == batch_size:
+ yield record_batch
+ record_batch = []
+ batch_index = 0
+ if record_batch and not drop_final_batch:
+ yield record_batch
+
+ def _verify_records(self,
+ sess,
+ outputs,
+ batch_size,
+ file_index,
+ num_epochs,
+ interleave_cycle_length,
+ drop_final_batch,
+ use_parser_fn):
+ if file_index is not None:
+ file_indices = [file_index]
+ else:
+ file_indices = range(self._num_files)
+
+ for expected_batch in self._next_expected_batch(
+ file_indices, batch_size, num_epochs, interleave_cycle_length,
+ drop_final_batch, use_parser_fn):
+ actual_batch = sess.run(outputs)
+ self.assertAllEqual(expected_batch, actual_batch)
+
+ def _read_test(self, batch_size, num_epochs, file_index=None,
+ num_parallel_reads=1, drop_final_batch=False, parser_fn=False):
+ if file_index is None:
+ file_pattern = self.test_filenames
+ else:
+ file_pattern = self.test_filenames[file_index]
+
+ if parser_fn:
+ fn = lambda x: string_ops.substr(x, 1, 999)
+ else:
+ fn = None
+
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
+ outputs = readers.make_tf_record_dataset(
+ file_pattern=file_pattern,
+ num_epochs=num_epochs,
+ batch_size=batch_size,
+ parser_fn=fn,
+ num_parallel_reads=num_parallel_reads,
+ drop_final_batch=drop_final_batch,
+ shuffle=False).make_one_shot_iterator().get_next()
+ self._verify_records(
+ sess, outputs, batch_size, file_index, num_epochs=num_epochs,
+ interleave_cycle_length=num_parallel_reads,
+ drop_final_batch=drop_final_batch, use_parser_fn=parser_fn)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(outputs)
+
+ def testRead(self):
+ for batch_size in [1, 2]:
+ for num_epochs in [1, 3]:
+ # Basic test: read from file 0.
+ self._read_test(batch_size, num_epochs, 0)
+
+ # Basic test: read from file 1.
+ self._read_test(batch_size, num_epochs, 1)
+
+ # Basic test: read from both files.
+ self._read_test(batch_size, num_epochs)
+
+ # Basic test: read from both files, with parallel reads.
+ self._read_test(batch_size, num_epochs, num_parallel_reads=8)
+
+ def testDropFinalBatch(self):
+ for batch_size in [1, 2, 10]:
+ for num_epochs in [1, 3]:
+ # Read from file 0.
+ self._read_test(batch_size, num_epochs, 0, drop_final_batch=True)
+
+ # Read from both files.
+ self._read_test(batch_size, num_epochs, drop_final_batch=True)
+
+ # Read from both files, with parallel reads.
+ self._read_test(batch_size, num_epochs, num_parallel_reads=8,
+ drop_final_batch=True)
+
+ def testParserFn(self):
+ for batch_size in [1, 2]:
+ for num_epochs in [1, 3]:
+ for drop_final_batch in [False, True]:
+ self._read_test(batch_size, num_epochs, parser_fn=True,
+ drop_final_batch=drop_final_batch)
+ self._read_test(batch_size, num_epochs, num_parallel_reads=8,
+ parser_fn=True, drop_final_batch=drop_final_batch)
+
+ def _shuffle_test(self, batch_size, num_epochs, num_parallel_reads=1,
+ seed=None):
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
+ dataset = readers.make_tf_record_dataset(
+ file_pattern=self.test_filenames,
+ num_epochs=num_epochs,
+ batch_size=batch_size,
+ num_parallel_reads=num_parallel_reads,
+ shuffle=True,
+ shuffle_seed=seed)
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ sess.run(iterator.initializer)
+ first_batches = []
+ try:
+ while True:
+ first_batches.append(sess.run(next_element))
+ except errors.OutOfRangeError:
+ pass
+
+ sess.run(iterator.initializer)
+ second_batches = []
+ try:
+ while True:
+ second_batches.append(sess.run(next_element))
+ except errors.OutOfRangeError:
+ pass
+
+ self.assertEqual(len(first_batches), len(second_batches))
+ if seed is not None:
+ # if you set a seed, should get the same results
+ for i in range(len(first_batches)):
+ self.assertAllEqual(first_batches[i], second_batches[i])
+
+ expected = []
+ for f in range(self._num_files):
+ for r in range(self._num_records):
+ expected.extend([self._record(f, r)] * num_epochs)
+
+ for batches in (first_batches, second_batches):
+ actual = []
+ for b in batches:
+ actual.extend(b)
+ self.assertAllEqual(sorted(expected), sorted(actual))
+
+ def testShuffle(self):
+ for batch_size in [1, 2]:
+ for num_epochs in [1, 3]:
+ for num_parallel_reads in [1, 2]:
+ # Test that all expected elements are produced
+ self._shuffle_test(batch_size, num_epochs, num_parallel_reads)
+ # Test that elements are produced in a consistent order if
+ # you specify a seed.
+ self._shuffle_test(batch_size, num_epochs, num_parallel_reads,
+ seed=21345)
+
+ def testIndefiniteRepeatShapeInference(self):
+ dataset = readers.make_tf_record_dataset(
+ file_pattern=self.test_filenames, num_epochs=None, batch_size=32)
+ for shape in nest.flatten(dataset.output_shapes):
+ self.assertEqual(32, shape[0])
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/reader_dataset_ops_test_base.py b/tensorflow/python/data/experimental/kernel_tests/reader_dataset_ops_test_base.py
new file mode 100644
index 0000000000..b6ab80d132
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/reader_dataset_ops_test_base.py
@@ -0,0 +1,353 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Base class for testing reader datasets."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gzip
+import os
+import zlib
+
+from tensorflow.core.example import example_pb2
+from tensorflow.core.example import feature_pb2
+from tensorflow.python.data.experimental.ops import readers
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.data.ops import readers as core_readers
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.lib.io import python_io
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import parsing_ops
+from tensorflow.python.util import compat
+
+
+class FixedLengthRecordDatasetTestBase(test_base.DatasetTestBase):
+ """Base class for setting up and testing FixedLengthRecordDataset."""
+
+ def setUp(self):
+ super(FixedLengthRecordDatasetTestBase, self).setUp()
+ self._num_files = 2
+ self._num_records = 7
+ self._header_bytes = 5
+ self._record_bytes = 3
+ self._footer_bytes = 2
+
+ def _record(self, f, r):
+ return compat.as_bytes(str(f * 2 + r) * self._record_bytes)
+
+ def _createFiles(self):
+ filenames = []
+ for i in range(self._num_files):
+ fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i)
+ filenames.append(fn)
+ with open(fn, "wb") as f:
+ f.write(b"H" * self._header_bytes)
+ for j in range(self._num_records):
+ f.write(self._record(i, j))
+ f.write(b"F" * self._footer_bytes)
+ return filenames
+
+
+class ReadBatchFeaturesTestBase(test_base.DatasetTestBase):
+ """Base class for setting up and testing `make_batched_feature_dataset`."""
+
+ def setUp(self):
+ super(ReadBatchFeaturesTestBase, self).setUp()
+ self._num_files = 2
+ self._num_records = 7
+ self.test_filenames = self._createFiles()
+
+ def make_batch_feature(self,
+ filenames,
+ num_epochs,
+ batch_size,
+ label_key=None,
+ reader_num_threads=1,
+ parser_num_threads=1,
+ shuffle=False,
+ shuffle_seed=None,
+ drop_final_batch=False):
+ self.filenames = filenames
+ self.num_epochs = num_epochs
+ self.batch_size = batch_size
+
+ return readers.make_batched_features_dataset(
+ file_pattern=self.filenames,
+ batch_size=self.batch_size,
+ features={
+ "file": parsing_ops.FixedLenFeature([], dtypes.int64),
+ "record": parsing_ops.FixedLenFeature([], dtypes.int64),
+ "keywords": parsing_ops.VarLenFeature(dtypes.string),
+ "label": parsing_ops.FixedLenFeature([], dtypes.string),
+ },
+ label_key=label_key,
+ reader=core_readers.TFRecordDataset,
+ num_epochs=self.num_epochs,
+ shuffle=shuffle,
+ shuffle_seed=shuffle_seed,
+ reader_num_threads=reader_num_threads,
+ parser_num_threads=parser_num_threads,
+ drop_final_batch=drop_final_batch)
+
+ def _record(self, f, r, l):
+ example = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ "file":
+ feature_pb2.Feature(
+ int64_list=feature_pb2.Int64List(value=[f])),
+ "record":
+ feature_pb2.Feature(
+ int64_list=feature_pb2.Int64List(value=[r])),
+ "keywords":
+ feature_pb2.Feature(
+ bytes_list=feature_pb2.BytesList(
+ value=self._get_keywords(f, r))),
+ "label":
+ feature_pb2.Feature(
+ bytes_list=feature_pb2.BytesList(
+ value=[compat.as_bytes(l)]))
+ }))
+ return example.SerializeToString()
+
+ def _get_keywords(self, f, r):
+ num_keywords = 1 + (f + r) % 2
+ keywords = []
+ for index in range(num_keywords):
+ keywords.append(compat.as_bytes("keyword%d" % index))
+ return keywords
+
+ def _sum_keywords(self, num_files):
+ sum_keywords = 0
+ for i in range(num_files):
+ for j in range(self._num_records):
+ sum_keywords += 1 + (i + j) % 2
+ return sum_keywords
+
+ def _createFiles(self):
+ filenames = []
+ for i in range(self._num_files):
+ fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i)
+ filenames.append(fn)
+ writer = python_io.TFRecordWriter(fn)
+ for j in range(self._num_records):
+ writer.write(self._record(i, j, "fake-label"))
+ writer.close()
+ return filenames
+
+ def _run_actual_batch(self, outputs, sess, label_key_provided=False):
+ if label_key_provided:
+ # outputs would be a tuple of (feature dict, label)
+ label_op = outputs[1]
+ features_op = outputs[0]
+ else:
+ features_op = outputs
+ label_op = features_op["label"]
+ file_op = features_op["file"]
+ keywords_indices_op = features_op["keywords"].indices
+ keywords_values_op = features_op["keywords"].values
+ keywords_dense_shape_op = features_op["keywords"].dense_shape
+ record_op = features_op["record"]
+ return sess.run([
+ file_op, keywords_indices_op, keywords_values_op,
+ keywords_dense_shape_op, record_op, label_op
+ ])
+
+ def _next_actual_batch(self, sess, label_key_provided=False):
+ return self._run_actual_batch(self.outputs, sess, label_key_provided)
+
+ def _interleave(self, iterators, cycle_length):
+ pending_iterators = iterators
+ open_iterators = []
+ num_open = 0
+ for i in range(cycle_length):
+ if pending_iterators:
+ open_iterators.append(pending_iterators.pop(0))
+ num_open += 1
+
+ while num_open:
+ for i in range(min(cycle_length, len(open_iterators))):
+ if open_iterators[i] is None:
+ continue
+ try:
+ yield next(open_iterators[i])
+ except StopIteration:
+ if pending_iterators:
+ open_iterators[i] = pending_iterators.pop(0)
+ else:
+ open_iterators[i] = None
+ num_open -= 1
+
+ def _next_expected_batch(self,
+ file_indices,
+ batch_size,
+ num_epochs,
+ cycle_length=1):
+
+ def _next_record(file_indices):
+ for j in file_indices:
+ for i in range(self._num_records):
+ yield j, i, compat.as_bytes("fake-label")
+
+ def _next_record_interleaved(file_indices, cycle_length):
+ return self._interleave([_next_record([i]) for i in file_indices],
+ cycle_length)
+
+ file_batch = []
+ keywords_batch_indices = []
+ keywords_batch_values = []
+ keywords_batch_max_len = 0
+ record_batch = []
+ batch_index = 0
+ label_batch = []
+ for _ in range(num_epochs):
+ if cycle_length == 1:
+ next_records = _next_record(file_indices)
+ else:
+ next_records = _next_record_interleaved(file_indices, cycle_length)
+ for record in next_records:
+ f = record[0]
+ r = record[1]
+ label_batch.append(record[2])
+ file_batch.append(f)
+ record_batch.append(r)
+ keywords = self._get_keywords(f, r)
+ keywords_batch_values.extend(keywords)
+ keywords_batch_indices.extend(
+ [[batch_index, i] for i in range(len(keywords))])
+ batch_index += 1
+ keywords_batch_max_len = max(keywords_batch_max_len, len(keywords))
+ if len(file_batch) == batch_size:
+ yield [
+ file_batch, keywords_batch_indices, keywords_batch_values,
+ [batch_size, keywords_batch_max_len], record_batch, label_batch
+ ]
+ file_batch = []
+ keywords_batch_indices = []
+ keywords_batch_values = []
+ keywords_batch_max_len = 0
+ record_batch = []
+ batch_index = 0
+ label_batch = []
+ if file_batch:
+ yield [
+ file_batch, keywords_batch_indices, keywords_batch_values,
+ [len(file_batch), keywords_batch_max_len], record_batch, label_batch
+ ]
+
+ def verify_records(self,
+ sess,
+ batch_size,
+ file_index=None,
+ num_epochs=1,
+ label_key_provided=False,
+ interleave_cycle_length=1):
+ if file_index is not None:
+ file_indices = [file_index]
+ else:
+ file_indices = range(self._num_files)
+
+ for expected_batch in self._next_expected_batch(
+ file_indices,
+ batch_size,
+ num_epochs,
+ cycle_length=interleave_cycle_length):
+ actual_batch = self._next_actual_batch(
+ sess, label_key_provided=label_key_provided)
+ for i in range(len(expected_batch)):
+ self.assertAllEqual(expected_batch[i], actual_batch[i])
+
+
+class TextLineDatasetTestBase(test_base.DatasetTestBase):
+ """Base class for setting up and testing TextLineDataset."""
+
+ def _lineText(self, f, l):
+ return compat.as_bytes("%d: %d" % (f, l))
+
+ def _createFiles(self,
+ num_files,
+ num_lines,
+ crlf=False,
+ compression_type=None):
+ filenames = []
+ for i in range(num_files):
+ fn = os.path.join(self.get_temp_dir(), "text_line.%d.txt" % i)
+ filenames.append(fn)
+ contents = []
+ for j in range(num_lines):
+ contents.append(self._lineText(i, j))
+ # Always include a newline after the record unless it is
+ # at the end of the file, in which case we include it
+ if j + 1 != num_lines or i == 0:
+ contents.append(b"\r\n" if crlf else b"\n")
+ contents = b"".join(contents)
+
+ if not compression_type:
+ with open(fn, "wb") as f:
+ f.write(contents)
+ elif compression_type == "GZIP":
+ with gzip.GzipFile(fn, "wb") as f:
+ f.write(contents)
+ elif compression_type == "ZLIB":
+ contents = zlib.compress(contents)
+ with open(fn, "wb") as f:
+ f.write(contents)
+ else:
+ raise ValueError("Unsupported compression_type", compression_type)
+
+ return filenames
+
+
+class TFRecordDatasetTestBase(test_base.DatasetTestBase):
+ """Base class for setting up and testing TFRecordDataset."""
+
+ def setUp(self):
+ super(TFRecordDatasetTestBase, self).setUp()
+ self._num_files = 2
+ self._num_records = 7
+
+ self.test_filenames = self._createFiles()
+
+ self.filenames = array_ops.placeholder(dtypes.string, shape=[None])
+ self.num_epochs = array_ops.placeholder_with_default(
+ constant_op.constant(1, dtypes.int64), shape=[])
+ self.compression_type = array_ops.placeholder_with_default("", shape=[])
+ self.batch_size = array_ops.placeholder(dtypes.int64, shape=[])
+
+ repeat_dataset = core_readers.TFRecordDataset(
+ self.filenames, self.compression_type).repeat(self.num_epochs)
+ batch_dataset = repeat_dataset.batch(self.batch_size)
+
+ iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types)
+ self.init_op = iterator.make_initializer(repeat_dataset)
+ self.init_batch_op = iterator.make_initializer(batch_dataset)
+ self.get_next = iterator.get_next()
+
+ def _record(self, f, r):
+ return compat.as_bytes("Record %d of file %d" % (r, f))
+
+ def _createFiles(self):
+ filenames = []
+ for i in range(self._num_files):
+ fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i)
+ filenames.append(fn)
+ writer = python_io.TFRecordWriter(fn)
+ for j in range(self._num_records):
+ writer.write(self._record(i, j))
+ writer.close()
+ return filenames
diff --git a/tensorflow/python/data/experimental/kernel_tests/resample_test.py b/tensorflow/python/data/experimental/kernel_tests/resample_test.py
new file mode 100644
index 0000000000..775648c943
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/resample_test.py
@@ -0,0 +1,182 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+from absl.testing import parameterized
+import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+from tensorflow.python.data.experimental.ops import resampling
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import string_ops
+from tensorflow.python.platform import test
+from tensorflow.python.util import compat
+
+
+def _time_resampling(
+ test_obj, data_np, target_dist, init_dist, num_to_sample):
+ dataset = dataset_ops.Dataset.from_tensor_slices(data_np).repeat()
+
+ # Reshape distribution via rejection sampling.
+ dataset = dataset.apply(
+ resampling.rejection_resample(
+ class_func=lambda x: x,
+ target_dist=target_dist,
+ initial_dist=init_dist,
+ seed=142))
+
+ get_next = dataset.make_one_shot_iterator().get_next()
+
+ with test_obj.test_session() as sess:
+ start_time = time.time()
+ for _ in xrange(num_to_sample):
+ sess.run(get_next)
+ end_time = time.time()
+
+ return end_time - start_time
+
+
+class ResampleTest(test_base.DatasetTestBase, parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ ("InitialDistributionKnown", True),
+ ("InitialDistributionUnknown", False))
+ def testDistribution(self, initial_known):
+ classes = np.random.randint(5, size=(20000,)) # Uniformly sampled
+ target_dist = [0.9, 0.05, 0.05, 0.0, 0.0]
+ initial_dist = [0.2] * 5 if initial_known else None
+ classes = math_ops.to_int64(classes) # needed for Windows build.
+ dataset = dataset_ops.Dataset.from_tensor_slices(classes).shuffle(
+ 200, seed=21).map(lambda c: (c, string_ops.as_string(c))).repeat()
+
+ get_next = dataset.apply(
+ resampling.rejection_resample(
+ target_dist=target_dist,
+ initial_dist=initial_dist,
+ class_func=lambda c, _: c,
+ seed=27)).make_one_shot_iterator().get_next()
+
+ with self.cached_session() as sess:
+ returned = []
+ while len(returned) < 4000:
+ returned.append(sess.run(get_next))
+
+ returned_classes, returned_classes_and_data = zip(*returned)
+ _, returned_data = zip(*returned_classes_and_data)
+ self.assertAllEqual([compat.as_bytes(str(c))
+ for c in returned_classes], returned_data)
+ total_returned = len(returned_classes)
+ class_counts = np.array([
+ len([True for v in returned_classes if v == c])
+ for c in range(5)])
+ returned_dist = class_counts / total_returned
+ self.assertAllClose(target_dist, returned_dist, atol=1e-2)
+
+ @parameterized.named_parameters(
+ ("OnlyInitial", True),
+ ("NotInitial", False))
+ def testEdgeCasesSampleFromInitialDataset(self, only_initial_dist):
+ init_dist = [0.5, 0.5]
+ target_dist = [0.5, 0.5] if only_initial_dist else [0.0, 1.0]
+ num_classes = len(init_dist)
+ # We don't need many samples to test that this works.
+ num_samples = 100
+ data_np = np.random.choice(num_classes, num_samples, p=init_dist)
+
+ dataset = dataset_ops.Dataset.from_tensor_slices(data_np)
+
+ # Reshape distribution.
+ dataset = dataset.apply(
+ resampling.rejection_resample(
+ class_func=lambda x: x,
+ target_dist=target_dist,
+ initial_dist=init_dist))
+
+ get_next = dataset.make_one_shot_iterator().get_next()
+
+ with self.cached_session() as sess:
+ returned = []
+ with self.assertRaises(errors.OutOfRangeError):
+ while True:
+ returned.append(sess.run(get_next))
+
+ def testRandomClasses(self):
+ init_dist = [0.25, 0.25, 0.25, 0.25]
+ target_dist = [0.0, 0.0, 0.0, 1.0]
+ num_classes = len(init_dist)
+ # We don't need many samples to test a dirac-delta target distribution.
+ num_samples = 100
+ data_np = np.random.choice(num_classes, num_samples, p=init_dist)
+
+ dataset = dataset_ops.Dataset.from_tensor_slices(data_np)
+
+ # Apply a random mapping that preserves the data distribution.
+ def _remap_fn(_):
+ return math_ops.cast(random_ops.random_uniform([1]) * num_classes,
+ dtypes.int32)[0]
+ dataset = dataset.map(_remap_fn)
+
+ # Reshape distribution.
+ dataset = dataset.apply(
+ resampling.rejection_resample(
+ class_func=lambda x: x,
+ target_dist=target_dist,
+ initial_dist=init_dist))
+
+ get_next = dataset.make_one_shot_iterator().get_next()
+
+ with self.cached_session() as sess:
+ returned = []
+ with self.assertRaises(errors.OutOfRangeError):
+ while True:
+ returned.append(sess.run(get_next))
+
+ classes, _ = zip(*returned)
+ bincount = np.bincount(
+ np.array(classes),
+ minlength=num_classes).astype(np.float32) / len(classes)
+
+ self.assertAllClose(target_dist, bincount, atol=1e-2)
+
+
+class ResampleDatasetBenchmark(test.Benchmark):
+
+ def benchmarkResamplePerformance(self):
+ init_dist = [0.25, 0.25, 0.25, 0.25]
+ target_dist = [0.0, 0.0, 0.0, 1.0]
+ num_classes = len(init_dist)
+ # We don't need many samples to test a dirac-delta target distribution
+ num_samples = 1000
+ data_np = np.random.choice(num_classes, num_samples, p=init_dist)
+
+ resample_time = _time_resampling(
+ self, data_np, target_dist, init_dist, num_to_sample=1000)
+
+ self.report_benchmark(
+ iters=1000, wall_time=resample_time, name="benchmark_resample")
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/scan_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/scan_dataset_op_test.py
new file mode 100644
index 0000000000..78ec80de23
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/scan_dataset_op_test.py
@@ -0,0 +1,172 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import itertools
+
+import numpy as np
+
+from tensorflow.python.data.experimental.ops import scan_ops
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class ScanDatasetTest(test_base.DatasetTestBase):
+
+ def _counting_dataset(self, start, scan_fn):
+ return dataset_ops.Dataset.from_tensors(0).repeat().apply(
+ scan_ops.scan(start, scan_fn))
+
+ def testCount(self):
+ def make_scan_fn(step):
+ return lambda state, _: (state + step, state)
+
+ start = array_ops.placeholder(dtypes.int32, shape=[])
+ step = array_ops.placeholder(dtypes.int32, shape=[])
+ take = array_ops.placeholder(dtypes.int64, shape=[])
+ iterator = self._counting_dataset(
+ start, make_scan_fn(step)).take(take).make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+
+ for start_val, step_val, take_val in [(0, 1, 10), (0, 1, 0), (10, 1, 10),
+ (10, 2, 10), (10, -1, 10),
+ (10, -2, 10)]:
+ sess.run(iterator.initializer,
+ feed_dict={start: start_val, step: step_val, take: take_val})
+ for expected, _ in zip(
+ itertools.count(start_val, step_val), range(take_val)):
+ self.assertEqual(expected, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testFibonacci(self):
+ iterator = dataset_ops.Dataset.from_tensors(1).repeat(None).apply(
+ scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1]))
+ ).make_one_shot_iterator()
+
+ if context.executing_eagerly():
+ next_element = iterator.get_next
+ else:
+ get_next = iterator.get_next()
+ next_element = lambda: get_next
+
+ self.assertEqual(1, self.evaluate(next_element()))
+ self.assertEqual(1, self.evaluate(next_element()))
+ self.assertEqual(2, self.evaluate(next_element()))
+ self.assertEqual(3, self.evaluate(next_element()))
+ self.assertEqual(5, self.evaluate(next_element()))
+ self.assertEqual(8, self.evaluate(next_element()))
+
+ def testSparseCount(self):
+ def _sparse(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=np.array([[0, 0]]),
+ values=(i * np.array([1])),
+ dense_shape=np.array([1, 1]))
+
+ def make_scan_fn(step):
+ return lambda state, _: (_sparse(state.values[0] + step), state)
+
+ start = array_ops.placeholder(dtypes.int32, shape=[])
+ step = array_ops.placeholder(dtypes.int32, shape=[])
+ take = array_ops.placeholder(dtypes.int64, shape=[])
+ iterator = self._counting_dataset(
+ _sparse(start),
+ make_scan_fn(step)).take(take).make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+
+ for start_val, step_val, take_val in [(0, 1, 10), (0, 1, 0), (10, 1, 10),
+ (10, 2, 10), (10, -1, 10),
+ (10, -2, 10)]:
+ sess.run(iterator.initializer,
+ feed_dict={start: start_val, step: step_val, take: take_val})
+ for expected, _ in zip(
+ itertools.count(start_val, step_val), range(take_val)):
+ self.assertEqual(expected, sess.run(next_element).values[0])
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testChangingStateShape(self):
+ # Test the fixed-point shape invariant calculations: start with
+ # initial values with known shapes, and use a scan function that
+ # changes the size of the state on each element.
+ def _scan_fn(state, input_value):
+ # Statically known rank, but dynamic length.
+ ret_longer_vector = array_ops.concat([state[0], state[0]], 0)
+ # Statically unknown rank.
+ ret_larger_rank = array_ops.expand_dims(state[1], 0)
+ return (ret_longer_vector, ret_larger_rank), (state, input_value)
+
+ dataset = dataset_ops.Dataset.from_tensors(0).repeat(5).apply(
+ scan_ops.scan(([0], 1), _scan_fn))
+ self.assertEqual([None], dataset.output_shapes[0][0].as_list())
+ self.assertIs(None, dataset.output_shapes[0][1].ndims)
+ self.assertEqual([], dataset.output_shapes[1].as_list())
+
+ iterator = dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ for i in range(5):
+ (longer_vector_val, larger_rank_val), _ = sess.run(next_element)
+ self.assertAllEqual([0] * (2**i), longer_vector_val)
+ self.assertAllEqual(np.array(1, ndmin=i), larger_rank_val)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testIncorrectStateType(self):
+
+ def _scan_fn(state, _):
+ return constant_op.constant(1, dtype=dtypes.int64), state
+
+ dataset = dataset_ops.Dataset.range(10)
+ with self.assertRaisesRegexp(
+ TypeError,
+ "The element types for the new state must match the initial state."):
+ dataset.apply(
+ scan_ops.scan(constant_op.constant(1, dtype=dtypes.int32), _scan_fn))
+
+ def testIncorrectReturnType(self):
+
+ def _scan_fn(unused_state, unused_input_value):
+ return constant_op.constant(1, dtype=dtypes.int64)
+
+ dataset = dataset_ops.Dataset.range(10)
+ with self.assertRaisesRegexp(
+ TypeError,
+ "The scan function must return a pair comprising the new state and the "
+ "output value."):
+ dataset.apply(
+ scan_ops.scan(constant_op.constant(1, dtype=dtypes.int32), _scan_fn))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD b/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD
new file mode 100644
index 0000000000..20c02a5366
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD
@@ -0,0 +1,555 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+py_library(
+ name = "dataset_serialization_test_base",
+ srcs = [
+ "dataset_serialization_test_base.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:lookup_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:training",
+ "//tensorflow/python:util",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/data/experimental/ops:iterator_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "batch_dataset_serialization_test",
+ size = "medium",
+ srcs = ["batch_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/experimental/ops:batching",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "cache_dataset_serialization_test",
+ size = "small",
+ srcs = ["cache_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+py_test(
+ name = "concatenate_dataset_serialization_test",
+ size = "small",
+ srcs = ["concatenate_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "csv_dataset_serialization_test",
+ size = "small",
+ srcs = ["csv_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python/data/experimental/ops:readers",
+ ],
+)
+
+py_test(
+ name = "dataset_constructor_serialization_test",
+ size = "medium",
+ srcs = ["dataset_constructor_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "filter_dataset_serialization_test",
+ size = "medium",
+ srcs = ["filter_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "fixed_length_record_dataset_serialization_test",
+ size = "medium",
+ srcs = ["fixed_length_record_dataset_serialization_test.py"],
+ shard_count = 4,
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/kernel_tests:reader_dataset_ops_test_base",
+ "//tensorflow/python/data/ops:readers",
+ ],
+)
+
+py_test(
+ name = "flat_map_dataset_serialization_test",
+ size = "medium",
+ srcs = ["flat_map_dataset_serialization_test.py"],
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:function",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:random_ops",
+ "//tensorflow/python:sparse_ops",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "group_by_reducer_serialization_test",
+ size = "medium",
+ srcs = ["group_by_reducer_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/ops:grouping",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "group_by_window_serialization_test",
+ size = "medium",
+ srcs = ["group_by_window_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/ops:grouping",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "ignore_errors_serialization_test",
+ size = "small",
+ srcs = ["ignore_errors_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/ops:error_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "interleave_dataset_serialization_test",
+ size = "medium",
+ srcs = ["interleave_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:sparse_ops",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+py_test(
+ name = "map_and_batch_dataset_serialization_test",
+ size = "medium",
+ srcs = ["map_and_batch_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python/data/experimental/ops:batching",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "map_dataset_serialization_test",
+ size = "medium",
+ srcs = ["map_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:function",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:random_ops",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "optimize_dataset_serialization_test",
+ size = "small",
+ srcs = ["optimize_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/ops:optimization",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "padded_batch_dataset_serialization_test",
+ size = "medium",
+ srcs = ["padded_batch_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:string_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "parallel_interleave_dataset_serialization_test",
+ size = "medium",
+ srcs = ["parallel_interleave_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:sparse_ops",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/experimental/ops:interleave_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "parallel_map_dataset_serialization_test",
+ size = "medium",
+ srcs = ["parallel_map_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:function",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:random_ops",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "parse_example_dataset_serialization_test",
+ size = "medium",
+ srcs = ["parse_example_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/kernel_tests:reader_dataset_ops_test_base",
+ ],
+)
+
+py_test(
+ name = "prefetch_dataset_serialization_test",
+ size = "small",
+ srcs = ["prefetch_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "range_dataset_serialization_test",
+ size = "small",
+ srcs = ["range_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:io_ops",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "sample_from_datasets_serialization_test",
+ size = "medium",
+ srcs = ["sample_from_datasets_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/ops:interleave_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "scan_dataset_serialization_test",
+ size = "small",
+ srcs = ["scan_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/ops:scan_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "sequence_dataset_serialization_test",
+ size = "medium",
+ srcs = ["sequence_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "serialization_integration_test",
+ size = "small",
+ srcs = ["serialization_integration_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python/data/experimental/ops:iterator_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "shuffle_and_repeat_dataset_serialization_test",
+ size = "medium",
+ srcs = ["shuffle_and_repeat_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/ops:shuffle_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "shuffle_dataset_serialization_test",
+ size = "medium",
+ srcs = ["shuffle_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python/data/experimental/ops:iterator_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "sql_dataset_serialization_test",
+ size = "small",
+ srcs = ["sql_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python/data/experimental/kernel_tests:sql_dataset_op_test_base",
+ "//tensorflow/python/data/experimental/ops:readers",
+ ],
+)
+
+py_test(
+ name = "stats_dataset_serialization_test",
+ size = "medium",
+ srcs = ["stats_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python/data/experimental/ops:stats_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "textline_dataset_serialization_test",
+ size = "medium",
+ srcs = ["textline_dataset_serialization_test.py"],
+ shard_count = 4,
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/kernel_tests:reader_dataset_ops_test_base",
+ "//tensorflow/python/data/ops:readers",
+ ],
+)
+
+py_test(
+ name = "tf_record_dataset_serialization_test",
+ size = "medium",
+ srcs = ["tf_record_dataset_serialization_test.py"],
+ shard_count = 4,
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/kernel_tests:reader_dataset_ops_test_base",
+ "//tensorflow/python/data/ops:readers",
+ ],
+)
+
+py_test(
+ name = "unbatch_dataset_serialization_test",
+ size = "medium",
+ srcs = ["unbatch_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/ops:batching",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "unique_dataset_serialization_test",
+ size = "small",
+ srcs = ["unique_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/ops:unique",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "zip_dataset_serialization_test",
+ size = "small",
+ srcs = ["zip_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/batch_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/batch_dataset_serialization_test.py
new file mode 100644
index 0000000000..d72a6df14c
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/batch_dataset_serialization_test.py
@@ -0,0 +1,83 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the BatchDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import batching
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class BatchDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def build_dataset(self, multiplier=15.0, tensor_slice_len=2, batch_size=2):
+ components = (
+ np.arange(tensor_slice_len),
+ np.array([[1, 2, 3]]) * np.arange(tensor_slice_len)[:, np.newaxis],
+ np.array(multiplier) * np.arange(tensor_slice_len))
+
+ return dataset_ops.Dataset.from_tensor_slices(components).batch(batch_size)
+
+ def testCore(self):
+ tensor_slice_len = 8
+ batch_size = 2
+ num_outputs = tensor_slice_len // batch_size
+ self.run_core_tests(
+ lambda: self.build_dataset(15.0, tensor_slice_len, batch_size),
+ lambda: self.build_dataset(20.0, tensor_slice_len, batch_size),
+ num_outputs)
+
+ def _build_dataset_dense_to_sparse(self, components):
+ return dataset_ops.Dataset.from_tensor_slices(components).map(
+ lambda x: array_ops.fill([x], x)).apply(
+ batching.dense_to_sparse_batch(4, [12]))
+
+ def testDenseToSparseBatchDatasetCore(self):
+ components = np.random.randint(5, size=(40,)).astype(np.int32)
+ diff_comp = np.random.randint(2, size=(100,)).astype(np.int32)
+
+ num_outputs = len(components) // 4
+ self.run_core_tests(lambda: self._build_dataset_dense_to_sparse(components),
+ lambda: self._build_dataset_dense_to_sparse(diff_comp),
+ num_outputs)
+
+ def _sparse(self, i):
+ return sparse_tensor.SparseTensorValue(
+ indices=[[0]], values=(i * [1]), dense_shape=[1])
+
+ def _build_dataset_sparse(self, batch_size=5):
+ return dataset_ops.Dataset.range(10).map(self._sparse).batch(batch_size)
+
+ def testSparseCore(self):
+ self.run_core_tests(self._build_dataset_sparse,
+ lambda: self._build_dataset_sparse(2), 2)
+
+ def _build_dataset_nested_sparse(self):
+ return dataset_ops.Dataset.range(10).map(self._sparse).batch(5).batch(2)
+
+ def testNestedSparseCore(self):
+ self.run_core_tests(self._build_dataset_nested_sparse, None, 1)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/cache_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/cache_dataset_serialization_test.py
new file mode 100644
index 0000000000..2bcf77f5d8
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/cache_dataset_serialization_test.py
@@ -0,0 +1,253 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the CacheDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from absl.testing import parameterized
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import errors
+from tensorflow.python.platform import test
+
+
+class CacheDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase,
+ parameterized.TestCase):
+
+ def setUp(self):
+ self.range_size = 10
+ self.num_repeats = 3
+ self.num_outputs = self.range_size * self.num_repeats
+ self.cache_file_prefix = 'test'
+
+ def make_dataset_fn(self, is_memory):
+ if is_memory:
+ filename = ''
+ else:
+ filename = os.path.join(self.get_temp_dir(), self.cache_file_prefix)
+
+ def ds_fn():
+ return dataset_ops.Dataset.range(self.range_size).cache(filename).repeat(
+ self.num_repeats)
+
+ return ds_fn
+
+ def expected_outputs(self):
+ return list(range(self.range_size)) * self.num_repeats
+
+ @parameterized.named_parameters(
+ ('Memory', True),
+ ('File', False),
+ )
+ def testCheckpointBeforeOneEpoch(self, is_memory):
+ ds_fn = self.make_dataset_fn(is_memory)
+
+ # Generate 5 entries from iterator and save checkpoint.
+ outputs = self.gen_outputs(ds_fn, [], 5, verify_exhausted=False)
+ self.assertSequenceEqual(outputs, range(5))
+
+ # Restore from checkpoint and produce the rest of the elements from the
+ # iterator.
+ outputs.extend(
+ self.gen_outputs(
+ ds_fn, [],
+ self.num_outputs - 5,
+ ckpt_saved=True,
+ verify_exhausted=False))
+ self.assertSequenceEqual(outputs, self.expected_outputs())
+
+ @parameterized.named_parameters(
+ ('Memory', True),
+ ('File', False),
+ )
+ def testCheckpointBeforeOneEpochThenRunFewSteps(self, is_memory):
+ ds_fn = self.make_dataset_fn(is_memory)
+
+ # Generate 8 entries from iterator but save checkpoint after producing 5.
+ outputs = self.gen_outputs(
+ ds_fn, [5], 8, verify_exhausted=False, save_checkpoint_at_end=False)
+ self.assertSequenceEqual(outputs, range(8))
+
+ if is_memory:
+ outputs = outputs[:5]
+ outputs.extend(
+ self.gen_outputs(
+ ds_fn, [],
+ self.num_outputs - 5,
+ ckpt_saved=True,
+ verify_exhausted=False))
+ self.assertSequenceEqual(outputs, self.expected_outputs())
+ else:
+ # Restoring from checkpoint and running GetNext should return
+ # `AlreadExistsError` now because the lockfile already exists.
+ with self.assertRaises(errors.AlreadyExistsError):
+ self.gen_outputs(
+ ds_fn, [],
+ self.num_outputs - 5,
+ ckpt_saved=True,
+ verify_exhausted=False)
+
+ @parameterized.named_parameters(
+ ('Memory', True),
+ ('File', False),
+ )
+ def testCheckpointAfterOneEpoch(self, is_memory):
+ ds_fn = self.make_dataset_fn(is_memory)
+
+ # Generate 15 entries from iterator and save checkpoint.
+ outputs = self.gen_outputs(ds_fn, [], 15, verify_exhausted=False)
+ self.assertSequenceEqual(outputs, list(range(10)) + list(range(5)))
+
+ # Restore from checkpoint and produce the rest of the elements from the
+ # iterator.
+ outputs.extend(
+ self.gen_outputs(
+ ds_fn, [],
+ self.num_outputs - 15,
+ ckpt_saved=True,
+ verify_exhausted=False))
+ self.assertSequenceEqual(outputs, self.expected_outputs())
+
+ @parameterized.named_parameters(
+ ('Memory', True),
+ ('File', False),
+ )
+ def testCheckpointAfterOneEpochThenRunFewSteps(self, is_memory):
+ ds_fn = self.make_dataset_fn(is_memory)
+
+ # Generate 18 entries from iterator but save checkpoint after producing 15.
+ outputs = self.gen_outputs(
+ ds_fn, [15], 18, verify_exhausted=False, save_checkpoint_at_end=False)
+ self.assertSequenceEqual(outputs, list(range(10)) + list(range(8)))
+
+ outputs = list(range(10)) + list(range(5)) + self.gen_outputs(
+ ds_fn, [],
+ self.num_outputs - 15,
+ ckpt_saved=True,
+ verify_exhausted=False)
+ self.assertSequenceEqual(outputs, list(range(10)) * 3)
+
+ @parameterized.named_parameters(
+ ('Memory', True),
+ ('File', False),
+ )
+ def testCheckpointBeforeOneEpochButRunCompleteEpoch(self, is_memory):
+ ds_fn = self.make_dataset_fn(is_memory)
+
+ # Generate 13 entries from iterator but save checkpoint after producing 5.
+ outputs = self.gen_outputs(
+ ds_fn, [5], 13, verify_exhausted=False, save_checkpoint_at_end=False)
+ self.assertSequenceEqual(outputs, list(range(10)) + list(range(3)))
+
+ # Since we ran for more than one epoch, the cache was completely written.
+ # The ckpt was saved when the iterator was in cache-write mode. Test that
+ # the iterator falls back to read mode after restoring if the cache has
+ # been completely written.
+
+ outputs = list(range(5)) + self.gen_outputs(
+ ds_fn, [],
+ self.num_outputs - 5,
+ ckpt_saved=True,
+ verify_exhausted=False)
+ self.assertSequenceEqual(outputs, list(range(10)) * 3)
+
+ @parameterized.named_parameters(
+ ('Memory', True),
+ ('File', False),
+ )
+ def testCheckpointUnusedWriterIterator(self, is_memory):
+ ds_fn = self.make_dataset_fn(is_memory)
+
+ # Checkpoint before get_next is called even once.
+ outputs = self.gen_outputs(ds_fn, [], 0, verify_exhausted=False)
+ self.assertSequenceEqual(outputs, [])
+
+ outputs = self.gen_outputs(
+ ds_fn, [], self.num_outputs, ckpt_saved=True, verify_exhausted=False)
+ self.assertSequenceEqual(outputs, list(range(10)) * 3)
+
+ @parameterized.named_parameters(
+ ('Memory', True),
+ ('File', False),
+ )
+ def testCheckpointUnusedMidwayWriterIterator(self, is_memory):
+ ds_fn = self.make_dataset_fn(is_memory)
+
+ # Produce 5 elements and checkpoint.
+ outputs = self.gen_outputs(ds_fn, [], 5, verify_exhausted=False)
+ self.assertSequenceEqual(outputs, range(5))
+
+ # Restore from checkpoint, then produce no elements and checkpoint.
+ outputs.extend(
+ self.gen_outputs(ds_fn, [], 0, ckpt_saved=True, verify_exhausted=False))
+ self.assertSequenceEqual(outputs, range(5))
+
+ # Restore from checkpoint and produce rest of the elements.
+ outputs.extend(
+ self.gen_outputs(
+ ds_fn, [],
+ self.num_outputs - 5,
+ ckpt_saved=True,
+ verify_exhausted=False))
+ self.assertSequenceEqual(outputs, list(range(10)) * 3)
+
+ @parameterized.named_parameters(
+ ('Memory', True),
+ ('File', False),
+ )
+ def testUnusedCheckpointError(self, is_memory):
+ ds_fn = self.make_dataset_fn(is_memory)
+
+ # Produce 5 elements and save ckpt.
+ outputs = self.gen_outputs(ds_fn, [], 5, verify_exhausted=False)
+ self.assertSequenceEqual(outputs, range(5))
+
+ if is_memory:
+ outputs = self.gen_outputs(
+ ds_fn, [], self.num_outputs, verify_exhausted=False)
+ self.assertSequenceEqual(outputs, self.expected_outputs())
+ else:
+ # Since the complete cache has not been written, a new iterator which does
+ # not restore the checkpoint will throw an error since there is a partial
+ # cache shard.
+ with self.assertRaises(errors.AlreadyExistsError):
+ outputs = self.gen_outputs(
+ ds_fn, [], self.num_outputs, verify_exhausted=False)
+
+ @parameterized.named_parameters(
+ ('Memory', True),
+ ('File', False),
+ )
+ def testIgnoreCheckpointIfCacheWritten(self, is_memory):
+ ds_fn = self.make_dataset_fn(is_memory)
+
+ # Produce 15 elements and save ckpt. This will write the complete cache.
+ outputs = self.gen_outputs(ds_fn, [], 15, verify_exhausted=False)
+ self.assertSequenceEqual(outputs, list(range(10)) + list(range(5)))
+
+ # Build the iterator again but do not restore from ckpt. Since the cache
+ # has already been written we should be able to use it.
+ outputs = self.gen_outputs(
+ ds_fn, [], self.num_outputs, verify_exhausted=False)
+ self.assertSequenceEqual(outputs, list(range(10)) * 3)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/concatenate_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/concatenate_dataset_serialization_test.py
new file mode 100644
index 0000000000..c075dff8cb
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/concatenate_dataset_serialization_test.py
@@ -0,0 +1,49 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the ConcatenateDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.platform import test
+
+
+class ConcatenateDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_concatenate_dataset(self, var_array):
+ input_components = (np.tile(np.array([[1], [2], [3], [4]]), 20),
+ np.tile(np.array([[12], [13], [14], [15]]), 4))
+ to_concatenate_components = (np.tile(
+ np.array([[5], [6], [7], [8], [9]]), 20), var_array)
+
+ return dataset_ops.Dataset.from_tensor_slices(input_components).concatenate(
+ dataset_ops.Dataset.from_tensor_slices(to_concatenate_components))
+
+ def testConcatenateCore(self):
+ num_outputs = 9
+ array = np.tile(np.array([[16], [17], [18], [19], [20]]), 15)
+ diff_array = np.array([[1], [2], [3], [4], [5]])
+ self.run_core_tests(lambda: self._build_concatenate_dataset(array),
+ lambda: self._build_concatenate_dataset(diff_array),
+ num_outputs)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/csv_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/csv_dataset_serialization_test.py
new file mode 100644
index 0000000000..d4983492e7
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/csv_dataset_serialization_test.py
@@ -0,0 +1,73 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the CsvDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gzip
+import os
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import readers
+from tensorflow.python.platform import test
+
+
+class CsvDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def setUp(self):
+ self._num_cols = 7
+ self._num_rows = 10
+ self._num_epochs = 14
+ self._num_outputs = self._num_rows * self._num_epochs
+
+ inputs = [
+ ",".join(str(self._num_cols * j + i)
+ for i in range(self._num_cols))
+ for j in range(self._num_rows)
+ ]
+ contents = "\n".join(inputs).encode("utf-8")
+
+ self._filename = os.path.join(self.get_temp_dir(), "file.csv")
+ self._compressed = os.path.join(self.get_temp_dir(),
+ "comp.csv") # GZip compressed
+
+ with open(self._filename, "wb") as f:
+ f.write(contents)
+ with gzip.GzipFile(self._compressed, "wb") as f:
+ f.write(contents)
+
+ def ds_func(self, **kwargs):
+ compression_type = kwargs.get("compression_type", None)
+ if compression_type == "GZIP":
+ filename = self._compressed
+ elif compression_type is None:
+ filename = self._filename
+ else:
+ raise ValueError("Invalid compression type:", compression_type)
+
+ return readers.CsvDataset(filename, **kwargs).repeat(self._num_epochs)
+
+ def testSerializationCore(self):
+ defs = [[0]] * self._num_cols
+ self.run_core_tests(
+ lambda: self.ds_func(record_defaults=defs, buffer_size=2),
+ lambda: self.ds_func(record_defaults=defs, buffer_size=12),
+ self._num_outputs)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_constructor_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_constructor_serialization_test.py
new file mode 100644
index 0000000000..41a095fb1a
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_constructor_serialization_test.py
@@ -0,0 +1,95 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the dataset constructors serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.platform import test
+
+
+class FromTensorsSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_tensor_dataset(self, variable_array):
+ components = (variable_array, np.array([1, 2, 3]), np.array(37.0))
+
+ return dataset_ops.Dataset.from_tensors(components)
+
+ def testFromTensorsCore(self):
+ # Equal length components
+ arr = np.array(1)
+ num_outputs = 1
+ diff_arr = np.array(2)
+ self.run_core_tests(lambda: self._build_tensor_dataset(arr),
+ lambda: self._build_tensor_dataset(diff_arr),
+ num_outputs)
+
+
+class FromTensorSlicesSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_tensor_slices_dataset(self, components):
+ return dataset_ops.Dataset.from_tensor_slices(components)
+
+ def testFromTensorSlicesCore(self):
+ # Equal length components
+ components = (np.tile(np.array([[1], [2], [3], [4]]), 20),
+ np.tile(np.array([[12], [13], [14], [15]]), 22),
+ np.array([37.0, 38.0, 39.0, 40.0]))
+
+ diff_comp = (np.tile(np.array([[1], [2], [3], [4]]), 20),
+ np.tile(np.array([[5], [6], [7], [8]]), 22),
+ np.array([1.0, 2.0, 3.0, 4.0]))
+
+ dict_components = {"foo": [1, 2, 3], "bar": [[4.0], [5.0], [6.0]]}
+
+ self.run_core_tests(lambda: self._build_tensor_slices_dataset(components),
+ lambda: self._build_tensor_slices_dataset(diff_comp), 4)
+ self.run_core_tests(
+ lambda: self._build_tensor_slices_dataset(dict_components), None, 3)
+
+
+class FromSparseTensorSlicesSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_sparse_tensor_slice_dataset(self, slices):
+ indices = np.array(
+ [[i, j] for i in range(len(slices)) for j in range(len(slices[i]))],
+ dtype=np.int64)
+ values = np.array([val for s in slices for val in s], dtype=np.float64)
+ dense_shape = np.array(
+ [len(slices), max(len(s) for s in slices) + 1], dtype=np.int64)
+ sparse_components = sparse_tensor.SparseTensor(indices, values, dense_shape)
+ return dataset_ops.Dataset.from_sparse_tensor_slices(sparse_components)
+
+ def testFromSparseTensorSlicesCore(self):
+ slices = [[1., 2., 3.], [1.], [1.], [1., 2.], [], [1., 2.], [], [], []]
+ diff_slices = [[1., 2.], [2.], [2., 3., 4.], [], [], []]
+
+ self.run_core_tests(
+ lambda: self._build_sparse_tensor_slice_dataset(slices),
+ lambda: self._build_sparse_tensor_slice_dataset(diff_slices),
+ 9,
+ sparse_tensors=True)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_serialization_test_base.py b/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_serialization_test_base.py
new file mode 100644
index 0000000000..7f435b8239
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_serialization_test_base.py
@@ -0,0 +1,692 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Base class for testing serializable datasets."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+import numpy as np
+
+from tensorflow.python.data.experimental.ops import iterator_ops as contrib_iterator_ops
+from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import lookup_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import test
+from tensorflow.python.training import checkpoint_management
+from tensorflow.python.training import saver as saver_lib
+from tensorflow.python.util import nest
+
+
+def remove_variants(get_next_op):
+ # TODO(b/72408568): Remove this once session.run can get
+ # variant tensors.
+ """Remove variants from a nest structure, so sess.run will execute."""
+
+ def _remove_variant(x):
+ if isinstance(x, ops.Tensor) and x.dtype == dtypes.variant:
+ return ()
+ else:
+ return x
+
+ return nest.map_structure(_remove_variant, get_next_op)
+
+
+class DatasetSerializationTestBase(test.TestCase):
+ """Base class for testing serializable datasets."""
+
+ def tearDown(self):
+ self._delete_ckpt()
+
+ # TODO(b/72657739): Remove sparse_tensor argument, which is to test the
+ # (deprecated) saveable `SparseTensorSliceDataset`, once the API
+ # `from_sparse_tensor_slices()`and related tests are deleted.
+ def run_core_tests(self, ds_fn1, ds_fn2, num_outputs, sparse_tensors=False):
+ """Runs the core tests.
+
+ Args:
+ ds_fn1: 0-argument function that returns a Dataset.
+ ds_fn2: 0-argument function that returns a Dataset different from
+ ds_fn1. If None, verify_restore_in_modified_graph test is not run.
+ num_outputs: Total number of outputs expected from this Dataset.
+ sparse_tensors: Whether dataset is built from SparseTensor(s).
+
+ Raises:
+ AssertionError if any test fails.
+ """
+ self.verify_unused_iterator(
+ ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
+ self.verify_fully_used_iterator(
+ ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
+ self.verify_exhausted_iterator(
+ ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
+ self.verify_init_before_restore(
+ ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
+ self.verify_multiple_breaks(
+ ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
+ self.verify_reset_restored_iterator(
+ ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
+ self.verify_restore_in_empty_graph(
+ ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
+ if ds_fn2:
+ self.verify_restore_in_modified_graph(
+ ds_fn1, ds_fn2, num_outputs, sparse_tensors=sparse_tensors)
+
+ def verify_unused_iterator(self,
+ ds_fn,
+ num_outputs,
+ sparse_tensors=False,
+ verify_exhausted=True):
+ """Verifies that saving and restoring an unused iterator works.
+
+ Args:
+ ds_fn: See `run_core_tests`.
+ num_outputs: See `run_core_tests`.
+ sparse_tensors: See `run_core_tests`.
+ verify_exhausted: See `gen_outputs`.
+
+ Raises:
+ AssertionError if any test fails.
+ """
+ self.verify_run_with_breaks(
+ ds_fn, [0],
+ num_outputs,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=verify_exhausted)
+
+ def verify_fully_used_iterator(self, ds_fn, num_outputs,
+ sparse_tensors=False):
+ """Verifies that saving and restoring a fully used iterator works.
+
+ Note that this only checks saving and restoring an iterator from which
+ `num_outputs` items have been produced but does not check for an
+ exhausted iterator, i.e., one from which an OutOfRange error has been
+ returned.
+
+ Args:
+ ds_fn: See `run_core_tests`.
+ num_outputs: See `run_core_tests`.
+ sparse_tensors: See `run_core_tests`.
+
+ Raises:
+ AssertionError if test fails.
+ """
+ self.verify_run_with_breaks(
+ ds_fn, [num_outputs], num_outputs, sparse_tensors=sparse_tensors)
+
+ def verify_exhausted_iterator(self, ds_fn, num_outputs, sparse_tensors=False):
+ """Verifies that saving and restoring an exhausted iterator works.
+
+ An exhausted iterator is one which has returned an OutOfRange error.
+
+ Args:
+ ds_fn: See `run_core_tests`.
+ num_outputs: See `run_core_tests`.
+ sparse_tensors: See `run_core_tests`.
+
+ Raises:
+ AssertionError if any test fails.
+ """
+ self.gen_outputs(
+ ds_fn, [],
+ num_outputs,
+ verify_exhausted=True,
+ sparse_tensors=sparse_tensors)
+ actual = self.gen_outputs(
+ ds_fn, [],
+ 0,
+ ckpt_saved=True,
+ verify_exhausted=True,
+ sparse_tensors=sparse_tensors)
+ self.assertEqual(len(actual), 0)
+
+ def verify_init_before_restore(self,
+ ds_fn,
+ num_outputs,
+ sparse_tensors=False,
+ verify_exhausted=True):
+ """Verifies that restoring into an already initialized iterator works.
+
+ Args:
+ ds_fn: See `run_core_tests`.
+ num_outputs: See `run_core_tests`.
+ sparse_tensors: See `run_core_tests`.
+ verify_exhausted: See `gen_outputs`.
+
+ Raises:
+ AssertionError if any test fails.
+ """
+ self.verify_run_with_breaks(
+ ds_fn,
+ self.gen_break_points(num_outputs),
+ num_outputs,
+ init_before_restore=True,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=verify_exhausted)
+
+ def verify_multiple_breaks(self,
+ ds_fn,
+ num_outputs,
+ num_breaks=10,
+ sparse_tensors=False,
+ verify_exhausted=True):
+ """Attempts to save/restore at multiple break points.
+
+ Args:
+ ds_fn: See `run_core_tests`.
+ num_outputs: See `run_core_tests`.
+ num_breaks: The number of break points. These are uniformly spread in
+ [0, num_outputs] both inclusive.
+ sparse_tensors: See `run_core_tests`.
+ verify_exhausted: See `gen_outputs`.
+
+ Raises:
+ AssertionError if any test fails.
+ """
+ self.verify_run_with_breaks(
+ ds_fn,
+ self.gen_break_points(num_outputs, num_breaks),
+ num_outputs,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=verify_exhausted)
+
+ def verify_reset_restored_iterator(self,
+ ds_fn,
+ num_outputs,
+ break_point=None,
+ sparse_tensors=False,
+ verify_exhausted=True):
+ """Attempts to re-initialize a restored iterator.
+
+ This is useful when restoring a training checkpoint during validation.
+
+ Args:
+ ds_fn: See `run_core_tests`.
+ num_outputs: See `run_core_tests`.
+ break_point: Break point. Optional. Defaults to num_outputs/2.
+ sparse_tensors: See `run_core_tests`.
+ verify_exhausted: See `gen_outputs`.
+
+ Raises:
+ AssertionError if any test fails.
+ """
+ break_point = num_outputs // 2 if not break_point else break_point
+
+ # Collect ground truth containing all outputs.
+ expected = self.gen_outputs(
+ ds_fn, [],
+ num_outputs,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=verify_exhausted)
+
+ # Skip some items and save checkpoint.
+ self.gen_outputs(
+ ds_fn, [],
+ break_point,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=False)
+
+ actual = []
+ # Restore from checkpoint and then run init_op.
+ with ops.Graph().as_default() as g:
+ saver = self._import_meta_graph()
+ init_op, get_next_op = self._get_iterator_ops_from_collection(
+ ds_fn, sparse_tensors=sparse_tensors)
+ get_next_op = remove_variants(get_next_op)
+ with self.session(graph=g) as sess:
+ self._restore(saver, sess)
+ self._initialize(init_op, sess)
+ for _ in range(num_outputs):
+ actual.append(sess.run(get_next_op))
+ if verify_exhausted:
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
+ self.match(expected, actual)
+
+ def verify_restore_in_modified_graph(self,
+ ds_fn1,
+ ds_fn2,
+ num_outputs,
+ break_point=None,
+ sparse_tensors=False,
+ verify_exhausted=True):
+ """Attempts to restore an iterator in a modified graph.
+
+ Builds an input pipeline using ds_fn1, runs it for `break_point` steps
+ and saves a checkpoint. Then builds a new graph using ds_fn2, restores
+ the checkpoint from ds_fn1 and verifies that the restore is successful.
+
+ Args:
+ ds_fn1: See `run_core_tests`.
+ ds_fn2: See `run_core_tests`.
+ num_outputs: See `run_core_tests`.
+ break_point: Break point. Optional. Defaults to num_outputs/2.
+ sparse_tensors: See `run_core_tests`.
+ verify_exhausted: See `gen_outputs`.
+
+ Raises:
+ AssertionError if any test fails.
+ """
+ break_point = num_outputs // 2 if not break_point else break_point
+
+ # Skip `break_point` items and store the remaining produced from ds_fn1
+ # in `expected`.
+ self.gen_outputs(
+ ds_fn1, [],
+ break_point,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=False)
+ expected = self.gen_outputs(
+ ds_fn1, [],
+ num_outputs - break_point,
+ ckpt_saved=True,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=verify_exhausted)
+
+ # Generate `break_point` items from ds_fn1 and save checkpoint.
+ self.gen_outputs(
+ ds_fn1, [],
+ break_point,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=False)
+
+ actual = []
+ # Build graph for ds_fn2 but load checkpoint for ds_fn1.
+ with ops.Graph().as_default() as g:
+ _, get_next_op, saver = self._build_graph(
+ ds_fn2, sparse_tensors=sparse_tensors)
+ get_next_op = remove_variants(get_next_op)
+ with self.session(graph=g) as sess:
+ self._restore(saver, sess)
+ for _ in range(num_outputs - break_point):
+ actual.append(sess.run(get_next_op))
+ if verify_exhausted:
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
+
+ self.match(expected, actual)
+
+ def verify_restore_in_empty_graph(self,
+ ds_fn,
+ num_outputs,
+ break_point=None,
+ sparse_tensors=False,
+ verify_exhausted=True):
+ """Attempts to restore an iterator in an empty graph.
+
+ Builds an input pipeline using ds_fn, runs it for `break_point` steps
+ and saves a checkpoint. Then builds a new empty graph, restores
+ the checkpoint from ds_fn and verifies that the restore is successful.
+
+ Args:
+ ds_fn: See `run_core_tests`.
+ num_outputs: See `run_core_tests`.
+ break_point: Break point. Optional. Defaults to num_outputs/2.
+ sparse_tensors: See `run_core_tests`.
+ verify_exhausted: See `gen_outputs`.
+
+ Raises:
+ AssertionError if any test fails.
+ """
+ break_point = num_outputs // 2 if not break_point else break_point
+
+ # Skip `break_point` items and store the remaining produced from ds_fn
+ # in `expected`.
+ self.gen_outputs(
+ ds_fn, [],
+ break_point,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=False)
+ expected = self.gen_outputs(
+ ds_fn, [],
+ num_outputs - break_point,
+ ckpt_saved=True,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=verify_exhausted)
+
+ # Generate `break_point` items from ds_fn and save checkpoint.
+ self.gen_outputs(
+ ds_fn, [],
+ break_point,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=False)
+
+ actual = []
+ # Build an empty graph but load checkpoint for ds_fn.
+ with ops.Graph().as_default() as g:
+ get_next_op, saver = self._build_empty_graph(
+ ds_fn, sparse_tensors=sparse_tensors)
+ get_next_op = remove_variants(get_next_op)
+ with self.session(graph=g) as sess:
+ self._restore(saver, sess)
+ for _ in range(num_outputs - break_point):
+ actual.append(sess.run(get_next_op))
+ if verify_exhausted:
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
+
+ self.match(expected, actual)
+
+ def verify_error_on_save(self,
+ ds_fn,
+ num_outputs,
+ error,
+ break_point=None,
+ sparse_tensors=False):
+ """Attempts to save a non-saveable iterator.
+
+ Args:
+ ds_fn: See `run_core_tests`.
+ num_outputs: See `run_core_tests`.
+ error: Declared error when trying to save iterator.
+ break_point: Break point. Optional. Defaults to num_outputs/2.
+ sparse_tensors: See `run_core_tests`.
+
+ Raises:
+ AssertionError if any test fails.
+ """
+
+ break_point = num_outputs // 2 if not break_point else break_point
+ with ops.Graph().as_default() as g:
+ init_op, get_next_op, saver = self._build_graph(
+ ds_fn, sparse_tensors=sparse_tensors)
+ get_next_op = remove_variants(get_next_op)
+ with self.session(graph=g) as sess:
+ self._initialize(init_op, sess)
+ for _ in range(break_point):
+ sess.run(get_next_op)
+ with self.assertRaises(error):
+ self._save(sess, saver)
+
+ def verify_run_with_breaks(self,
+ ds_fn,
+ break_points,
+ num_outputs,
+ init_before_restore=False,
+ sparse_tensors=False,
+ verify_exhausted=True):
+ """Verifies that ds_fn() produces the same outputs with and without breaks.
+
+ 1. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it
+ *without* stopping at break points.
+ 2. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it
+ with stopping at break points.
+
+ Deep matches outputs from 1 and 2.
+
+ Args:
+ ds_fn: See `gen_outputs`.
+ break_points: See `gen_outputs`.
+ num_outputs: See `gen_outputs`.
+ init_before_restore: See `gen_outputs`.
+ sparse_tensors: See `run_core_tests`.
+ verify_exhausted: See `gen_outputs`.
+
+ Raises:
+ AssertionError if any test fails.
+ """
+ expected = self.gen_outputs(
+ ds_fn, [],
+ num_outputs,
+ init_before_restore=init_before_restore,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=verify_exhausted)
+
+ actual = self.gen_outputs(
+ ds_fn,
+ break_points,
+ num_outputs,
+ init_before_restore=init_before_restore,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=verify_exhausted)
+
+ self.match(expected, actual)
+
+ def gen_outputs(self,
+ ds_fn,
+ break_points,
+ num_outputs,
+ ckpt_saved=False,
+ init_before_restore=False,
+ sparse_tensors=False,
+ verify_exhausted=True,
+ save_checkpoint_at_end=True):
+ """Generates elements from input dataset while stopping at break points.
+
+ Produces `num_outputs` outputs and saves the state of the iterator in the
+ Saver checkpoint.
+
+ Args:
+ ds_fn: 0-argument function that returns the dataset.
+ break_points: A list of integers. For each `break_point` in
+ `break_points`, we produce outputs till `break_point` number of items
+ have been produced and then checkpoint the state. The current graph
+ and session are destroyed and a new graph and session are used to
+ produce outputs till next checkpoint or till `num_outputs` elements
+ have been produced. `break_point` must be <= `num_outputs`.
+ num_outputs: The total number of outputs to produce from the iterator.
+ ckpt_saved: Whether a checkpoint already exists. If False, we build the
+ graph from ds_fn.
+ init_before_restore: Whether init should be called before saver.restore.
+ This is just so that we can verify that restoring an already initialized
+ iterator works.
+ sparse_tensors: Whether dataset is built from SparseTensor(s).
+ verify_exhausted: Whether to verify that the iterator has been exhausted
+ after producing `num_outputs` elements.
+ save_checkpoint_at_end: Whether to save a checkpoint after producing all
+ outputs. If False, checkpoints are saved each break point but not at the
+ end. Note that checkpoints overwrite each other so there is always only
+ a single checkpoint available. Defaults to True.
+
+ Returns:
+ A list of `num_outputs` items.
+ """
+ outputs = []
+
+ def get_ops():
+ if ckpt_saved:
+ saver = self._import_meta_graph()
+ init_op, get_next_op = self._get_iterator_ops_from_collection(
+ ds_fn, sparse_tensors=sparse_tensors)
+ else:
+ init_op, get_next_op, saver = self._build_graph(
+ ds_fn, sparse_tensors=sparse_tensors)
+ return init_op, get_next_op, saver
+
+ for i in range(len(break_points) + 1):
+ with ops.Graph().as_default() as g:
+ init_op, get_next_op, saver = get_ops()
+ get_next_op = remove_variants(get_next_op)
+ with self.session(graph=g) as sess:
+ if ckpt_saved:
+ if init_before_restore:
+ self._initialize(init_op, sess)
+ self._restore(saver, sess)
+ else:
+ self._initialize(init_op, sess)
+ start = break_points[i - 1] if i > 0 else 0
+ end = break_points[i] if i < len(break_points) else num_outputs
+ num_iters = end - start
+ for _ in range(num_iters):
+ outputs.append(sess.run(get_next_op))
+ if i == len(break_points) and verify_exhausted:
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
+ if save_checkpoint_at_end or i < len(break_points):
+ self._save(sess, saver)
+ ckpt_saved = True
+
+ return outputs
+
+ def match(self, expected, actual):
+ """Matches nested structures.
+
+ Recursively matches shape and values of `expected` and `actual`.
+ Handles scalars, numpy arrays and other python sequence containers
+ e.g. list, dict.
+
+ Args:
+ expected: Nested structure 1.
+ actual: Nested structure 2.
+
+ Raises:
+ AssertionError if matching fails.
+ """
+ if isinstance(expected, np.ndarray):
+ expected = expected.tolist()
+ if isinstance(actual, np.ndarray):
+ actual = actual.tolist()
+ self.assertEqual(type(expected), type(actual))
+
+ if nest.is_sequence(expected):
+ self.assertEqual(len(expected), len(actual))
+ if isinstance(expected, dict):
+ for key1, key2 in zip(sorted(expected), sorted(actual)):
+ self.assertEqual(key1, key2)
+ self.match(expected[key1], actual[key2])
+ else:
+ for item1, item2 in zip(expected, actual):
+ self.match(item1, item2)
+ else:
+ self.assertEqual(expected, actual)
+
+ def does_not_match(self, expected, actual):
+ with self.assertRaises(AssertionError):
+ self.match(expected, actual)
+
+ def gen_break_points(self, num_outputs, num_samples=10):
+ """Generates `num_samples` breaks points in [0, num_outputs]."""
+ return np.linspace(0, num_outputs, num_samples, dtype=int)
+
+ def _build_graph(self, ds_fn, sparse_tensors=False):
+ iterator = ds_fn().make_initializable_iterator()
+
+ saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
+ ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
+ init_op = iterator.initializer
+ if sparse_tensors:
+ get_next = sparse_tensor.SparseTensor(*iterator.get_next())
+ else:
+ get_next = iterator.get_next()
+ self._add_iterator_ops_to_collection(init_op, get_next, ds_fn,
+ sparse_tensors)
+ saver = saver_lib.Saver(allow_empty=True)
+ return init_op, get_next, saver
+
+ def _build_empty_graph(self, ds_fn, sparse_tensors=False):
+ iterator = iterator_ops.Iterator.from_structure(
+ self._get_output_types(ds_fn),
+ output_shapes=self._get_output_shapes(ds_fn),
+ output_classes=self._get_output_classes(ds_fn))
+ saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
+ ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
+ if sparse_tensors:
+ get_next = sparse_tensor.SparseTensor(*iterator.get_next())
+ else:
+ get_next = iterator.get_next()
+ saver = saver_lib.Saver(allow_empty=True)
+ return get_next, saver
+
+ def _add_iterator_ops_to_collection(self,
+ init_op,
+ get_next,
+ ds_fn,
+ sparse_tensors=False):
+ ops.add_to_collection("iterator_ops", init_op)
+ # `get_next` may be a tuple e.g. in TensorSliceDataset. Since Collections
+ # do not support tuples we flatten the tensors and restore the shape in
+ # `_get_iterator_ops_from_collection`.
+ if sparse_tensors: # specific for deprecated `from_sparse_tensor_slices`.
+ ops.add_to_collection("iterator_ops", get_next.indices)
+ ops.add_to_collection("iterator_ops", get_next.values)
+ ops.add_to_collection("iterator_ops", get_next.dense_shape)
+ return
+
+ get_next_list = nest.flatten(get_next)
+ for i, output_class in enumerate(
+ nest.flatten(self._get_output_classes(ds_fn))):
+ if output_class is sparse_tensor.SparseTensor:
+ ops.add_to_collection("iterator_ops", get_next_list[i].indices)
+ ops.add_to_collection("iterator_ops", get_next_list[i].values)
+ ops.add_to_collection("iterator_ops", get_next_list[i].dense_shape)
+ else:
+ ops.add_to_collection("iterator_ops", get_next_list[i])
+
+ def _get_iterator_ops_from_collection(self, ds_fn, sparse_tensors=False):
+ all_ops = ops.get_collection("iterator_ops")
+ if sparse_tensors: # specific for deprecated `from_sparse_tensor_slices`.
+ init_op, indices, values, dense_shape = all_ops
+ return init_op, sparse_tensor.SparseTensor(indices, values, dense_shape)
+ get_next_list = []
+ i = 1
+ for output_class in nest.flatten(self._get_output_classes(ds_fn)):
+ if output_class is sparse_tensor.SparseTensor:
+ indices, values, dense_shape = all_ops[i:i + 3]
+ i += 3
+ get_next_list.append(
+ sparse_tensor.SparseTensor(indices, values, dense_shape))
+ else:
+ get_next_list.append(all_ops[i])
+ i += 1
+ return all_ops[0], nest.pack_sequence_as(
+ self._get_output_types(ds_fn), get_next_list)
+
+ def _get_output_types(self, ds_fn):
+ with ops.Graph().as_default():
+ return ds_fn().output_types
+
+ def _get_output_shapes(self, ds_fn):
+ with ops.Graph().as_default():
+ return ds_fn().output_shapes
+
+ def _get_output_classes(self, ds_fn):
+ with ops.Graph().as_default():
+ return ds_fn().output_classes
+
+ def _ckpt_path(self):
+ return os.path.join(self.get_temp_dir(), "iterator")
+
+ def _latest_ckpt(self):
+ return checkpoint_management.latest_checkpoint(self.get_temp_dir())
+
+ def _save(self, sess, saver):
+ saver.save(sess, self._ckpt_path())
+
+ def _restore(self, saver, sess):
+ sess.run(lookup_ops.tables_initializer())
+ saver.restore(sess, self._latest_ckpt())
+
+ def _initialize(self, init_op, sess):
+ sess.run(variables.global_variables_initializer())
+ sess.run(lookup_ops.tables_initializer())
+ sess.run(init_op)
+
+ def _import_meta_graph(self):
+ meta_file_path = self._ckpt_path() + ".meta"
+ return saver_lib.import_meta_graph(meta_file_path)
+
+ def _delete_ckpt(self):
+ # Remove all checkpoint files.
+ prefix = self._ckpt_path()
+ pattern = prefix + "*"
+ files = gfile.Glob(pattern)
+ map(gfile.Remove, files)
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/filter_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/filter_dataset_serialization_test.py
new file mode 100644
index 0000000000..225f6cbac0
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/filter_dataset_serialization_test.py
@@ -0,0 +1,71 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the FilterDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class FilterDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_filter_range_graph(self, div):
+ return dataset_ops.Dataset.range(100).filter(
+ lambda x: math_ops.not_equal(math_ops.mod(x, div), 2))
+
+ def testFilterCore(self):
+ div = 3
+ num_outputs = np.sum([x % 3 != 2 for x in range(100)])
+ self.run_core_tests(lambda: self._build_filter_range_graph(div),
+ lambda: self._build_filter_range_graph(div * 2),
+ num_outputs)
+
+ def _build_filter_dict_graph(self):
+ return dataset_ops.Dataset.range(10).map(
+ lambda x: {"foo": x * 2, "bar": x ** 2}).filter(
+ lambda d: math_ops.equal(d["bar"] % 2, 0)).map(
+ lambda d: d["foo"] + d["bar"])
+
+ def testFilterDictCore(self):
+ num_outputs = np.sum([(x**2) % 2 == 0 for x in range(10)])
+ self.run_core_tests(self._build_filter_dict_graph, None, num_outputs)
+
+ def _build_sparse_filter(self):
+
+ def _map_fn(i):
+ return sparse_tensor.SparseTensor(
+ indices=[[0, 0]], values=(i * [1]), dense_shape=[1, 1]), i
+
+ def _filter_fn(_, i):
+ return math_ops.equal(i % 2, 0)
+
+ return dataset_ops.Dataset.range(10).map(_map_fn).filter(_filter_fn).map(
+ lambda x, i: x)
+
+ def testSparseCore(self):
+ num_outputs = 5
+ self.run_core_tests(self._build_sparse_filter, None, num_outputs)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py
new file mode 100644
index 0000000000..70caf3e0d5
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py
@@ -0,0 +1,45 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the FixedLengthRecordDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.ops import readers as core_readers
+from tensorflow.python.platform import test
+
+
+class FixedLengthRecordDatasetSerializationTest(
+ reader_dataset_ops_test_base.FixedLengthRecordDatasetTestBase,
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_iterator_graph(self, num_epochs, compression_type=None):
+ filenames = self._createFiles()
+ return core_readers.FixedLengthRecordDataset(
+ filenames, self._record_bytes, self._header_bytes,
+ self._footer_bytes).repeat(num_epochs)
+
+ def testFixedLengthRecordCore(self):
+ num_epochs = 5
+ num_outputs = num_epochs * self._num_files * self._num_records
+ self.run_core_tests(lambda: self._build_iterator_graph(num_epochs),
+ lambda: self._build_iterator_graph(num_epochs * 2),
+ num_outputs)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/flat_map_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/flat_map_dataset_serialization_test.py
new file mode 100644
index 0000000000..c30534a9e9
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/flat_map_dataset_serialization_test.py
@@ -0,0 +1,122 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the FlatMapDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import function
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import sparse_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.platform import test
+
+
+class FlatMapDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def testCore(self):
+ # Complicated way of saying range(start, start+25).
+ def build_ds(start):
+
+ def map_fn(x):
+ return dataset_ops.Dataset.range(x, x + 5)
+
+ return dataset_ops.Dataset.range(start, start + 5 * 5, 5).flat_map(map_fn)
+
+ self.run_core_tests(lambda: build_ds(0), lambda: build_ds(10), 25)
+
+ def testMapThenFlatMap(self):
+
+ def build_ds():
+
+ def flat_map_fn(_):
+
+ def map_fn(y):
+ return 10 * math_ops.to_int32(y)
+
+ return dataset_ops.Dataset.range(100).map(map_fn)
+
+ return dataset_ops.Dataset.range(5).flat_map(flat_map_fn)
+
+ self.run_core_tests(build_ds, None, 500)
+
+ def testCaptureDefunInMapFn(self):
+
+ def build_ds():
+
+ def map_fn(x):
+
+ @function.Defun(dtypes.int64)
+ def defun_fn(x):
+ return constant_op.constant(1000) + math_ops.to_int32(x)
+
+ return dataset_ops.Dataset.from_tensor_slices([defun_fn(x)])
+
+ return dataset_ops.Dataset.range(100).flat_map(map_fn)
+
+ self.run_core_tests(build_ds, None, 100)
+
+ def testDisallowVariableCapture(self):
+
+ def build_ds():
+ test_var = variable_scope.get_variable(
+ name="test_var", shape=(), use_resource=True)
+ return dataset_ops.Dataset.range(5).flat_map(
+ lambda _: dataset_ops.Dataset.from_tensor_slices([test_var]))
+
+ self.verify_error_on_save(build_ds, 5, errors.InvalidArgumentError)
+
+ def testDisallowCapturingStatefulOps(self):
+
+ def build_ds():
+
+ def flat_map_fn(_):
+
+ def map_fn(x):
+ return random_ops.random_uniform(
+ (), 0, 10, dtype=dtypes.int32) * math_ops.to_int32(x)
+
+ return dataset_ops.Dataset.range(100).map(map_fn)
+
+ return dataset_ops.Dataset.range(5).flat_map(flat_map_fn)
+
+ self.verify_error_on_save(build_ds, 500, errors.InvalidArgumentError)
+
+ def testSparseCore(self):
+
+ def _map_fn(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2])
+
+ def _flat_map_fn(x):
+ return dataset_ops.Dataset.from_tensor_slices(
+ sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values))
+
+ def _build_ds():
+ return dataset_ops.Dataset.range(10).map(_map_fn).flat_map(_flat_map_fn)
+
+ self.run_core_tests(_build_ds, None, 20)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/group_by_reducer_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/group_by_reducer_serialization_test.py
new file mode 100644
index 0000000000..169c8845d0
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/group_by_reducer_serialization_test.py
@@ -0,0 +1,61 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the GroupByReducer serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import grouping
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.platform import test
+
+
+class GroupByReducerSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_dataset(self, components):
+ reducer = grouping.Reducer(
+ init_func=lambda _: np.int64(0),
+ reduce_func=lambda x, y: x + y,
+ finalize_func=lambda x: x)
+
+ return dataset_ops.Dataset.from_tensor_slices(components).apply(
+ grouping.group_by_reducer(lambda x: x % 5, reducer))
+
+ def testCoreGroupByReducer(self):
+ components = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=np.int64)
+ self.verify_unused_iterator(
+ lambda: self._build_dataset(components), 5, verify_exhausted=True)
+ self.verify_init_before_restore(
+ lambda: self._build_dataset(components), 5, verify_exhausted=True)
+ self.verify_multiple_breaks(
+ lambda: self._build_dataset(components), 5, verify_exhausted=True)
+ self.verify_reset_restored_iterator(
+ lambda: self._build_dataset(components), 5, verify_exhausted=True)
+ self.verify_restore_in_empty_graph(
+ lambda: self._build_dataset(components), 5, verify_exhausted=True)
+ diff_components = np.array([5, 4, 3, 2, 1, 0], dtype=np.int64)
+ self.verify_restore_in_modified_graph(
+ lambda: self._build_dataset(components),
+ lambda: self._build_dataset(diff_components),
+ 5,
+ verify_exhausted=True)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/group_by_window_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/group_by_window_serialization_test.py
new file mode 100644
index 0000000000..e5bc76288e
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/group_by_window_serialization_test.py
@@ -0,0 +1,57 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the GroupByWindow serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import grouping
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.platform import test
+
+
+class GroupByWindowSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_dataset(self, components):
+ return dataset_ops.Dataset.from_tensor_slices(components).repeat(-1).apply(
+ grouping.group_by_window(lambda x: x % 3, lambda _, xs: xs.batch(4), 4))
+
+ def testCoreGroupByWindow(self):
+ components = np.array(
+ [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0], dtype=np.int64)
+ self.verify_unused_iterator(
+ lambda: self._build_dataset(components), 12, verify_exhausted=False)
+ self.verify_init_before_restore(
+ lambda: self._build_dataset(components), 12, verify_exhausted=False)
+ self.verify_multiple_breaks(
+ lambda: self._build_dataset(components), 12, verify_exhausted=False)
+ self.verify_reset_restored_iterator(
+ lambda: self._build_dataset(components), 12, verify_exhausted=False)
+ self.verify_restore_in_empty_graph(
+ lambda: self._build_dataset(components), 12, verify_exhausted=False)
+ diff_components = np.array([0, 0, 0, 1, 1, 1], dtype=np.int64)
+ self.verify_restore_in_modified_graph(
+ lambda: self._build_dataset(components),
+ lambda: self._build_dataset(diff_components),
+ 12,
+ verify_exhausted=False)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/ignore_errors_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/ignore_errors_serialization_test.py
new file mode 100644
index 0000000000..df1f43129a
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/ignore_errors_serialization_test.py
@@ -0,0 +1,46 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the IgnoreErrors input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import error_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class IgnoreErrorsSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_ds(self, components):
+ return dataset_ops.Dataset.from_tensor_slices(components).map(
+ lambda x: array_ops.check_numerics(x, "message")).apply(
+ error_ops.ignore_errors())
+
+ def testIgnoreErrorsCore(self):
+ components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
+ diff_components = np.array([1., 2., 3., np.nan]).astype(np.float32)
+ num_outputs = 4
+ self.run_core_tests(lambda: self._build_ds(components),
+ lambda: self._build_ds(diff_components), num_outputs)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/interleave_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/interleave_dataset_serialization_test.py
new file mode 100644
index 0000000000..0c1d40ce39
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/interleave_dataset_serialization_test.py
@@ -0,0 +1,83 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the InterleaveDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import sparse_ops
+from tensorflow.python.platform import test
+
+
+class InterleaveDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase,
+ parameterized.TestCase):
+
+ def _build_iterator_graph(self, input_values, cycle_length, block_length,
+ num_parallel_calls):
+ repeat_count = 2
+ return dataset_ops.Dataset.from_tensor_slices(input_values).repeat(
+ repeat_count).interleave(
+ lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x),
+ cycle_length, block_length, num_parallel_calls)
+
+ @parameterized.named_parameters(
+ ("1", 2, 3, None),
+ ("2", 2, 3, 1),
+ ("3", 2, 3, 2),
+ ("4", 1, 3, None),
+ ("5", 1, 3, 1),
+ ("6", 2, 1, None),
+ ("7", 2, 1, 1),
+ ("8", 2, 1, 2),
+ )
+ def testSerializationCore(self, cycle_length, block_length,
+ num_parallel_calls):
+ input_values = np.array([4, 5, 6], dtype=np.int64)
+ num_outputs = np.sum(input_values) * 2
+ # pylint: disable=g-long-lambda
+ self.run_core_tests(
+ lambda: self._build_iterator_graph(
+ input_values, cycle_length, block_length, num_parallel_calls),
+ lambda: self._build_iterator_graph(
+ input_values, cycle_length * 2, block_length, num_parallel_calls),
+ num_outputs)
+ # pylint: enable=g-long-lambda
+
+ def testSparseCore(self):
+
+ def _map_fn(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2])
+
+ def _interleave_fn(x):
+ return dataset_ops.Dataset.from_tensor_slices(
+ sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values))
+
+ def _build_dataset():
+ return dataset_ops.Dataset.range(10).map(_map_fn).interleave(
+ _interleave_fn, cycle_length=1)
+
+ self.run_core_tests(_build_dataset, None, 20)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py
new file mode 100644
index 0000000000..166ffa99ca
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py
@@ -0,0 +1,88 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the MapAndBatchDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import batching
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class MapAndBatchDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def testNumParallelBatches(self):
+ range_size = 11
+ num_repeats = 2
+ batch_size = 5
+ total_outputs = range_size * num_repeats
+ num_outputs_drop_remainder = total_outputs // batch_size
+ num_outputs_keep_remainder = int(math.ceil(total_outputs / batch_size))
+ num_parallel_batches = 2
+
+ def build_ds(range_start, drop_remainder=False):
+
+ def _map_fn(x):
+ return math_ops.square(x)
+
+ return dataset_ops.Dataset.range(
+ range_start, range_start + range_size).repeat(num_repeats).apply(
+ batching.map_and_batch(
+ map_func=_map_fn,
+ batch_size=batch_size,
+ num_parallel_batches=num_parallel_batches,
+ drop_remainder=drop_remainder))
+
+ self.run_core_tests(lambda: build_ds(10), lambda: build_ds(15),
+ num_outputs_keep_remainder)
+ self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True),
+ num_outputs_drop_remainder)
+
+ def testNumParallelCalls(self):
+ range_size = 11
+ num_repeats = 2
+ batch_size = 5
+ total_outputs = range_size * num_repeats
+ num_outputs_drop_remainder = total_outputs // batch_size
+ num_outputs_keep_remainder = int(math.ceil(total_outputs / batch_size))
+ num_parallel_calls = 7
+
+ def build_ds(range_start, drop_remainder=False):
+
+ def _map_fn(x):
+ return math_ops.square(x)
+
+ return dataset_ops.Dataset.range(
+ range_start, range_start + range_size).repeat(num_repeats).apply(
+ batching.map_and_batch(
+ map_func=_map_fn,
+ batch_size=batch_size,
+ num_parallel_calls=num_parallel_calls,
+ drop_remainder=drop_remainder))
+
+ self.run_core_tests(lambda: build_ds(10), lambda: build_ds(15),
+ num_outputs_keep_remainder)
+ self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True),
+ num_outputs_drop_remainder)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/map_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/map_dataset_serialization_test.py
new file mode 100644
index 0000000000..b93156a96c
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/map_dataset_serialization_test.py
@@ -0,0 +1,140 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the MapDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import function
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.platform import test
+
+
+class MapDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def setUp(self):
+ self._tensor_slice_len = 7
+ self._num_epochs = 14
+ self._num_outputs = self._tensor_slice_len * self._num_epochs
+
+ def _build_ds(self, multiplier=37.0):
+ components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) *
+ np.arange(self._tensor_slice_len)[:, np.newaxis],
+ np.array(multiplier) * np.arange(self._tensor_slice_len))
+
+ def _map_fn(x, y, z):
+ return math_ops.square(x), math_ops.square(y), math_ops.square(z)
+
+ return (
+ dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn)
+ .repeat(self._num_epochs))
+
+ def testSaveRestoreCore(self):
+ self.run_core_tests(
+ self._build_ds,
+ lambda: self._build_ds(multiplier=15.0),
+ self._num_outputs)
+
+ def testSaveStatefulFunction(self):
+
+ def _build_ds():
+
+ def _map_fn(x):
+ return random_ops.random_uniform(
+ (), 0, 10, dtype=dtypes.int32) * math_ops.to_int32(x)
+
+ return dataset_ops.Dataset.range(100).map(_map_fn)
+
+ self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError)
+
+ def testCaptureVariableInMapFn(self):
+
+ def _build_ds():
+ counter_var = variable_scope.get_variable(
+ "counter", (), dtypes.int32, use_resource=True)
+ return (dataset_ops.Dataset.from_tensors(0).repeat(10).map(
+ lambda _: counter_var.assign_add(1)))
+
+ self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError)
+
+ def testCaptureConstantInMapFn(self):
+
+ def _build_ds():
+ constant_var = constant_op.constant(5)
+ return (dataset_ops.Dataset.from_tensors(0).repeat(10).map(
+ lambda x: x + constant_var))
+
+ self.run_core_tests(_build_ds, None, 10)
+
+ def testCaptureDefunInMapFn(self):
+ num_outputs = 100
+
+ def _build_ds():
+
+ @function.Defun(dtypes.int64)
+ def defun_fn(x):
+ return constant_op.constant(1000) + math_ops.to_int32(x)
+
+ return dataset_ops.Dataset.range(num_outputs).map(defun_fn)
+
+ self.run_core_tests(_build_ds, None, num_outputs)
+
+ def testBuildDefunInMapFn(self):
+ num_outputs = 100
+
+ def _build_ds():
+
+ @function.Defun(dtypes.int64)
+ def defun_fn(x):
+
+ @function.Defun(dtypes.int32)
+ def defun_fn_deep(x):
+ return constant_op.constant(1000) + math_ops.to_int32(x)
+
+ return constant_op.constant(11000) + defun_fn_deep(math_ops.to_int32(x))
+
+ return dataset_ops.Dataset.range(num_outputs).map(defun_fn)
+
+ self.run_core_tests(_build_ds, None, num_outputs)
+
+ def testSparseCore(self):
+
+ def _sparse(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=np.array([[0, 0]]),
+ values=(i * np.array([1])),
+ dense_shape=np.array([1, 1]))
+
+ def _build_ds(num_outputs):
+ return dataset_ops.Dataset.range(num_outputs).map(_sparse)
+
+ num_outputs = 10
+ self.run_core_tests(lambda: _build_ds(num_outputs),
+ lambda: _build_ds(int(num_outputs / 2)), num_outputs)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/optimize_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/optimize_dataset_serialization_test.py
new file mode 100644
index 0000000000..ed4a1da596
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/optimize_dataset_serialization_test.py
@@ -0,0 +1,39 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the OptimizeDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import optimization
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.platform import test
+
+
+class OptimizeDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def testCore(self):
+
+ def build_dataset(num_elements, batch_size):
+ return dataset_ops.Dataset.range(num_elements).map(lambda x: x * x).batch(
+ batch_size).apply(optimization.optimize(["map_and_batch_fusion"]))
+
+ self.run_core_tests(lambda: build_dataset(200, 10), None, 20)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/padded_batch_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/padded_batch_dataset_serialization_test.py
new file mode 100644
index 0000000000..6f72b24673
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/padded_batch_dataset_serialization_test.py
@@ -0,0 +1,66 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the PaddedBatchDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import string_ops
+from tensorflow.python.platform import test
+
+
+class PaddedBatchDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def testPaddedBatch(self):
+
+ def build_dataset(seq_lens):
+ return dataset_ops.Dataset.from_tensor_slices(seq_lens).map(
+ lambda x: array_ops.fill([x], x)).padded_batch(
+ 4, padded_shapes=[-1])
+
+ seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32)
+ seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32)
+ self.run_core_tests(lambda: build_dataset(seq_lens1),
+ lambda: build_dataset(seq_lens2), 8)
+
+ def testPaddedBatchNonDefaultPadding(self):
+
+ def build_dataset(seq_lens):
+
+ def fill_tuple(x):
+ filled = array_ops.fill([x], x)
+ return (filled, string_ops.as_string(filled))
+
+ padded_shape = [-1]
+ return dataset_ops.Dataset.from_tensor_slices(seq_lens).map(
+ fill_tuple).padded_batch(
+ 4,
+ padded_shapes=(padded_shape, padded_shape),
+ padding_values=(-1, "<end>"))
+
+ seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32)
+ seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32)
+ self.run_core_tests(lambda: build_dataset(seq_lens1),
+ lambda: build_dataset(seq_lens2), 8)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py
new file mode 100644
index 0000000000..b8f38e8a28
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py
@@ -0,0 +1,101 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the ParallelInterleaveDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import interleave_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import sparse_ops
+from tensorflow.python.platform import test
+
+
+class ParallelInterleaveDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def setUp(self):
+ self.input_values = np.array([4, 5, 6], dtype=np.int64)
+ self.num_repeats = 2
+ self.num_outputs = np.sum(self.input_values) * 2
+
+ def _build_ds(self, cycle_length, block_length, sloppy=False):
+ return (dataset_ops.Dataset.from_tensor_slices(
+ self.input_values).repeat(self.num_repeats).apply(
+ interleave_ops.parallel_interleave(
+ lambda x: dataset_ops.Dataset.range(10 * x, 11 * x),
+ cycle_length, block_length, sloppy)))
+
+ def testSerializationCore(self):
+ # cycle_length > 1, block_length > 1
+ cycle_length = 2
+ block_length = 3
+ self.run_core_tests(
+ lambda: self._build_ds(cycle_length, block_length),
+ lambda: self._build_ds(cycle_length * 2, block_length * 1),
+ self.num_outputs)
+ # cycle_length = 1
+ cycle_length = 1
+ block_length = 3
+ self.run_core_tests(lambda: self._build_ds(cycle_length, block_length),
+ None, self.num_outputs)
+ # block_length = 1
+ cycle_length = 2
+ block_length = 1
+ self.run_core_tests(lambda: self._build_ds(cycle_length, block_length),
+ None, self.num_outputs)
+
+ def testSerializationWithSloppy(self):
+ break_points = self.gen_break_points(self.num_outputs, 10)
+ expected_outputs = np.repeat(
+ np.concatenate([np.arange(10 * x, 11 * x) for x in self.input_values]),
+ self.num_repeats).tolist()
+
+ def run_test(cycle_length, block_length):
+ actual = self.gen_outputs(
+ lambda: self._build_ds(cycle_length, block_length, True),
+ break_points, self.num_outputs)
+ self.assertSequenceEqual(sorted(actual), expected_outputs)
+
+ # cycle_length > 1, block_length > 1
+ run_test(2, 3)
+ # cycle_length = 1
+ run_test(1, 3)
+ # block_length = 1
+ run_test(2, 1)
+
+ def testSparseCore(self):
+
+ def _map_fn(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2])
+
+ def _interleave_fn(x):
+ return dataset_ops.Dataset.from_tensor_slices(
+ sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values))
+
+ def _build_dataset():
+ return dataset_ops.Dataset.range(10).map(_map_fn).apply(
+ interleave_ops.parallel_interleave(_interleave_fn, 1))
+
+ self.run_core_tests(_build_dataset, None, 20)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/parallel_map_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/parallel_map_dataset_serialization_test.py
new file mode 100644
index 0000000000..a0bdd4fa59
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/parallel_map_dataset_serialization_test.py
@@ -0,0 +1,139 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the ParallelMapDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import function
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.platform import test
+
+
+class ParallelMapDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def setUp(self):
+ self._tensor_slice_len = 7
+ self._num_epochs = 1
+ self._num_outputs = self._tensor_slice_len * self._num_epochs
+
+ def _build_ds(self, multiplier=37.0):
+ components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) *
+ np.arange(self._tensor_slice_len)[:, np.newaxis],
+ np.array(multiplier) * np.arange(self._tensor_slice_len))
+
+ def _map_fn(x, y, z):
+ return math_ops.square(x), math_ops.square(y), math_ops.square(z)
+
+ return (dataset_ops.Dataset.from_tensor_slices(components).map(
+ _map_fn, num_parallel_calls=3).repeat(self._num_epochs))
+
+ def _build_ds_with_prefetch(self, multiplier=37.0):
+ components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) *
+ np.arange(self._tensor_slice_len)[:, np.newaxis],
+ np.array(multiplier) * np.arange(self._tensor_slice_len))
+
+ def _map_fn(x, y, z):
+ return math_ops.square(x), math_ops.square(y), math_ops.square(z)
+
+ return (dataset_ops.Dataset.from_tensor_slices(components).map(
+ _map_fn, num_parallel_calls=3).repeat(self._num_epochs).prefetch(5))
+
+ def testSaveRestoreCore(self):
+ for ds_fn in [self._build_ds, self._build_ds_with_prefetch]:
+ self.run_core_tests(
+ ds_fn,
+ lambda: ds_fn(multiplier=15.0), # pylint: disable=cell-var-from-loop
+ self._num_outputs)
+
+ def testSaveStatefulFunction(self):
+
+ def _build_ds():
+
+ def _map_fn(x):
+ return random_ops.random_uniform(
+ (), 0, 10, dtype=dtypes.int32) * math_ops.to_int32(x)
+
+ return dataset_ops.Dataset.range(100).map(
+ _map_fn, num_parallel_calls=2).prefetch(2)
+
+ self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError)
+
+ def testCaptureVariableInMapFn(self):
+
+ def _build_ds():
+ counter_var = variable_scope.get_variable(
+ "counter", (), dtypes.int32, use_resource=True)
+ return (dataset_ops.Dataset.from_tensors(0).repeat(10).map(
+ lambda _: counter_var.assign_add(1),
+ num_parallel_calls=2).prefetch(2))
+
+ self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError)
+
+ def testCaptureConstantInMapFn(self):
+
+ def _build_ds():
+ constant_var = constant_op.constant(5)
+ return (dataset_ops.Dataset.from_tensors(0).repeat(10).map(
+ lambda x: x + constant_var, num_parallel_calls=2).prefetch(2))
+
+ self.run_core_tests(_build_ds, None, 10)
+
+ def testCaptureDefunInMapFn(self):
+ num_outputs = 100
+
+ def _build_ds():
+
+ @function.Defun(dtypes.int64)
+ def defun_fn(x):
+ return constant_op.constant(1000) + math_ops.to_int32(x)
+
+ return dataset_ops.Dataset.range(num_outputs).map(
+ defun_fn, num_parallel_calls=2).prefetch(2)
+
+ self.run_core_tests(_build_ds, None, num_outputs)
+
+ def testBuildDefunInMapFn(self):
+ num_outputs = 100
+
+ def _build_ds():
+
+ @function.Defun(dtypes.int64)
+ def defun_fn(x):
+
+ @function.Defun(dtypes.int32)
+ def defun_fn_deep(x):
+ return constant_op.constant(1000) + math_ops.to_int32(x)
+
+ return constant_op.constant(11000) + defun_fn_deep(math_ops.to_int32(x))
+
+ return dataset_ops.Dataset.range(num_outputs).map(
+ defun_fn, num_parallel_calls=2).prefetch(2)
+
+ self.run_core_tests(_build_ds, None, num_outputs)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/parse_example_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/parse_example_dataset_serialization_test.py
new file mode 100644
index 0000000000..a0dd6960b0
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/parse_example_dataset_serialization_test.py
@@ -0,0 +1,50 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the ParseExampleDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.platform import test
+
+
+class ParseExampleDatasetSerializationTest(
+ reader_dataset_ops_test_base.ReadBatchFeaturesTestBase,
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def ParseExampleDataset(self, num_repeat, batch_size):
+ return self.make_batch_feature(
+ filenames=self.test_filenames,
+ num_epochs=num_repeat,
+ batch_size=batch_size,
+ reader_num_threads=5,
+ parser_num_threads=10)
+
+ def testSerializationCore(self):
+ num_repeat = 5
+ batch_size = 2
+ num_outputs = self._num_records * self._num_files * num_repeat // batch_size
+ # pylint: disable=g-long-lambda
+ self.run_core_tests(
+ lambda: self.ParseExampleDataset(
+ num_repeat=num_repeat, batch_size=batch_size),
+ lambda: self.ParseExampleDataset(num_repeat=10, batch_size=4),
+ num_outputs)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/prefetch_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/prefetch_dataset_serialization_test.py
new file mode 100644
index 0000000000..00d74c0025
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/prefetch_dataset_serialization_test.py
@@ -0,0 +1,39 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the PrefetchDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.platform import test
+
+
+class PrefetchDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def build_dataset(self, seed):
+ return dataset_ops.Dataset.range(100).prefetch(10).shuffle(
+ buffer_size=10, seed=seed, reshuffle_each_iteration=False)
+
+ def testCore(self):
+ num_outputs = 100
+ self.run_core_tests(lambda: self.build_dataset(10),
+ lambda: self.build_dataset(20), num_outputs)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/range_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/range_dataset_serialization_test.py
new file mode 100644
index 0000000000..ef99d01c73
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/range_dataset_serialization_test.py
@@ -0,0 +1,118 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the RangeDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import io_ops
+from tensorflow.python.ops import parsing_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+class RangeDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _iterator_checkpoint_prefix_local(self):
+ return os.path.join(self.get_temp_dir(), "iterator")
+
+ def _save_op(self, iterator_resource):
+ iterator_state_variant = gen_dataset_ops.serialize_iterator(
+ iterator_resource)
+ save_op = io_ops.write_file(
+ self._iterator_checkpoint_prefix_local(),
+ parsing_ops.serialize_tensor(iterator_state_variant))
+ return save_op
+
+ def _restore_op(self, iterator_resource):
+ iterator_state_variant = parsing_ops.parse_tensor(
+ io_ops.read_file(self._iterator_checkpoint_prefix_local()),
+ dtypes.variant)
+ restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
+ iterator_state_variant)
+ return restore_op
+
+ def testSaveRestore(self):
+
+ def _build_graph(start, stop):
+ iterator = dataset_ops.Dataset.range(start,
+ stop).make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ save_op = self._save_op(iterator._iterator_resource)
+ restore_op = self._restore_op(iterator._iterator_resource)
+ return init_op, get_next, save_op, restore_op
+
+ # Saving and restoring in different sessions.
+ start = 2
+ stop = 10
+ break_point = 5
+ with ops.Graph().as_default() as g:
+ init_op, get_next, save_op, _ = _build_graph(start, stop)
+ with self.session(graph=g) as sess:
+ sess.run(variables.global_variables_initializer())
+ sess.run(init_op)
+ for i in range(start, break_point):
+ self.assertEqual(i, sess.run(get_next))
+ sess.run(save_op)
+
+ with ops.Graph().as_default() as g:
+ init_op, get_next, _, restore_op = _build_graph(start, stop)
+ with self.session(graph=g) as sess:
+ sess.run(init_op)
+ sess.run(restore_op)
+ for i in range(break_point, stop):
+ self.assertEqual(i, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Saving and restoring in same session.
+ with ops.Graph().as_default() as g:
+ init_op, get_next, save_op, restore_op = _build_graph(start, stop)
+ with self.session(graph=g) as sess:
+ sess.run(variables.global_variables_initializer())
+ sess.run(init_op)
+ for i in range(start, break_point):
+ self.assertEqual(i, sess.run(get_next))
+ sess.run(save_op)
+ sess.run(restore_op)
+ for i in range(break_point, stop):
+ self.assertEqual(i, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def _build_range_dataset(self, start, stop):
+ return dataset_ops.Dataset.range(start, stop)
+
+ def testRangeCore(self):
+ start = 2
+ stop = 10
+ stop_1 = 8
+ self.run_core_tests(lambda: self._build_range_dataset(start, stop),
+ lambda: self._build_range_dataset(start, stop_1),
+ stop - start)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/sample_from_datasets_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/sample_from_datasets_serialization_test.py
new file mode 100644
index 0000000000..c23c1ecdfb
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/sample_from_datasets_serialization_test.py
@@ -0,0 +1,46 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the SampleFromDatasets serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import interleave_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.platform import test
+
+
+class SampleFromDatasetsSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_dataset(self, probs, num_samples):
+ dataset = interleave_ops.sample_from_datasets(
+ [
+ dataset_ops.Dataset.from_tensors(i).repeat(None)
+ for i in range(len(probs))
+ ],
+ probs,
+ seed=1813)
+ return dataset.take(num_samples)
+
+ def testSerializationCore(self):
+ self.run_core_tests(
+ lambda: self._build_dataset([0.5, 0.5], 100),
+ lambda: self._build_dataset([0.25, 0.25, 0.25, 0.25], 1000), 100)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/scan_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/scan_dataset_serialization_test.py
new file mode 100644
index 0000000000..5f50160619
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/scan_dataset_serialization_test.py
@@ -0,0 +1,40 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the ScanDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import scan_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.platform import test
+
+
+class ScanDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_dataset(self, num_elements):
+ return dataset_ops.Dataset.from_tensors(1).repeat(num_elements).apply(
+ scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1])))
+
+ def testScanCore(self):
+ num_output = 5
+ self.run_core_tests(lambda: self._build_dataset(num_output),
+ lambda: self._build_dataset(2), num_output)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/sequence_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/sequence_dataset_serialization_test.py
new file mode 100644
index 0000000000..fe99a3d3d9
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/sequence_dataset_serialization_test.py
@@ -0,0 +1,129 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the sequence datasets serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.platform import test
+
+
+class SkipDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_skip_dataset(self, count):
+ components = (np.arange(10),)
+ return dataset_ops.Dataset.from_tensor_slices(components).skip(count)
+
+ def testSkipFewerThanInputs(self):
+ count = 4
+ num_outputs = 10 - count
+ self.run_core_tests(lambda: self._build_skip_dataset(count),
+ lambda: self._build_skip_dataset(count + 2),
+ num_outputs)
+
+ def testSkipVarious(self):
+ # Skip more than inputs
+ self.run_core_tests(lambda: self._build_skip_dataset(20), None, 0)
+ # Skip exactly the input size
+ self.run_core_tests(lambda: self._build_skip_dataset(10), None, 0)
+ self.run_core_tests(lambda: self._build_skip_dataset(-1), None, 0)
+ # Skip nothing
+ self.run_core_tests(lambda: self._build_skip_dataset(0), None, 10)
+
+ def testInvalidSkip(self):
+ with self.assertRaisesRegexp(ValueError,
+ 'Shape must be rank 0 but is rank 1'):
+ self.run_core_tests(lambda: self._build_skip_dataset([1, 2]), None, 0)
+
+
+class TakeDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_take_dataset(self, count):
+ components = (np.arange(10),)
+ return dataset_ops.Dataset.from_tensor_slices(components).take(count)
+
+ def testTakeFewerThanInputs(self):
+ count = 4
+ self.run_core_tests(
+ lambda: self._build_take_dataset(count),
+ lambda: self._build_take_dataset(count + 2),
+ count,
+ )
+
+ def testTakeVarious(self):
+ # Take more than inputs
+ self.run_core_tests(lambda: self._build_take_dataset(20), None, 10)
+ # Take exactly the input size
+ self.run_core_tests(lambda: self._build_take_dataset(10), None, 10)
+ # Take all
+ self.run_core_tests(lambda: self._build_take_dataset(-1), None, 10)
+ # Take nothing
+ self.run_core_tests(lambda: self._build_take_dataset(0), None, 0)
+
+ def testInvalidTake(self):
+ with self.assertRaisesRegexp(ValueError,
+ 'Shape must be rank 0 but is rank 1'):
+ self.run_core_tests(lambda: self._build_take_dataset([1, 2]), None, 0)
+
+
+class RepeatDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_repeat_dataset(self, count, take_count=3):
+ components = (np.arange(10),)
+ return dataset_ops.Dataset.from_tensor_slices(components).take(
+ take_count).repeat(count)
+
+ def testFiniteRepeat(self):
+ count = 10
+ self.run_core_tests(lambda: self._build_repeat_dataset(count),
+ lambda: self._build_repeat_dataset(count + 2),
+ 3 * count)
+
+ def testEmptyRepeat(self):
+ self.run_core_tests(lambda: self._build_repeat_dataset(0), None, 0)
+
+ def testInfiniteRepeat(self):
+ self.verify_unused_iterator(
+ lambda: self._build_repeat_dataset(-1), 10, verify_exhausted=False)
+ self.verify_init_before_restore(
+ lambda: self._build_repeat_dataset(-1), 10, verify_exhausted=False)
+ self.verify_multiple_breaks(
+ lambda: self._build_repeat_dataset(-1), 20, verify_exhausted=False)
+ self.verify_reset_restored_iterator(
+ lambda: self._build_repeat_dataset(-1), 20, verify_exhausted=False)
+ self.verify_restore_in_modified_graph(
+ lambda: self._build_repeat_dataset(-1),
+ lambda: self._build_repeat_dataset(2),
+ 20,
+ verify_exhausted=False)
+ # Test repeat empty dataset
+ self.run_core_tests(lambda: self._build_repeat_dataset(-1, 0), None, 0)
+
+ def testInvalidRepeat(self):
+ with self.assertRaisesRegexp(
+ ValueError, 'Shape must be rank 0 but is rank 1'):
+ self.run_core_tests(lambda: self._build_repeat_dataset([1, 2], 0),
+ None, 0)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/serialization_integration_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/serialization_integration_test.py
new file mode 100644
index 0000000000..88d5c896c9
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/serialization_integration_test.py
@@ -0,0 +1,85 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Integration test for dataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from tensorflow.python.data.experimental.ops import iterator_ops as contrib_iterator_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import ops
+from tensorflow.python.platform import test
+from tensorflow.python.training import saver as saver_lib
+
+
+class SerializationIntegrationTest(test.TestCase):
+
+ def _build_input_pipeline(self, name, num_outputs):
+ with ops.name_scope(name):
+ ds = dataset_ops.Dataset.range(num_outputs).shuffle(
+ 10, reshuffle_each_iteration=False).prefetch(10)
+ iterator = ds.make_initializable_iterator()
+ saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
+ ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
+ return iterator.initializer, iterator.get_next()
+
+ def _build_graph(self, num_pipelines, num_outputs):
+ init_ops = []
+ get_next_ops = []
+ for i in range(num_pipelines):
+ name = "input_pipeline_%d" % i
+ init_op, get_next_op = self._build_input_pipeline(name, num_outputs)
+ init_ops.append(init_op)
+ get_next_ops.append(get_next_op)
+ saver = saver_lib.Saver()
+ return init_ops, get_next_ops, saver
+
+ def _ckpt_path(self):
+ return os.path.join(self.get_temp_dir(), "iterator")
+
+ def testConcurrentSaves(self):
+ num_pipelines = 100
+ num_outputs = 100
+ break_point = 10
+ all_outputs = [[] for _ in range(num_pipelines)]
+ with ops.Graph().as_default() as g:
+ init_ops, get_next_ops, saver = self._build_graph(num_pipelines,
+ num_outputs)
+ with self.session(graph=g) as sess:
+ sess.run(init_ops)
+ for _ in range(break_point):
+ output = sess.run(get_next_ops)
+ for i in range(num_pipelines):
+ all_outputs[i].append(output[i])
+ saver.save(sess, self._ckpt_path())
+
+ with ops.Graph().as_default() as g:
+ init_ops, get_next_ops, saver = self._build_graph(num_pipelines,
+ num_outputs)
+ with self.session(graph=g) as sess:
+ saver.restore(sess, self._ckpt_path())
+ for _ in range(num_outputs - break_point):
+ output = sess.run(get_next_ops)
+ for i in range(num_pipelines):
+ all_outputs[i].append(output[i])
+
+ for output in all_outputs:
+ self.assertSequenceEqual(sorted(output), range(num_outputs))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py
new file mode 100644
index 0000000000..f847ac19f9
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py
@@ -0,0 +1,39 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the ShuffleAndRepeatDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import shuffle_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.platform import test
+
+
+class ShuffleAndRepeatSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_ds(self, seed):
+ return dataset_ops.Dataset.range(20).apply(
+ shuffle_ops.shuffle_and_repeat(buffer_size=5, count=5, seed=seed))
+
+ def testCore(self):
+ self.run_core_tests(lambda: self._build_ds(10), lambda: self._build_ds(20),
+ 100)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/shuffle_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/shuffle_dataset_serialization_test.py
new file mode 100644
index 0000000000..a04f1ddafc
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/shuffle_dataset_serialization_test.py
@@ -0,0 +1,148 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the ShuffleDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import iterator_ops as contrib_iterator_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import ops
+from tensorflow.python.platform import test
+from tensorflow.python.training import saver as saver_lib
+
+
+class ShuffleDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_shuffle_dataset(
+ self,
+ range_limit=10,
+ num_repeats=5,
+ buffer_size=5,
+ seed=None,
+ reshuffle_each_iteration=None,
+ ):
+ return dataset_ops.Dataset.range(range_limit).shuffle(
+ buffer_size,
+ seed=seed,
+ reshuffle_each_iteration=reshuffle_each_iteration).repeat(num_repeats)
+
+ def testShuffleCore(self):
+
+ seed = 55
+ range_limit = 5
+ num_repeats = 2
+ num_outputs = range_limit * num_repeats
+ buffer_sizes = [1, 3, 5, 8, 10]
+ # pylint: disable=cell-var-from-loop
+ # pylint: disable=g-long-lambda
+ for reshuffle_each_iteration in [True, False]:
+ for buffer_size in buffer_sizes:
+ self.run_core_tests(
+ lambda: self._build_shuffle_dataset(
+ range_limit=range_limit,
+ num_repeats=num_repeats,
+ buffer_size=buffer_size,
+ seed=seed,
+ reshuffle_each_iteration=reshuffle_each_iteration),
+ lambda: self._build_shuffle_dataset(
+ range_limit=range_limit,
+ num_repeats=num_repeats,
+ buffer_size=buffer_size,
+ seed=10,
+ reshuffle_each_iteration=reshuffle_each_iteration),
+ num_outputs)
+ # pylint: enable=cell-var-from-loop
+ # pylint: enable=g-long-lambda
+
+ def testNonDeterministicSeeding(self):
+
+ range_limit = 5
+ num_repeats = 2
+ num_outputs = range_limit * num_repeats
+ buffer_sizes = [1, 3, 5, 8, 10]
+ for reshuffle_each_iteration in [True, False]:
+ for buffer_size in buffer_sizes:
+
+ def ds_fn():
+ # pylint: disable=cell-var-from-loop
+ return self._build_shuffle_dataset(
+ range_limit=range_limit,
+ num_repeats=num_repeats,
+ buffer_size=buffer_size,
+ seed=None, # Iterator seeds are generated non-deterministically.
+ reshuffle_each_iteration=reshuffle_each_iteration)
+ # pylint: enable=cell-var-from-loop
+
+ # We checkpoint the initial state of the Dataset so that we can restore
+ # the seeds in the next run. Since the seeding is non-deterministic
+ # the dataset gets initialized with different seeds each time.
+ expected = self.gen_outputs(
+ ds_fn,
+ break_points=[0],
+ num_outputs=num_outputs,
+ ckpt_saved=False,
+ verify_exhausted=False,
+ save_checkpoint_at_end=False)
+ actual = self.gen_outputs(
+ ds_fn,
+ break_points=self.gen_break_points(num_outputs),
+ num_outputs=num_outputs,
+ ckpt_saved=True,
+ verify_exhausted=False)
+ self.match(expected, actual)
+
+ def testMultipleIterators(self):
+ range_limit = 5
+ num_repeats = 2
+ num_outputs = range_limit * num_repeats
+ buffer_sizes = [1, 3, 5, 8, 10]
+
+ for reshuffle_each_iteration in [True, False]:
+ for buffer_size in buffer_sizes:
+
+ def ds_fn():
+ # pylint: disable=cell-var-from-loop
+ return self._build_shuffle_dataset(
+ range_limit=range_limit,
+ num_repeats=num_repeats,
+ buffer_size=buffer_size,
+ seed=None, # Iterator seeds are generated non-deterministically.
+ reshuffle_each_iteration=reshuffle_each_iteration)
+ # pylint: enable=cell-var-from-loop
+
+ with ops.Graph().as_default() as g:
+ ds = ds_fn()
+ iterators = [ds.make_one_shot_iterator(), ds.make_one_shot_iterator()]
+ get_next_ops = [it.get_next() for it in iterators]
+ saveables = [
+ contrib_iterator_ops.make_saveable_from_iterator(it)
+ for it in iterators
+ ]
+ for saveable in saveables:
+ ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
+ saver = saver_lib.Saver(allow_empty=True)
+ with self.session(graph=g) as sess:
+ self._save(sess, saver)
+ expected = [sess.run(get_next_ops) for _ in range(num_outputs)]
+ self._restore(saver, sess)
+ actual = [sess.run(get_next_ops) for _ in range(num_outputs)]
+ self.match(expected, actual)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/sql_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/sql_dataset_serialization_test.py
new file mode 100644
index 0000000000..b179770ce3
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/sql_dataset_serialization_test.py
@@ -0,0 +1,53 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the SqlDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from tensorflow.python.data.experimental.kernel_tests import sql_dataset_op_test_base
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import readers
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class SqlDatasetSerializationTest(
+ sql_dataset_op_test_base.SqlDatasetTestBase,
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_dataset(self, num_repeats):
+ data_source_name = os.path.join(test.get_temp_dir(), "tftest.sqlite")
+ driver_name = array_ops.placeholder_with_default(
+ array_ops.constant("sqlite", dtypes.string), shape=[])
+ query = ("SELECT first_name, last_name, motto FROM students ORDER BY "
+ "first_name DESC")
+ output_types = (dtypes.string, dtypes.string, dtypes.string)
+ return readers.SqlDataset(driver_name, data_source_name, query,
+ output_types).repeat(num_repeats)
+
+ def testSQLSaveable(self):
+ num_repeats = 4
+ num_outputs = num_repeats * 2
+ self.run_core_tests(lambda: self._build_dataset(num_repeats),
+ lambda: self._build_dataset(num_repeats // 2),
+ num_outputs)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/stats_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/stats_dataset_serialization_test.py
new file mode 100644
index 0000000000..ef7061b190
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/stats_dataset_serialization_test.py
@@ -0,0 +1,106 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the StatsDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import stats_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+# TODO(shivaniagrawal): Can not checkpoint input_pipeline with the
+# transformation `stats_ops.set_stats_aggregator`, since we don't support
+# serializing StatsAggregator yet.
+class StatsDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_dataset_bytes_stats(self, num_elements):
+ return dataset_ops.Dataset.range(num_elements).map(
+ lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).apply(
+ stats_ops.bytes_produced_stats("bytes_produced"))
+
+ def test_bytes_produced_stats_invalid_tag_shape(self):
+ with self.assertRaisesRegexp(
+ ValueError, "Shape must be rank 0 but is rank 1"):
+ # pylint: disable=g-long-lambda
+ self.run_core_tests(
+ lambda: dataset_ops.Dataset.range(100).apply(
+ stats_ops.bytes_produced_stats(["bytes_produced"])),
+ None, 100)
+ # pylint: enable=g-long-lambda
+
+ def testBytesStatsDatasetSaveableCore(self):
+ num_outputs = 100
+ self.run_core_tests(
+ lambda: self._build_dataset_bytes_stats(num_outputs),
+ lambda: self._build_dataset_bytes_stats(num_outputs // 10), num_outputs)
+
+ def _build_dataset_latency_stats(self, num_elements, tag="record_latency"):
+ return dataset_ops.Dataset.range(num_elements).apply(
+ stats_ops.latency_stats(tag))
+
+ def _build_dataset_multiple_tags(self,
+ num_elements,
+ tag1="record_latency",
+ tag2="record_latency_2"):
+ return dataset_ops.Dataset.range(num_elements).apply(
+ stats_ops.latency_stats(tag1)).apply(stats_ops.latency_stats(tag2))
+
+ def test_latency_stats_invalid_tag_shape(self):
+ with self.assertRaisesRegexp(
+ ValueError, "Shape must be rank 0 but is rank 1"):
+ # pylint: disable=g-long-lambda
+ self.run_core_tests(
+ lambda: dataset_ops.Dataset.range(100).apply(
+ stats_ops.latency_stats(["record_latency", "record_latency_2"])),
+ None, 100)
+ # pylint: enable=g-long-lambda
+
+ def testLatencyStatsDatasetSaveableCore(self):
+ num_outputs = 100
+
+ self.run_core_tests(
+ lambda: self._build_dataset_latency_stats(num_outputs),
+ lambda: self._build_dataset_latency_stats(num_outputs // 10),
+ num_outputs)
+
+ self.run_core_tests(lambda: self._build_dataset_multiple_tags(num_outputs),
+ None, num_outputs)
+
+ tag1 = "record_latency"
+ tag2 = "record_latency"
+ self.run_core_tests(
+ lambda: self._build_dataset_multiple_tags(num_outputs, tag1, tag2),
+ None, num_outputs)
+
+ def _build_dataset_stats_aggregator(self):
+ stats_aggregator = stats_ops.StatsAggregator()
+ return dataset_ops.Dataset.range(10).apply(
+ stats_ops.set_stats_aggregator(stats_aggregator))
+
+ def test_set_stats_aggregator_not_support_checkpointing(self):
+ with self.assertRaisesRegexp(errors.UnimplementedError,
+ "does not support checkpointing"):
+ self.run_core_tests(self._build_dataset_stats_aggregator, None, 10)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/textline_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/textline_dataset_serialization_test.py
new file mode 100644
index 0000000000..c87a7443a7
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/textline_dataset_serialization_test.py
@@ -0,0 +1,53 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the TextLineDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.ops import readers as core_readers
+from tensorflow.python.platform import test
+
+
+class TextLineDatasetSerializationTest(
+ reader_dataset_ops_test_base.TextLineDatasetTestBase,
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_iterator_graph(self, test_filenames, compression_type=None):
+ return core_readers.TextLineDataset(
+ test_filenames, compression_type=compression_type, buffer_size=10)
+
+ def testTextLineCore(self):
+ compression_types = [None, "GZIP", "ZLIB"]
+ num_files = 5
+ lines_per_file = 5
+ num_outputs = num_files * lines_per_file
+ for compression_type in compression_types:
+ test_filenames = self._createFiles(
+ num_files,
+ lines_per_file,
+ crlf=True,
+ compression_type=compression_type)
+ # pylint: disable=cell-var-from-loop
+ self.run_core_tests(
+ lambda: self._build_iterator_graph(test_filenames, compression_type),
+ lambda: self._build_iterator_graph(test_filenames), num_outputs)
+ # pylint: enable=cell-var-from-loop
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/tf_record_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/tf_record_dataset_serialization_test.py
new file mode 100644
index 0000000000..f0dcc131d4
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/tf_record_dataset_serialization_test.py
@@ -0,0 +1,99 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the TFRecordDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gzip
+import os
+import zlib
+
+from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.ops import readers as core_readers
+from tensorflow.python.platform import test
+
+
+class TFRecordDatasetSerializationTest(
+ reader_dataset_ops_test_base.TFRecordDatasetTestBase,
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_iterator_graph(self,
+ num_epochs,
+ batch_size=1,
+ compression_type=None,
+ buffer_size=None):
+ filenames = self._createFiles()
+ if compression_type == "ZLIB":
+ zlib_files = []
+ for i, fn in enumerate(filenames):
+ with open(fn, "rb") as f:
+ cdata = zlib.compress(f.read())
+ zfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.z" % i)
+ with open(zfn, "wb") as f:
+ f.write(cdata)
+ zlib_files.append(zfn)
+ filenames = zlib_files
+
+ elif compression_type == "GZIP":
+ gzip_files = []
+ for i, fn in enumerate(self.test_filenames):
+ with open(fn, "rb") as f:
+ gzfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.gz" % i)
+ with gzip.GzipFile(gzfn, "wb") as gzf:
+ gzf.write(f.read())
+ gzip_files.append(gzfn)
+ filenames = gzip_files
+
+ return core_readers.TFRecordDataset(
+ filenames, compression_type,
+ buffer_size=buffer_size).repeat(num_epochs).batch(batch_size)
+
+ def testTFRecordWithoutBufferCore(self):
+ num_epochs = 5
+ batch_size = num_epochs
+ num_outputs = num_epochs * self._num_files * self._num_records // batch_size
+ # pylint: disable=g-long-lambda
+ self.run_core_tests(
+ lambda: self._build_iterator_graph(num_epochs, batch_size,
+ buffer_size=0),
+ lambda: self._build_iterator_graph(num_epochs * 2, batch_size),
+ num_outputs)
+ self.run_core_tests(
+ lambda: self._build_iterator_graph(num_epochs, buffer_size=0), None,
+ num_outputs * batch_size)
+ # pylint: enable=g-long-lambda
+
+ def testTFRecordWithBufferCore(self):
+ num_epochs = 5
+ num_outputs = num_epochs * self._num_files * self._num_records
+ self.run_core_tests(lambda: self._build_iterator_graph(num_epochs),
+ lambda: self._build_iterator_graph(num_epochs * 2),
+ num_outputs)
+
+ def testTFRecordWithCompressionCore(self):
+ num_epochs = 5
+ num_outputs = num_epochs * self._num_files * self._num_records
+ self.run_core_tests(
+ lambda: self._build_iterator_graph(num_epochs, compression_type="ZLIB"),
+ lambda: self._build_iterator_graph(num_epochs * 2), num_outputs)
+ self.run_core_tests(
+ lambda: self._build_iterator_graph(num_epochs, compression_type="GZIP"),
+ lambda: self._build_iterator_graph(num_epochs * 2), num_outputs)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/unbatch_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/unbatch_dataset_serialization_test.py
new file mode 100644
index 0000000000..528598dfe4
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/unbatch_dataset_serialization_test.py
@@ -0,0 +1,51 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the UnbatchDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import batching
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.platform import test
+
+
+class UnbatchDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def build_dataset(self, multiplier=15.0, tensor_slice_len=2, batch_size=2):
+ components = (
+ np.arange(tensor_slice_len),
+ np.array([[1, 2, 3]]) * np.arange(tensor_slice_len)[:, np.newaxis],
+ np.array(multiplier) * np.arange(tensor_slice_len))
+
+ return dataset_ops.Dataset.from_tensor_slices(components).batch(
+ batch_size).apply(batching.unbatch())
+
+ def testCore(self):
+ tensor_slice_len = 8
+ batch_size = 2
+ num_outputs = tensor_slice_len
+ self.run_core_tests(
+ lambda: self.build_dataset(15.0, tensor_slice_len, batch_size),
+ lambda: self.build_dataset(20.0, tensor_slice_len, batch_size),
+ num_outputs)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/unique_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/unique_dataset_serialization_test.py
new file mode 100644
index 0000000000..e2862af4d6
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/unique_dataset_serialization_test.py
@@ -0,0 +1,40 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the UniqueDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import unique
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.platform import test
+
+
+class UniqueDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def testUnique(self):
+
+ def build_dataset(num_elements, unique_elem_range):
+ return dataset_ops.Dataset.range(num_elements).map(
+ lambda x: x % unique_elem_range).apply(unique.unique())
+
+ self.run_core_tests(lambda: build_dataset(200, 100),
+ lambda: build_dataset(40, 100), 100)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/zip_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/zip_dataset_serialization_test.py
new file mode 100644
index 0000000000..4ea6131c22
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/zip_dataset_serialization_test.py
@@ -0,0 +1,54 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the ZipDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.platform import test
+
+
+class ZipDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_dataset(self, arr):
+ components = [
+ np.tile(np.array([[1], [2], [3], [4]]), 20),
+ np.tile(np.array([[12], [13], [14], [15]]), 22),
+ np.array(arr)
+ ]
+ datasets = [
+ dataset_ops.Dataset.from_tensor_slices(component)
+ for component in components
+ ]
+ return dataset_ops.Dataset.zip((datasets[0], (datasets[1], datasets[2])))
+
+ def testCore(self):
+ # Equal length components
+ arr = [37.0, 38.0, 39.0, 40.0]
+ num_outputs = len(arr)
+ self.run_core_tests(lambda: self._build_dataset(arr), None, num_outputs)
+ # Variable length components
+ diff_size_arr = [1.0, 2.0]
+ self.run_core_tests(lambda: self._build_dataset(diff_size_arr),
+ lambda: self._build_dataset(arr), 2)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization_integration_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization_integration_test.py
new file mode 100644
index 0000000000..88d5c896c9
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization_integration_test.py
@@ -0,0 +1,85 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Integration test for dataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from tensorflow.python.data.experimental.ops import iterator_ops as contrib_iterator_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import ops
+from tensorflow.python.platform import test
+from tensorflow.python.training import saver as saver_lib
+
+
+class SerializationIntegrationTest(test.TestCase):
+
+ def _build_input_pipeline(self, name, num_outputs):
+ with ops.name_scope(name):
+ ds = dataset_ops.Dataset.range(num_outputs).shuffle(
+ 10, reshuffle_each_iteration=False).prefetch(10)
+ iterator = ds.make_initializable_iterator()
+ saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
+ ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
+ return iterator.initializer, iterator.get_next()
+
+ def _build_graph(self, num_pipelines, num_outputs):
+ init_ops = []
+ get_next_ops = []
+ for i in range(num_pipelines):
+ name = "input_pipeline_%d" % i
+ init_op, get_next_op = self._build_input_pipeline(name, num_outputs)
+ init_ops.append(init_op)
+ get_next_ops.append(get_next_op)
+ saver = saver_lib.Saver()
+ return init_ops, get_next_ops, saver
+
+ def _ckpt_path(self):
+ return os.path.join(self.get_temp_dir(), "iterator")
+
+ def testConcurrentSaves(self):
+ num_pipelines = 100
+ num_outputs = 100
+ break_point = 10
+ all_outputs = [[] for _ in range(num_pipelines)]
+ with ops.Graph().as_default() as g:
+ init_ops, get_next_ops, saver = self._build_graph(num_pipelines,
+ num_outputs)
+ with self.session(graph=g) as sess:
+ sess.run(init_ops)
+ for _ in range(break_point):
+ output = sess.run(get_next_ops)
+ for i in range(num_pipelines):
+ all_outputs[i].append(output[i])
+ saver.save(sess, self._ckpt_path())
+
+ with ops.Graph().as_default() as g:
+ init_ops, get_next_ops, saver = self._build_graph(num_pipelines,
+ num_outputs)
+ with self.session(graph=g) as sess:
+ saver.restore(sess, self._ckpt_path())
+ for _ in range(num_outputs - break_point):
+ output = sess.run(get_next_ops)
+ for i in range(num_pipelines):
+ all_outputs[i].append(output[i])
+
+ for output in all_outputs:
+ self.assertSequenceEqual(sorted(output), range(num_outputs))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/shuffle_dataset_op_test.py
new file mode 100644
index 0000000000..50895b5945
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/shuffle_dataset_op_test.py
@@ -0,0 +1,115 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.ops import shuffle_ops
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.platform import test
+
+
+class ShuffleAndRepeatTest(test_base.DatasetTestBase):
+
+ def _build_ds(self, seed, count=5, num_elements=20):
+ return dataset_ops.Dataset.range(num_elements).apply(
+ shuffle_ops.shuffle_and_repeat(buffer_size=5, count=count, seed=seed))
+
+ def _gen_outputs(self, ds_fn, num_outputs, verify_exhausted=True):
+ get_next = ds_fn().make_one_shot_iterator().get_next()
+ outputs = []
+ with self.cached_session() as sess:
+ for _ in range(num_outputs):
+ outputs.append(sess.run(get_next))
+ if verify_exhausted:
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+ return outputs
+
+ def testCorrectOutput(self):
+ output = self._gen_outputs(lambda: self._build_ds(10), 100)
+ self.assertSequenceEqual(
+ sorted(output), sorted(
+ np.array([range(20) for _ in range(5)]).flatten()))
+ for i in range(5):
+ self.assertSequenceEqual(sorted(output[i * 20:(i + 1) * 20]), range(20))
+
+ def testReshuffling(self):
+ # Check that the output orders of different epochs are indeed different.
+ output = self._gen_outputs(lambda: self._build_ds(10), 100)
+ for i in range(4):
+ epoch1 = output[i * 20:(i + 1) * 20]
+ epoch2 = output[(i + 1) * 20:(i + 2) * 20]
+ self.assertNotEqual(epoch1, epoch2)
+
+ def testSameOrderForSameSeeds(self):
+ output1 = self._gen_outputs(lambda: self._build_ds(10), 100)
+ output2 = self._gen_outputs(lambda: self._build_ds(10), 100)
+ self.assertEqual(output1, output2)
+
+ def testDifferentOrderForDifferentSeeds(self):
+ output1 = self._gen_outputs(lambda: self._build_ds(10), 100)
+ output2 = self._gen_outputs(lambda: self._build_ds(20), 100)
+ self.assertNotEqual(output1, output2)
+ self.assertEqual(sorted(output1), sorted(output2))
+
+ def testCountNone(self):
+ output1 = self._gen_outputs(
+ lambda: self._build_ds(10, count=None), 100, verify_exhausted=False)
+ output2 = self._gen_outputs(
+ lambda: self._build_ds(20, count=None), 100, verify_exhausted=False)
+ self.assertNotEqual(output1, output2)
+ self.assertEqual(sorted(output1), sorted(output2))
+
+ def testCountMinusOne(self):
+ output1 = self._gen_outputs(
+ lambda: self._build_ds(10, count=-1), 100, verify_exhausted=False)
+ output2 = self._gen_outputs(
+ lambda: self._build_ds(20, count=-1), 100, verify_exhausted=False)
+ self.assertNotEqual(output1, output2)
+ self.assertEqual(sorted(output1), sorted(output2))
+
+ def testInfiniteOutputs(self):
+ # Asserting the iterator is exhausted after producing 100 items should fail.
+ with self.assertRaises(AssertionError):
+ self._gen_outputs(lambda: self._build_ds(10, count=None), 100)
+ with self.assertRaises(AssertionError):
+ self._gen_outputs(lambda: self._build_ds(10, count=-1), 100)
+
+ def testInfiniteEmpty(self):
+ with self.assertRaises(errors.OutOfRangeError):
+ self._gen_outputs(lambda: self._build_ds(10, count=None, num_elements=0),
+ 100)
+ with self.assertRaises(errors.OutOfRangeError):
+ self._gen_outputs(lambda: self._build_ds(10, count=-1, num_elements=0),
+ 100)
+
+ def testLargeBufferSize(self):
+ with ops.Graph().as_default() as g:
+ ds = dataset_ops.Dataset.range(20).apply(
+ shuffle_ops.shuffle_and_repeat(buffer_size=21))
+ get_next_op = ds.make_one_shot_iterator().get_next()
+ with self.session(graph=g) as sess:
+ sess.run(get_next_op)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/sql_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/sql_dataset_op_test.py
new file mode 100644
index 0000000000..301f75488a
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/sql_dataset_op_test.py
@@ -0,0 +1,590 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for experimental sql input op."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.kernel_tests import sql_dataset_op_test_base
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.platform import test
+
+
+class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
+
+ # Test that SqlDataset can read from a database table.
+ def testReadResultSet(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
+ dtypes.string), 2)
+ with self.cached_session() as sess:
+ for _ in range(2): # Run twice to verify statelessness of db operations.
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query: "SELECT first_name, last_name, motto FROM students "
+ "ORDER BY first_name DESC"
+ })
+ for _ in range(2): # Dataset is repeated. See setUp.
+ self.assertEqual((b"John", b"Doe", b"Hi!"), sess.run(get_next))
+ self.assertEqual((b"Jane", b"Moe", b"Hi again!"), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test that SqlDataset works on a join query.
+ def testReadResultSetJoinQuery(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
+ dtypes.string))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query:
+ "SELECT students.first_name, state, motto FROM students "
+ "INNER JOIN people "
+ "ON students.first_name = people.first_name "
+ "AND students.last_name = people.last_name"
+ })
+ self.assertEqual((b"John", b"California", b"Hi!"), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test that SqlDataset can read a database entry with a null-terminator
+ # in the middle of the text and place the entry in a `string` tensor.
+ def testReadResultSetNullTerminator(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
+ dtypes.string))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query:
+ "SELECT first_name, last_name, favorite_nonsense_word "
+ "FROM students ORDER BY first_name DESC"
+ })
+ self.assertEqual((b"John", b"Doe", b"n\0nsense"), sess.run(get_next))
+ self.assertEqual((b"Jane", b"Moe", b"nonsense\0"), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test that SqlDataset works when used on two different queries.
+ # Because the output types of the dataset must be determined at graph-creation
+ # time, the two queries must have the same number and types of columns.
+ def testReadResultSetReuseSqlDataset(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
+ dtypes.string))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query: "SELECT first_name, last_name, motto FROM students "
+ "ORDER BY first_name DESC"
+ })
+ self.assertEqual((b"John", b"Doe", b"Hi!"), sess.run(get_next))
+ self.assertEqual((b"Jane", b"Moe", b"Hi again!"), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query: "SELECT first_name, last_name, state FROM people "
+ "ORDER BY first_name DESC"
+ })
+ self.assertEqual((b"John", b"Doe", b"California"), sess.run(get_next))
+ self.assertEqual((b"Benjamin", b"Franklin", b"Pennsylvania"),
+ sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test that an `OutOfRangeError` is raised on the first call to
+ # `get_next_str_only` if result set is empty.
+ def testReadEmptyResultSet(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
+ dtypes.string))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query: "SELECT first_name, last_name, motto FROM students "
+ "WHERE first_name = 'Nonexistent'"
+ })
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test that an error is raised when `driver_name` is invalid.
+ def testReadResultSetWithInvalidDriverName(self):
+ init_op = self._createSqlDataset((dtypes.string, dtypes.string,
+ dtypes.string))[0]
+ with self.cached_session() as sess:
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(
+ init_op,
+ feed_dict={
+ self.driver_name: "sqlfake",
+ self.query: "SELECT first_name, last_name, motto FROM students "
+ "ORDER BY first_name DESC"
+ })
+
+ # Test that an error is raised when a column name in `query` is nonexistent
+ def testReadResultSetWithInvalidColumnName(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
+ dtypes.string))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query:
+ "SELECT first_name, last_name, fake_column FROM students "
+ "ORDER BY first_name DESC"
+ })
+ with self.assertRaises(errors.UnknownError):
+ sess.run(get_next)
+
+ # Test that an error is raised when there is a syntax error in `query`.
+ def testReadResultSetOfQueryWithSyntaxError(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
+ dtypes.string))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query:
+ "SELEmispellECT first_name, last_name, motto FROM students "
+ "ORDER BY first_name DESC"
+ })
+ with self.assertRaises(errors.UnknownError):
+ sess.run(get_next)
+
+ # Test that an error is raised when the number of columns in `query`
+ # does not match the length of `output_types`.
+ def testReadResultSetWithMismatchBetweenColumnsAndOutputTypes(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
+ dtypes.string))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query: "SELECT first_name, last_name FROM students "
+ "ORDER BY first_name DESC"
+ })
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(get_next)
+
+ # Test that no results are returned when `query` is an insert query rather
+ # than a select query. In particular, the error refers to the number of
+ # output types passed to the op not matching the number of columns in the
+ # result set of the query (namely, 0 for an insert statement.)
+ def testReadResultSetOfInsertQuery(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
+ dtypes.string))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query:
+ "INSERT INTO students (first_name, last_name, motto) "
+ "VALUES ('Foo', 'Bar', 'Baz'), ('Fizz', 'Buzz', 'Fizzbuzz')"
+ })
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(get_next)
+
+ # Test that `SqlDataset` can read an integer from a SQLite database table and
+ # place it in an `int8` tensor.
+ def testReadResultSetInt8(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int8))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query: "SELECT first_name, desk_number FROM students "
+ "ORDER BY first_name DESC"
+ })
+ self.assertEqual((b"John", 9), sess.run(get_next))
+ self.assertEqual((b"Jane", 127), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test that `SqlDataset` can read a negative or 0-valued integer from a
+ # SQLite database table and place it in an `int8` tensor.
+ def testReadResultSetInt8NegativeAndZero(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int8,
+ dtypes.int8))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query: "SELECT first_name, income, favorite_negative_number "
+ "FROM students "
+ "WHERE first_name = 'John' ORDER BY first_name DESC"
+ })
+ self.assertEqual((b"John", 0, -2), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test that `SqlDataset` can read a large (positive or negative) integer from
+ # a SQLite database table and place it in an `int8` tensor.
+ def testReadResultSetInt8MaxValues(self):
+ init_op, get_next = self._createSqlDataset((dtypes.int8, dtypes.int8))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query:
+ "SELECT desk_number, favorite_negative_number FROM students "
+ "ORDER BY first_name DESC"
+ })
+ self.assertEqual((9, -2), sess.run(get_next))
+ # Max and min values of int8
+ self.assertEqual((127, -128), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test that `SqlDataset` can read an integer from a SQLite database table and
+ # place it in an `int16` tensor.
+ def testReadResultSetInt16(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query: "SELECT first_name, desk_number FROM students "
+ "ORDER BY first_name DESC"
+ })
+ self.assertEqual((b"John", 9), sess.run(get_next))
+ self.assertEqual((b"Jane", 127), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test that `SqlDataset` can read a negative or 0-valued integer from a
+ # SQLite database table and place it in an `int16` tensor.
+ def testReadResultSetInt16NegativeAndZero(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16,
+ dtypes.int16))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query: "SELECT first_name, income, favorite_negative_number "
+ "FROM students "
+ "WHERE first_name = 'John' ORDER BY first_name DESC"
+ })
+ self.assertEqual((b"John", 0, -2), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test that `SqlDataset` can read a large (positive or negative) integer from
+ # a SQLite database table and place it in an `int16` tensor.
+ def testReadResultSetInt16MaxValues(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query: "SELECT first_name, favorite_medium_sized_number "
+ "FROM students ORDER BY first_name DESC"
+ })
+ # Max value of int16
+ self.assertEqual((b"John", 32767), sess.run(get_next))
+ # Min value of int16
+ self.assertEqual((b"Jane", -32768), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test that `SqlDataset` can read an integer from a SQLite database table and
+ # place it in an `int32` tensor.
+ def testReadResultSetInt32(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query: "SELECT first_name, desk_number FROM students "
+ "ORDER BY first_name DESC"
+ })
+ self.assertEqual((b"John", 9), sess.run(get_next))
+ self.assertEqual((b"Jane", 127), sess.run(get_next))
+
+ # Test that `SqlDataset` can read a negative or 0-valued integer from a
+ # SQLite database table and place it in an `int32` tensor.
+ def testReadResultSetInt32NegativeAndZero(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query: "SELECT first_name, income FROM students "
+ "ORDER BY first_name DESC"
+ })
+ self.assertEqual((b"John", 0), sess.run(get_next))
+ self.assertEqual((b"Jane", -20000), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test that `SqlDataset` can read a large (positive or negative) integer from
+ # a SQLite database table and place it in an `int32` tensor.
+ def testReadResultSetInt32MaxValues(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query: "SELECT first_name, favorite_number FROM students "
+ "ORDER BY first_name DESC"
+ })
+ # Max value of int32
+ self.assertEqual((b"John", 2147483647), sess.run(get_next))
+ # Min value of int32
+ self.assertEqual((b"Jane", -2147483648), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test that `SqlDataset` can read a numeric `varchar` from a SQLite database
+ # table and place it in an `int32` tensor.
+ def testReadResultSetInt32VarCharColumnAsInt(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query: "SELECT first_name, school_id FROM students "
+ "ORDER BY first_name DESC"
+ })
+ self.assertEqual((b"John", 123), sess.run(get_next))
+ self.assertEqual((b"Jane", 1000), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test that `SqlDataset` can read an integer from a SQLite database table
+ # and place it in an `int64` tensor.
+ def testReadResultSetInt64(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query: "SELECT first_name, desk_number FROM students "
+ "ORDER BY first_name DESC"
+ })
+ self.assertEqual((b"John", 9), sess.run(get_next))
+ self.assertEqual((b"Jane", 127), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test that `SqlDataset` can read a negative or 0-valued integer from a
+ # SQLite database table and place it in an `int64` tensor.
+ def testReadResultSetInt64NegativeAndZero(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query: "SELECT first_name, income FROM students "
+ "ORDER BY first_name DESC"
+ })
+ self.assertEqual((b"John", 0), sess.run(get_next))
+ self.assertEqual((b"Jane", -20000), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test that `SqlDataset` can read a large (positive or negative) integer from
+ # a SQLite database table and place it in an `int64` tensor.
+ def testReadResultSetInt64MaxValues(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query:
+ "SELECT first_name, favorite_big_number FROM students "
+ "ORDER BY first_name DESC"
+ })
+ # Max value of int64
+ self.assertEqual((b"John", 9223372036854775807), sess.run(get_next))
+ # Min value of int64
+ self.assertEqual((b"Jane", -9223372036854775808), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test that `SqlDataset` can read an integer from a SQLite database table and
+ # place it in a `uint8` tensor.
+ def testReadResultSetUInt8(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint8))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query: "SELECT first_name, desk_number FROM students "
+ "ORDER BY first_name DESC"
+ })
+ self.assertEqual((b"John", 9), sess.run(get_next))
+ self.assertEqual((b"Jane", 127), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test that `SqlDataset` can read the minimum and maximum uint8 values from a
+ # SQLite database table and place them in `uint8` tensors.
+ def testReadResultSetUInt8MinAndMaxValues(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint8))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query: "SELECT first_name, brownie_points FROM students "
+ "ORDER BY first_name DESC"
+ })
+ # Min value of uint8
+ self.assertEqual((b"John", 0), sess.run(get_next))
+ # Max value of uint8
+ self.assertEqual((b"Jane", 255), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test that `SqlDataset` can read an integer from a SQLite database table
+ # and place it in a `uint16` tensor.
+ def testReadResultSetUInt16(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint16))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query: "SELECT first_name, desk_number FROM students "
+ "ORDER BY first_name DESC"
+ })
+ self.assertEqual((b"John", 9), sess.run(get_next))
+ self.assertEqual((b"Jane", 127), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test that `SqlDataset` can read the minimum and maximum uint16 values from a
+ # SQLite database table and place them in `uint16` tensors.
+ def testReadResultSetUInt16MinAndMaxValues(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint16))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query: "SELECT first_name, account_balance FROM students "
+ "ORDER BY first_name DESC"
+ })
+ # Min value of uint16
+ self.assertEqual((b"John", 0), sess.run(get_next))
+ # Max value of uint16
+ self.assertEqual((b"Jane", 65535), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test that `SqlDataset` can read a 0-valued and 1-valued integer from a
+ # SQLite database table and place them as `True` and `False` respectively
+ # in `bool` tensors.
+ def testReadResultSetBool(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.bool))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query:
+ "SELECT first_name, registration_complete FROM students "
+ "ORDER BY first_name DESC"
+ })
+ self.assertEqual((b"John", True), sess.run(get_next))
+ self.assertEqual((b"Jane", False), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test that `SqlDataset` can read an integer that is not 0-valued or 1-valued
+ # from a SQLite database table and place it as `True` in a `bool` tensor.
+ def testReadResultSetBoolNotZeroOrOne(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.bool))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query: "SELECT first_name, favorite_medium_sized_number "
+ "FROM students ORDER BY first_name DESC"
+ })
+ self.assertEqual((b"John", True), sess.run(get_next))
+ self.assertEqual((b"Jane", True), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test that `SqlDataset` can read a float from a SQLite database table
+ # and place it in a `float64` tensor.
+ def testReadResultSetFloat64(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
+ dtypes.float64))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query:
+ "SELECT first_name, last_name, victories FROM townspeople "
+ "ORDER BY first_name"
+ })
+ self.assertEqual((b"George", b"Washington", 20.0), sess.run(get_next))
+ self.assertEqual((b"John", b"Adams", -19.95), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test that `SqlDataset` can read a float from a SQLite database table beyond
+ # the precision of 64-bit IEEE, without throwing an error. Test that
+ # `SqlDataset` identifies such a value as equal to itself.
+ def testReadResultSetFloat64OverlyPrecise(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
+ dtypes.float64))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query:
+ "SELECT first_name, last_name, accolades FROM townspeople "
+ "ORDER BY first_name"
+ })
+ self.assertEqual(
+ (b"George", b"Washington",
+ 1331241.321342132321324589798264627463827647382647382643874),
+ sess.run(get_next))
+ self.assertEqual(
+ (b"John", b"Adams",
+ 1331241321342132321324589798264627463827647382647382643874.0),
+ sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test that `SqlDataset` can read a float from a SQLite database table,
+ # representing the largest integer representable as a 64-bit IEEE float
+ # such that the previous integer is also representable as a 64-bit IEEE float.
+ # Test that `SqlDataset` can distinguish these two numbers.
+ def testReadResultSetFloat64LargestConsecutiveWholeNumbersNotEqual(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
+ dtypes.float64))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query:
+ "SELECT first_name, last_name, triumphs FROM townspeople "
+ "ORDER BY first_name"
+ })
+ self.assertNotEqual((b"George", b"Washington", 9007199254740992.0),
+ sess.run(get_next))
+ self.assertNotEqual((b"John", b"Adams", 9007199254740991.0),
+ sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/sql_dataset_op_test_base.py b/tensorflow/python/data/experimental/kernel_tests/sql_dataset_op_test_base.py
new file mode 100644
index 0000000000..a135c357f0
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/sql_dataset_op_test_base.py
@@ -0,0 +1,95 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Base class for testing SqlDataset."""
+
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+import sqlite3
+
+from tensorflow.python.data.experimental.ops import readers
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class SqlDatasetTestBase(test_base.DatasetTestBase):
+ """Base class for setting up and testing SqlDataset."""
+
+ def _createSqlDataset(self, output_types, num_repeats=1):
+ dataset = readers.SqlDataset(self.driver_name, self.data_source_name,
+ self.query, output_types).repeat(num_repeats)
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ return init_op, get_next
+
+ def setUp(self):
+ self.data_source_name = os.path.join(test.get_temp_dir(), "tftest.sqlite")
+ self.driver_name = array_ops.placeholder_with_default(
+ array_ops.constant("sqlite", dtypes.string), shape=[])
+ self.query = array_ops.placeholder(dtypes.string, shape=[])
+
+ conn = sqlite3.connect(self.data_source_name)
+ c = conn.cursor()
+ c.execute("DROP TABLE IF EXISTS students")
+ c.execute("DROP TABLE IF EXISTS people")
+ c.execute("DROP TABLE IF EXISTS townspeople")
+ c.execute(
+ "CREATE TABLE IF NOT EXISTS students (id INTEGER NOT NULL PRIMARY KEY, "
+ "first_name VARCHAR(100), last_name VARCHAR(100), motto VARCHAR(100), "
+ "school_id VARCHAR(100), favorite_nonsense_word VARCHAR(100), "
+ "desk_number INTEGER, income INTEGER, favorite_number INTEGER, "
+ "favorite_big_number INTEGER, favorite_negative_number INTEGER, "
+ "favorite_medium_sized_number INTEGER, brownie_points INTEGER, "
+ "account_balance INTEGER, registration_complete INTEGER)")
+ c.executemany(
+ "INSERT INTO students (first_name, last_name, motto, school_id, "
+ "favorite_nonsense_word, desk_number, income, favorite_number, "
+ "favorite_big_number, favorite_negative_number, "
+ "favorite_medium_sized_number, brownie_points, account_balance, "
+ "registration_complete) "
+ "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
+ [("John", "Doe", "Hi!", "123", "n\0nsense", 9, 0, 2147483647,
+ 9223372036854775807, -2, 32767, 0, 0, 1),
+ ("Jane", "Moe", "Hi again!", "1000", "nonsense\0", 127, -20000,
+ -2147483648, -9223372036854775808, -128, -32768, 255, 65535, 0)])
+ c.execute(
+ "CREATE TABLE IF NOT EXISTS people (id INTEGER NOT NULL PRIMARY KEY, "
+ "first_name VARCHAR(100), last_name VARCHAR(100), state VARCHAR(100))")
+ c.executemany(
+ "INSERT INTO PEOPLE (first_name, last_name, state) VALUES (?, ?, ?)",
+ [("Benjamin", "Franklin", "Pennsylvania"), ("John", "Doe",
+ "California")])
+ c.execute(
+ "CREATE TABLE IF NOT EXISTS townspeople (id INTEGER NOT NULL PRIMARY "
+ "KEY, first_name VARCHAR(100), last_name VARCHAR(100), victories "
+ "FLOAT, accolades FLOAT, triumphs FLOAT)")
+ c.executemany(
+ "INSERT INTO townspeople (first_name, last_name, victories, "
+ "accolades, triumphs) VALUES (?, ?, ?, ?, ?)",
+ [("George", "Washington", 20.00,
+ 1331241.321342132321324589798264627463827647382647382643874,
+ 9007199254740991.0),
+ ("John", "Adams", -19.95,
+ 1331241321342132321324589798264627463827647382647382643874.0,
+ 9007199254740992.0)])
+ conn.commit()
+ conn.close()
diff --git a/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py
new file mode 100644
index 0000000000..6761fbd16b
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py
@@ -0,0 +1,253 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline statistics gathering ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.kernel_tests import stats_dataset_test_base
+from tensorflow.python.data.experimental.ops import stats_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
+
+ def testBytesProduced(self):
+ stats_aggregator = stats_ops.StatsAggregator()
+ dataset = dataset_ops.Dataset.range(100).map(
+ lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).apply(
+ stats_ops.bytes_produced_stats("bytes_produced")).apply(
+ stats_ops.set_stats_aggregator(stats_aggregator))
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+ summary_t = stats_aggregator.get_summary()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ expected_sum = 0.0
+ for i in range(100):
+ self.assertAllEqual(
+ np.array([i] * i, dtype=np.int64), sess.run(next_element))
+ summary_str = sess.run(summary_t)
+ self._assertSummaryHasCount(summary_str, "bytes_produced", float(i + 1))
+ expected_sum += i * 8.0
+ self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+ summary_str = sess.run(summary_t)
+ self._assertSummaryHasCount(summary_str, "bytes_produced", 100.0)
+ self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum)
+
+ def testLatencyStats(self):
+ stats_aggregator = stats_ops.StatsAggregator()
+ dataset = dataset_ops.Dataset.range(100).apply(
+ stats_ops.latency_stats("record_latency")).apply(
+ stats_ops.set_stats_aggregator(stats_aggregator))
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+ summary_t = stats_aggregator.get_summary()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(100):
+ self.assertEqual(i, sess.run(next_element))
+ self._assertSummaryHasCount(
+ sess.run(summary_t), "record_latency", float(i + 1))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+ self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 100.0)
+
+ def testPrefetchBufferUtilization(self):
+ stats_aggregator = stats_ops.StatsAggregator()
+ dataset = dataset_ops.Dataset.range(100).map(
+ lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).prefetch(
+ -1).apply(stats_ops.set_stats_aggregator(stats_aggregator))
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+ summary_t = stats_aggregator.get_summary()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(100):
+ self.assertAllEqual(
+ np.array([i] * i, dtype=np.int64), sess.run(next_element))
+ summary_str = sess.run(summary_t)
+ self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization",
+ float(i + 1))
+ self._assertSummaryContains(summary_str, "Prefetch::buffer_capacity")
+ self._assertSummaryContains(summary_str, "Prefetch::buffer_size")
+ self._assertSummaryHasRange(summary_str, "Prefetch::buffer_utilization",
+ 0, 1)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+ summary_str = sess.run(summary_t)
+ self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization",
+ 100)
+
+ def testPrefetchBufferScalars(self):
+ stats_aggregator = stats_ops.StatsAggregator()
+ dataset = dataset_ops.Dataset.range(10).map(
+ lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).prefetch(
+ 0).apply(stats_ops.set_stats_aggregator(stats_aggregator))
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+ summary_t = stats_aggregator.get_summary()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(10):
+ self.assertAllEqual(
+ np.array([i] * i, dtype=np.int64), sess.run(next_element))
+ summary_str = sess.run(summary_t)
+ self._assertSummaryHasScalarValue(summary_str,
+ "Prefetch::buffer_capacity", 0)
+ self._assertSummaryHasScalarValue(summary_str, "Prefetch::buffer_size",
+ 0)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testFilteredElementsStats(self):
+ stats_aggregator = stats_ops.StatsAggregator()
+ dataset = dataset_ops.Dataset.range(101).filter(
+ lambda x: math_ops.equal(math_ops.mod(x, 3), 0)).apply(
+ stats_ops.set_stats_aggregator(stats_aggregator))
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+ summary_t = stats_aggregator.get_summary()
+
+ with self.test_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(34):
+ self.assertEqual(i * 3, sess.run(next_element))
+ if i is not 0:
+ self._assertSummaryHasScalarValue(
+ sess.run(summary_t), "Filter::dropped_elements", float(i * 2))
+ self._assertSummaryHasScalarValue(
+ sess.run(summary_t), "Filter::filtered_elements", float(i + 1))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+ self._assertSummaryHasScalarValue(
+ sess.run(summary_t), "Filter::dropped_elements", 67.0)
+ self._assertSummaryHasScalarValue(
+ sess.run(summary_t), "Filter::filtered_elements", 34.0)
+
+ def testReinitialize(self):
+ stats_aggregator = stats_ops.StatsAggregator()
+ dataset = dataset_ops.Dataset.range(100).apply(
+ stats_ops.latency_stats("record_latency")).apply(
+ stats_ops.set_stats_aggregator(stats_aggregator))
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+ summary_t = stats_aggregator.get_summary()
+
+ with self.cached_session() as sess:
+ for j in range(5):
+ sess.run(iterator.initializer)
+ for i in range(100):
+ self.assertEqual(i, sess.run(next_element))
+ self._assertSummaryHasCount(
+ sess.run(summary_t), "record_latency", float((j * 100) + i + 1))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+ self._assertSummaryHasCount(
+ sess.run(summary_t), "record_latency", (j + 1) * 100.0)
+
+ def testNoAggregatorRegistered(self):
+ dataset = dataset_ops.Dataset.range(100).apply(
+ stats_ops.latency_stats("record_latency"))
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(100):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testMultipleTags(self):
+ stats_aggregator = stats_ops.StatsAggregator()
+ dataset = dataset_ops.Dataset.range(100).apply(
+ stats_ops.latency_stats("record_latency")).apply(
+ stats_ops.latency_stats("record_latency_2")).apply(
+ stats_ops.set_stats_aggregator(stats_aggregator))
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+ summary_t = stats_aggregator.get_summary()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(100):
+ self.assertEqual(i, sess.run(next_element))
+ self._assertSummaryHasCount(
+ sess.run(summary_t), "record_latency", float(i + 1))
+ self._assertSummaryHasCount(
+ sess.run(summary_t), "record_latency_2", float(i + 1))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+ self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 100.0)
+ self._assertSummaryHasCount(
+ sess.run(summary_t), "record_latency_2", 100.0)
+
+ def testRepeatedTags(self):
+ stats_aggregator = stats_ops.StatsAggregator()
+ dataset = dataset_ops.Dataset.range(100).apply(
+ stats_ops.latency_stats("record_latency")).apply(
+ stats_ops.latency_stats("record_latency")).apply(
+ stats_ops.set_stats_aggregator(stats_aggregator))
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+ summary_t = stats_aggregator.get_summary()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(100):
+ self.assertEqual(i, sess.run(next_element))
+ self._assertSummaryHasCount(
+ sess.run(summary_t), "record_latency", float(2 * (i + 1)))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+ self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0)
+
+ def testMultipleIteratorsSameAggregator(self):
+ stats_aggregator = stats_ops.StatsAggregator()
+ dataset = dataset_ops.Dataset.range(100).apply(
+ stats_ops.latency_stats("record_latency")).apply(
+ stats_ops.set_stats_aggregator(stats_aggregator))
+ iterator_0 = dataset.make_initializable_iterator()
+ iterator_1 = dataset.make_initializable_iterator()
+ next_element = iterator_0.get_next() + iterator_1.get_next()
+ summary_t = stats_aggregator.get_summary()
+
+ with self.cached_session() as sess:
+ sess.run([iterator_0.initializer, iterator_1.initializer])
+ for i in range(100):
+ self.assertEqual(i * 2, sess.run(next_element))
+ self._assertSummaryHasCount(
+ sess.run(summary_t), "record_latency", float(2 * (i + 1)))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+ self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/stats_dataset_test_base.py b/tensorflow/python/data/experimental/kernel_tests/stats_dataset_test_base.py
new file mode 100644
index 0000000000..80f2625927
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/stats_dataset_test_base.py
@@ -0,0 +1,71 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Base class for testing the input pipeline statistics gathering ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+from tensorflow.core.framework import summary_pb2
+from tensorflow.python.data.kernel_tests import test_base
+
+
+class StatsDatasetTestBase(test_base.DatasetTestBase):
+ """Base class for testing statistics gathered in `StatsAggregator`."""
+
+ def _assertSummaryContains(self, summary_str, tag):
+ summary_proto = summary_pb2.Summary()
+ summary_proto.ParseFromString(summary_str)
+ for value in summary_proto.value:
+ if tag == value.tag:
+ return
+ self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
+
+ def _assertSummaryHasCount(self, summary_str, tag, expected_value):
+ summary_proto = summary_pb2.Summary()
+ summary_proto.ParseFromString(summary_str)
+ for value in summary_proto.value:
+ if tag == value.tag:
+ self.assertEqual(expected_value, value.histo.num)
+ return
+ self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
+
+ def _assertSummaryHasRange(self, summary_str, tag, min_value, max_value):
+ summary_proto = summary_pb2.Summary()
+ summary_proto.ParseFromString(summary_str)
+ for value in summary_proto.value:
+ if tag == value.tag:
+ self.assertLessEqual(min_value, value.histo.min)
+ self.assertGreaterEqual(max_value, value.histo.max)
+ return
+ self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
+
+ def _assertSummaryHasSum(self, summary_str, tag, expected_value):
+ summary_proto = summary_pb2.Summary()
+ summary_proto.ParseFromString(summary_str)
+ for value in summary_proto.value:
+ if tag == value.tag:
+ self.assertEqual(expected_value, value.histo.sum)
+ return
+ self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
+
+ def _assertSummaryHasScalarValue(self, summary_str, tag, expected_value):
+ summary_proto = summary_pb2.Summary()
+ summary_proto.ParseFromString(summary_str)
+ for value in summary_proto.value:
+ if tag == value.tag:
+ self.assertEqual(expected_value, value.simple_value)
+ return
+ self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
diff --git a/tensorflow/python/data/experimental/kernel_tests/threadpool_dataset_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/threadpool_dataset_ops_test.py
new file mode 100644
index 0000000000..4432dcb05a
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/threadpool_dataset_ops_test.py
@@ -0,0 +1,91 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline statistics gathering ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import threading
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.python.data.experimental.ops import threadpool
+from tensorflow.python.data.experimental.ops import unique
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.ops import script_ops
+from tensorflow.python.platform import test
+
+
+class OverrideThreadpoolDatasetTest(test_base.DatasetTestBase,
+ parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ ("1", 1, None),
+ ("2", 2, None),
+ ("3", 4, None),
+ ("4", 8, None),
+ ("5", 16, None),
+ ("6", 4, -1),
+ ("7", 4, 0),
+ ("8", 4, 1),
+ ("9", 4, 4),
+ )
+ def testNumThreads(self, num_threads, max_intra_op_parallelism):
+
+ def get_thread_id(_):
+ # Python creates a dummy thread object to represent the current
+ # thread when called from an "alien" thread (such as a
+ # `PrivateThreadPool` thread in this case). It does not include
+ # the TensorFlow-given display name, but it has a unique
+ # identifier that maps one-to-one with the underlying OS thread.
+ return np.array(threading.current_thread().ident).astype(np.int64)
+
+ dataset = (
+ dataset_ops.Dataset.range(1000).map(
+ lambda x: script_ops.py_func(get_thread_id, [x], dtypes.int64),
+ num_parallel_calls=32).apply(unique.unique()))
+
+ dataset = threadpool.override_threadpool(
+ dataset,
+ threadpool.PrivateThreadPool(
+ num_threads,
+ max_intra_op_parallelism=max_intra_op_parallelism,
+ display_name="private_thread_pool_%d" % num_threads))
+
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ thread_ids = []
+ try:
+ while True:
+ thread_ids.append(sess.run(next_element))
+ except errors.OutOfRangeError:
+ pass
+ self.assertEqual(len(thread_ids), len(set(thread_ids)))
+ self.assertGreater(len(thread_ids), 0)
+ # NOTE(mrry): We don't control the thread pool scheduling, and
+ # so cannot guarantee that all of the threads in the pool will
+ # perform work.
+ self.assertLessEqual(len(thread_ids), num_threads)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/unique_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/unique_dataset_op_test.py
new file mode 100644
index 0000000000..b5a0b20f3f
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/unique_dataset_op_test.py
@@ -0,0 +1,83 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.ops import unique
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.platform import test
+from tensorflow.python.util import compat
+
+
+class UniqueDatasetTest(test_base.DatasetTestBase):
+
+ def _testSimpleHelper(self, dtype, test_cases):
+ """Test the `unique()` transformation on a list of test cases.
+
+ Args:
+ dtype: The `dtype` of the elements in each test case.
+ test_cases: A list of pairs of lists. The first component is the test
+ input that will be passed to the transformation; the second component
+ is the expected sequence of outputs from the transformation.
+ """
+
+ # The `current_test_case` will be updated when we loop over `test_cases`
+ # below; declare it here so that the generator can capture it once.
+ current_test_case = []
+ dataset = dataset_ops.Dataset.from_generator(lambda: current_test_case,
+ dtype).apply(unique.unique())
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ for test_case, expected in test_cases:
+ current_test_case = test_case
+ sess.run(iterator.initializer)
+ for element in expected:
+ if dtype == dtypes.string:
+ element = compat.as_bytes(element)
+ self.assertAllEqual(element, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testSimpleInt(self):
+ for dtype in [dtypes.int32, dtypes.int64]:
+ self._testSimpleHelper(dtype, [
+ ([], []),
+ ([1], [1]),
+ ([1, 1, 1, 1, 1, 1, 1], [1]),
+ ([1, 2, 3, 4], [1, 2, 3, 4]),
+ ([1, 2, 4, 3, 2, 1, 2, 3, 4], [1, 2, 4, 3]),
+ ([[1], [1, 1], [1, 1, 1]], [[1], [1, 1], [1, 1, 1]]),
+ ([[1, 1], [1, 1], [2, 2], [3, 3], [1, 1]], [[1, 1], [2, 2], [3, 3]]),
+ ])
+
+ def testSimpleString(self):
+ self._testSimpleHelper(dtypes.string, [
+ ([], []),
+ (["hello"], ["hello"]),
+ (["hello", "hello", "hello"], ["hello"]),
+ (["hello", "world"], ["hello", "world"]),
+ (["foo", "bar", "baz", "baz", "bar", "foo"], ["foo", "bar", "baz"]),
+ ])
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/writer_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/writer_ops_test.py
new file mode 100644
index 0000000000..25a2e63ba1
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/writer_ops_test.py
@@ -0,0 +1,118 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from tensorflow.python.data.experimental.ops import writers
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.ops import readers
+from tensorflow.python.framework import dtypes
+from tensorflow.python.lib.io import python_io
+from tensorflow.python.lib.io import tf_record
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+from tensorflow.python.util import compat
+
+
+class TFRecordWriterTest(test_base.DatasetTestBase):
+
+ def setUp(self):
+ super(TFRecordWriterTest, self).setUp()
+ self._num_records = 7
+ self.filename = array_ops.placeholder(dtypes.string, shape=[])
+ self.compression_type = array_ops.placeholder_with_default("", shape=[])
+
+ input_dataset = readers.TFRecordDataset([self.filename],
+ self.compression_type)
+ self.writer = writers.TFRecordWriter(
+ self._outputFilename(), self.compression_type).write(input_dataset)
+
+ def _record(self, i):
+ return compat.as_bytes("Record %d" % (i))
+
+ def _createFile(self, options=None):
+ filename = self._inputFilename()
+ writer = python_io.TFRecordWriter(filename, options)
+ for i in range(self._num_records):
+ writer.write(self._record(i))
+ writer.close()
+ return filename
+
+ def _inputFilename(self):
+ return os.path.join(self.get_temp_dir(), "tf_record.in.txt")
+
+ def _outputFilename(self):
+ return os.path.join(self.get_temp_dir(), "tf_record.out.txt")
+
+ def testWrite(self):
+ with self.cached_session() as sess:
+ sess.run(
+ self.writer, feed_dict={
+ self.filename: self._createFile(),
+ })
+ for i, r in enumerate(tf_record.tf_record_iterator(self._outputFilename())):
+ self.assertAllEqual(self._record(i), r)
+
+ def testWriteZLIB(self):
+ options = tf_record.TFRecordOptions(tf_record.TFRecordCompressionType.ZLIB)
+ with self.cached_session() as sess:
+ sess.run(
+ self.writer,
+ feed_dict={
+ self.filename: self._createFile(options),
+ self.compression_type: "ZLIB",
+ })
+ for i, r in enumerate(
+ tf_record.tf_record_iterator(self._outputFilename(), options=options)):
+ self.assertAllEqual(self._record(i), r)
+
+ def testWriteGZIP(self):
+ options = tf_record.TFRecordOptions(tf_record.TFRecordCompressionType.GZIP)
+ with self.cached_session() as sess:
+ sess.run(
+ self.writer,
+ feed_dict={
+ self.filename: self._createFile(options),
+ self.compression_type: "GZIP",
+ })
+ for i, r in enumerate(
+ tf_record.tf_record_iterator(self._outputFilename(), options=options)):
+ self.assertAllEqual(self._record(i), r)
+
+ def testFailDataset(self):
+ with self.assertRaises(TypeError):
+ writers.TFRecordWriter(self._outputFilename(),
+ self.compression_type).write("whoops")
+
+ def testFailDType(self):
+ input_dataset = dataset_ops.Dataset.from_tensors(10)
+ with self.assertRaises(TypeError):
+ writers.TFRecordWriter(self._outputFilename(),
+ self.compression_type).write(input_dataset)
+
+ def testFailShape(self):
+ input_dataset = dataset_ops.Dataset.from_tensors([["hello"], ["world"]])
+ with self.assertRaises(TypeError):
+ writers.TFRecordWriter(self._outputFilename(),
+ self.compression_type).write(input_dataset)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/ops/BUILD b/tensorflow/python/data/experimental/ops/BUILD
new file mode 100644
index 0000000000..915d399f1b
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/BUILD
@@ -0,0 +1,377 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_gen_op_wrapper_py",
+ "tf_kernel_library",
+)
+load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
+
+py_library(
+ name = "counter",
+ srcs = ["counter.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":scan_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_library(
+ name = "get_single_element",
+ srcs = ["get_single_element.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
+
+py_library(
+ name = "iterator_ops",
+ srcs = [
+ "iterator_ops.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:basic_session_run_hooks",
+ "//tensorflow/python:checkpoint_management",
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:saver",
+ "//tensorflow/python:session_run_hook",
+ "//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python/data/ops:optional_ops",
+ ],
+)
+
+py_library(
+ name = "random_ops",
+ srcs = [
+ "random_ops.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:random_seed",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
+
+py_library(
+ name = "readers",
+ srcs = [
+ "readers.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":batching",
+ ":interleave_ops",
+ ":optimization",
+ ":parsing_ops",
+ ":shuffle_ops",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:experimental_dataset_ops_gen",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:lib",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:readers",
+ "//tensorflow/python/data/util:convert",
+ "//tensorflow/python/data/util:nest",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_library(
+ name = "shuffle_ops",
+ srcs = [
+ "shuffle_ops.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_library(
+ name = "batching",
+ srcs = ["batching.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":get_single_element",
+ ":grouping",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python:tensor_util",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:convert",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_library(
+ name = "enumerate_ops",
+ srcs = ["enumerate_ops.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_library(
+ name = "error_ops",
+ srcs = ["error_ops.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:experimental_dataset_ops_gen",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
+
+py_library(
+ name = "grouping",
+ srcs = ["grouping.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:check_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:function",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
+
+py_library(
+ name = "interleave_ops",
+ srcs = ["interleave_ops.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":random_ops",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:experimental_dataset_ops_gen",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:stateless_random_ops_gen",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/ops:readers",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
+
+py_library(
+ name = "optimization",
+ srcs = ["optimization.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:experimental_dataset_ops_gen",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
+
+py_library(
+ name = "parsing_ops",
+ srcs = ["parsing_ops.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ ],
+)
+
+py_library(
+ name = "map_defun",
+ srcs = ["map_defun.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:tensor_shape",
+ ],
+)
+
+py_library(
+ name = "resampling",
+ srcs = ["resampling.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":batching",
+ ":interleave_ops",
+ ":scan_ops",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:logging_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:random_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_library(
+ name = "scan_ops",
+ srcs = ["scan_ops.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:function",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
+
+py_library(
+ name = "stats_ops",
+ srcs = ["stats_ops.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
+
+py_library(
+ name = "threadpool",
+ srcs = ["threadpool.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:experimental_dataset_ops_gen",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ "//tensorflow/python/eager:context",
+ ],
+)
+
+py_library(
+ name = "unique",
+ srcs = [
+ "unique.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:experimental_dataset_ops_gen",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
+
+py_library(
+ name = "writers",
+ srcs = [
+ "writers.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_library(
+ name = "indexed_dataset_ops",
+ srcs = ["indexed_dataset_ops.py"],
+ deps = [
+ "//tensorflow/python:experimental_dataset_ops_gen",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
+
+py_library(
+ name = "prefetching_ops",
+ srcs = ["prefetching_ops.py"],
+ deps = [
+ "//tensorflow/python:experimental_dataset_ops_gen",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
+
+py_library(
+ name = "dataset_ops",
+ deps = [
+ ":batching",
+ ":counter",
+ ":enumerate_ops",
+ ":error_ops",
+ ":get_single_element",
+ ":grouping",
+ ":indexed_dataset_ops",
+ ":interleave_ops",
+ ":map_defun",
+ ":optimization",
+ ":prefetching_ops",
+ ":readers",
+ ":resampling",
+ ":scan_ops",
+ ":shuffle_ops",
+ ":stats_ops",
+ ":threadpool",
+ ":unique",
+ ":writers",
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ ],
+)
diff --git a/tensorflow/python/data/experimental/ops/batching.py b/tensorflow/python/data/experimental/ops/batching.py
new file mode 100644
index 0000000000..d42af9e7e9
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/batching.py
@@ -0,0 +1,669 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Batching dataset transformations."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.ops import get_single_element
+from tensorflow.python.data.experimental.ops import grouping
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import convert
+from tensorflow.python.data.util import nest
+from tensorflow.python.data.util import sparse
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_array_ops
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import sparse_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+def batch_window(dataset):
+ """Batches a window of tensors.
+
+ Args:
+ dataset: the input dataset.
+
+ Returns:
+ A `Tensor` representing the batch of the entire input dataset.
+ """
+ if isinstance(dataset.output_classes, tuple):
+ raise TypeError("Input dataset expected to have a single component")
+ if dataset.output_classes is ops.Tensor:
+ return _batch_dense_window(dataset)
+ elif dataset.output_classes is sparse_tensor.SparseTensor:
+ return _batch_sparse_window(dataset)
+ else:
+ raise TypeError("Unsupported dataset type: %s" % dataset.output_classes)
+
+
+def _batch_dense_window(dataset):
+ """Batches a window of dense tensors."""
+
+ def key_fn(_):
+ return np.int64(0)
+
+ def shape_init_fn(_):
+ return array_ops.shape(first_element)
+
+ def shape_reduce_fn(state, value):
+ check_ops.assert_equal(state, array_ops.shape(value))
+ return state
+
+ def finalize_fn(state):
+ return state
+
+ if dataset.output_shapes.is_fully_defined():
+ shape = dataset.output_shapes
+ else:
+ first_element = get_single_element.get_single_element(dataset.take(1))
+ shape_reducer = grouping.Reducer(shape_init_fn, shape_reduce_fn,
+ finalize_fn)
+ shape = get_single_element.get_single_element(
+ dataset.apply(grouping.group_by_reducer(key_fn, shape_reducer)))
+
+ def batch_init_fn(_):
+ batch_shape = array_ops.concat([[0], shape], 0)
+ return gen_array_ops.empty(batch_shape, dtype=dataset.output_types)
+
+ def batch_reduce_fn(state, value):
+ return array_ops.concat([state, [value]], 0)
+
+ batch_reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn)
+ return get_single_element.get_single_element(
+ dataset.apply(grouping.group_by_reducer(key_fn, batch_reducer)))
+
+
+def _batch_sparse_window(dataset):
+ """Batches a window of sparse tensors."""
+
+ def key_fn(_):
+ return np.int64(0)
+
+ def shape_init_fn(_):
+ return first_element.dense_shape
+
+ def shape_reduce_fn(state, value):
+ check_ops.assert_equal(state, value.dense_shape)
+ return state
+
+ def finalize_fn(state):
+ return state
+
+ if dataset.output_shapes.is_fully_defined():
+ shape = dataset.output_shapes
+ else:
+ first_element = get_single_element.get_single_element(dataset.take(1))
+ shape_reducer = grouping.Reducer(shape_init_fn, shape_reduce_fn,
+ finalize_fn)
+ shape = get_single_element.get_single_element(
+ dataset.apply(grouping.group_by_reducer(key_fn, shape_reducer)))
+
+ def batch_init_fn(_):
+ indices_shape = array_ops.concat([[0], [array_ops.size(shape) + 1]], 0)
+ return sparse_tensor.SparseTensor(
+ indices=gen_array_ops.empty(indices_shape, dtype=dtypes.int64),
+ values=constant_op.constant([], shape=[0], dtype=dataset.output_types),
+ dense_shape=array_ops.concat(
+ [np.array([0], dtype=np.int64),
+ math_ops.cast(shape, dtypes.int64)], 0))
+
+ def batch_reduce_fn(state, value):
+ return sparse_ops.sparse_concat(0, [state, value])
+
+ def reshape_fn(value):
+ return sparse_ops.sparse_reshape(
+ value,
+ array_ops.concat([np.array([1], dtype=np.int64), value.dense_shape], 0))
+
+ batch_reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn)
+ return get_single_element.get_single_element(
+ dataset.map(reshape_fn).apply(
+ grouping.group_by_reducer(key_fn, batch_reducer)))
+
+
+@tf_export("data.experimental.dense_to_sparse_batch")
+def dense_to_sparse_batch(batch_size, row_shape):
+ """A transformation that batches ragged elements into `tf.SparseTensor`s.
+
+ Like `Dataset.padded_batch()`, this transformation combines multiple
+ consecutive elements of the dataset, which might have different
+ shapes, into a single element. The resulting element has three
+ components (`indices`, `values`, and `dense_shape`), which
+ comprise a `tf.SparseTensor` that represents the same data. The
+ `row_shape` represents the dense shape of each row in the
+ resulting `tf.SparseTensor`, to which the effective batch size is
+ prepended. For example:
+
+ ```python
+ # NOTE: The following examples use `{ ... }` to represent the
+ # contents of a dataset.
+ a = { ['a', 'b', 'c'], ['a', 'b'], ['a', 'b', 'c', 'd'] }
+
+ a.apply(tf.data.experimental.dense_to_sparse_batch(
+ batch_size=2, row_shape=[6])) ==
+ {
+ ([[0, 0], [0, 1], [0, 2], [1, 0], [1, 1]], # indices
+ ['a', 'b', 'c', 'a', 'b'], # values
+ [2, 6]), # dense_shape
+ ([[0, 0], [0, 1], [0, 2], [0, 3]],
+ ['a', 'b', 'c', 'd'],
+ [1, 6])
+ }
+ ```
+
+ Args:
+ batch_size: A `tf.int64` scalar `tf.Tensor`, representing the
+ number of consecutive elements of this dataset to combine in a
+ single batch.
+ row_shape: A `tf.TensorShape` or `tf.int64` vector tensor-like
+ object representing the equivalent dense shape of a row in the
+ resulting `tf.SparseTensor`. Each element of this dataset must
+ have the same rank as `row_shape`, and must have size less
+ than or equal to `row_shape` in each dimension.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+
+ def _apply_fn(dataset):
+ return _DenseToSparseBatchDataset(dataset, batch_size, row_shape)
+
+ return _apply_fn
+
+
+def padded_batch_window(dataset, padded_shape, padding_value=None):
+ """Batches a window of tensors with padding.
+
+ Args:
+ dataset: the input dataset.
+ padded_shape: (Optional.) `tf.TensorShape` or `tf.int64` vector tensor-like
+ object representing the shape to which the input elements should be padded
+ prior to batching. Any unknown dimensions (e.g. `tf.Dimension(None)` in a
+ `tf.TensorShape` or `-1` in a tensor-like object) will be padded to the
+ maximum size of that dimension in each batch.
+ padding_value: (Optional.) A scalar-shaped `tf.Tensor`, representing the
+ padding value to use. Defaults are `0` for numeric types and the empty
+ string for string types. If `dataset` contains `tf.SparseTensor`, this
+ value is ignored.
+
+ Returns:
+ A `Tensor` representing the batch of the entire input dataset.
+
+ Raises:
+ ValueError: if invalid arguments are provided.
+ """
+ if not issubclass(dataset.output_classes,
+ (ops.Tensor, sparse_tensor.SparseTensor)):
+ raise TypeError("Input dataset expected to have a single tensor component")
+ if issubclass(dataset.output_classes, (ops.Tensor)):
+ return _padded_batch_dense_window(dataset, padded_shape, padding_value)
+ elif issubclass(dataset.output_classes, (sparse_tensor.SparseTensor)):
+ if padding_value is not None:
+ raise ValueError("Padding value not allowed for sparse tensors")
+ return _padded_batch_sparse_window(dataset, padded_shape)
+ else:
+ raise TypeError("Unsupported dataset type: %s" % dataset.output_classes)
+
+
+def _padded_batch_dense_window(dataset, padded_shape, padding_value=None):
+ """Batches a window of dense tensors with padding."""
+
+ padded_shape = math_ops.cast(
+ convert.partial_shape_to_tensor(padded_shape), dtypes.int32)
+
+ def key_fn(_):
+ return np.int64(0)
+
+ def max_init_fn(_):
+ return padded_shape
+
+ def max_reduce_fn(state, value):
+ """Computes the maximum shape to pad to."""
+ condition = math_ops.reduce_all(
+ math_ops.logical_or(
+ math_ops.less_equal(array_ops.shape(value), padded_shape),
+ math_ops.equal(padded_shape, -1)))
+ assert_op = control_flow_ops.Assert(condition, [
+ "Actual shape greater than padded shape: ",
+ array_ops.shape(value), padded_shape
+ ])
+ with ops.control_dependencies([assert_op]):
+ return math_ops.maximum(state, array_ops.shape(value))
+
+ def finalize_fn(state):
+ return state
+
+ # Compute the padded shape.
+ max_reducer = grouping.Reducer(max_init_fn, max_reduce_fn, finalize_fn)
+ padded_shape = get_single_element.get_single_element(
+ dataset.apply(grouping.group_by_reducer(key_fn, max_reducer)))
+
+ if padding_value is None:
+ if dataset.output_types == dtypes.string:
+ padding_value = ""
+ elif dataset.output_types == dtypes.bool:
+ padding_value = False
+ elif dataset.output_types == dtypes.variant:
+ raise TypeError("Unable to create padding for field of type 'variant'")
+ else:
+ padding_value = 0
+
+ def batch_init_fn(_):
+ batch_shape = array_ops.concat(
+ [np.array([0], dtype=np.int32), padded_shape], 0)
+ return gen_array_ops.empty(batch_shape, dtype=dataset.output_types)
+
+ def batch_reduce_fn(state, value):
+ return array_ops.concat([state, [value]], 0)
+
+ def pad_fn(value):
+ shape = array_ops.shape(value)
+ left = array_ops.zeros_like(shape)
+ right = padded_shape - shape
+ return array_ops.pad(
+ value, array_ops.stack([left, right], 1), constant_values=padding_value)
+
+ batch_reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn)
+ return get_single_element.get_single_element(
+ dataset.map(pad_fn).apply(
+ grouping.group_by_reducer(key_fn, batch_reducer)))
+
+
+def _padded_batch_sparse_window(dataset, padded_shape):
+ """Batches a window of sparse tensors with padding."""
+
+ def key_fn(_):
+ return np.int64(0)
+
+ def max_init_fn(_):
+ return convert.partial_shape_to_tensor(padded_shape)
+
+ def max_reduce_fn(state, value):
+ """Computes the maximum shape to pad to."""
+ condition = math_ops.reduce_all(
+ math_ops.logical_or(
+ math_ops.less_equal(value.dense_shape, padded_shape),
+ math_ops.equal(padded_shape, -1)))
+ assert_op = control_flow_ops.Assert(condition, [
+ "Actual shape greater than padded shape: ", value.dense_shape,
+ padded_shape
+ ])
+ with ops.control_dependencies([assert_op]):
+ return math_ops.maximum(state, value.dense_shape)
+
+ def finalize_fn(state):
+ return state
+
+ # Compute the padded shape.
+ max_reducer = grouping.Reducer(max_init_fn, max_reduce_fn, finalize_fn)
+ padded_shape = get_single_element.get_single_element(
+ dataset.apply(grouping.group_by_reducer(key_fn, max_reducer)))
+
+ def batch_init_fn(_):
+ indices_shape = array_ops.concat([[0], [array_ops.size(padded_shape) + 1]],
+ 0)
+ return sparse_tensor.SparseTensor(
+ indices=gen_array_ops.empty(indices_shape, dtype=dtypes.int64),
+ values=constant_op.constant([], shape=[0], dtype=dataset.output_types),
+ dense_shape=array_ops.concat(
+ [np.array([0], dtype=np.int64), padded_shape], 0))
+
+ def batch_reduce_fn(state, value):
+ padded_value = sparse_tensor.SparseTensor(
+ indices=value.indices, values=value.values, dense_shape=padded_shape)
+ reshaped_value = sparse_ops.sparse_reshape(
+ padded_value,
+ array_ops.concat(
+ [np.array([1], dtype=np.int64), padded_value.dense_shape], 0))
+ return sparse_ops.sparse_concat(0, [state, reshaped_value])
+
+ reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn)
+ return get_single_element.get_single_element(
+ dataset.apply(grouping.group_by_reducer(key_fn, reducer)))
+
+
+class _UnbatchDataset(dataset_ops.UnaryDataset):
+ """A dataset that splits the elements of its input into multiple elements."""
+
+ def __init__(self, input_dataset):
+ """See `unbatch()` for more details."""
+ super(_UnbatchDataset, self).__init__(input_dataset)
+ flat_shapes = nest.flatten(input_dataset.output_shapes)
+ if any(s.ndims == 0 for s in flat_shapes):
+ raise ValueError("Cannot unbatch an input with scalar components.")
+ known_batch_dim = tensor_shape.Dimension(None)
+ for s in flat_shapes:
+ try:
+ known_batch_dim = known_batch_dim.merge_with(s[0])
+ except ValueError:
+ raise ValueError("Cannot unbatch an input whose components have "
+ "different batch sizes.")
+ self._input_dataset = input_dataset
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.unbatch_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+ @property
+ def output_shapes(self):
+ return nest.map_structure(lambda s: s[1:],
+ self._input_dataset.output_shapes)
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+
+@tf_export("data.experimental.unbatch")
+def unbatch():
+ """Splits elements of a dataset into multiple elements on the batch dimension.
+
+ For example, if elements of the dataset are shaped `[B, a0, a1, ...]`,
+ where `B` may vary for each input element, then for each element in the
+ dataset, the unbatched dataset will contain `B` consecutive elements
+ of shape `[a0, a1, ...]`.
+
+ ```python
+ # NOTE: The following example uses `{ ... }` to represent the contents
+ # of a dataset.
+ a = { ['a', 'b', 'c'], ['a', 'b'], ['a', 'b', 'c', 'd'] }
+
+ a.apply(tf.data.experimental.unbatch()) == {
+ 'a', 'b', 'c', 'a', 'b', 'a', 'b', 'c', 'd'}
+ ```
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+
+ def _apply_fn(dataset):
+ """Function from `Dataset` to `Dataset` that applies the transformation."""
+ if not sparse.any_sparse(dataset.output_classes):
+ return _UnbatchDataset(dataset)
+
+ # NOTE(mrry): We must ensure that any SparseTensors in `dataset`
+ # are normalized to the rank-1 dense representation, so that the
+ # sparse-oblivious unbatching logic will slice them
+ # appropriately. This leads to a somewhat inefficient re-encoding step
+ # for all SparseTensor components.
+ # TODO(mrry): Consider optimizing this in future
+ # if it turns out to be a bottleneck.
+ def normalize(arg, *rest):
+ if rest:
+ return sparse.serialize_many_sparse_tensors((arg,) + rest)
+ else:
+ return sparse.serialize_many_sparse_tensors(arg)
+
+ normalized_dataset = dataset.map(normalize)
+
+ # NOTE(mrry): Our `map()` has lost information about the sparseness
+ # of any SparseTensor components, so re-apply the structure of the
+ # original dataset.
+ restructured_dataset = _RestructuredDataset(
+ normalized_dataset,
+ dataset.output_types,
+ dataset.output_shapes,
+ dataset.output_classes,
+ allow_unsafe_cast=True)
+ return _UnbatchDataset(restructured_dataset)
+
+ return _apply_fn
+
+
+class _DenseToSparseBatchDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` that batches ragged dense elements into `tf.SparseTensor`s."""
+
+ def __init__(self, input_dataset, batch_size, row_shape):
+ """See `Dataset.dense_to_sparse_batch()` for more details."""
+ super(_DenseToSparseBatchDataset, self).__init__(input_dataset)
+ if not isinstance(input_dataset.output_types, dtypes.DType):
+ raise TypeError("DenseToSparseDataset requires an input whose elements "
+ "have a single component, whereas the input has %r." %
+ input_dataset.output_types)
+ self._input_dataset = input_dataset
+ self._batch_size = batch_size
+ self._row_shape = row_shape
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.dense_to_sparse_batch_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._batch_size,
+ row_shape=convert.partial_shape_to_tensor(self._row_shape),
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_classes(self):
+ return sparse_tensor.SparseTensor
+
+ @property
+ def output_shapes(self):
+ return tensor_shape.vector(None).concatenate(self._row_shape)
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+
+class _RestructuredDataset(dataset_ops.UnaryDataset):
+ """An internal helper for changing the structure and shape of a dataset."""
+
+ def __init__(self,
+ dataset,
+ output_types,
+ output_shapes=None,
+ output_classes=None,
+ allow_unsafe_cast=False):
+ """Creates a new dataset with the given output types and shapes.
+
+ The given `dataset` must have a structure that is convertible:
+ * `dataset.output_types` must be the same as `output_types` module nesting.
+ * Each shape in `dataset.output_shapes` must be compatible with each shape
+ in `output_shapes` (if given).
+
+ Note: This helper permits "unsafe casts" for shapes, equivalent to using
+ `tf.Tensor.set_shape()` where domain-specific knowledge is available.
+
+ Args:
+ dataset: A `Dataset` object.
+ output_types: A nested structure of `tf.DType` objects.
+ output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects.
+ If omitted, the shapes will be inherited from `dataset`.
+ output_classes: (Optional.) A nested structure of class types.
+ If omitted, the class types will be inherited from `dataset`.
+ allow_unsafe_cast: (Optional.) If `True`, the caller may switch the
+ reported output types and shapes of the restructured dataset, e.g. to
+ switch a sparse tensor represented as `tf.variant` to its user-visible
+ type and shape.
+
+ Raises:
+ ValueError: If either `output_types` or `output_shapes` is not compatible
+ with the structure of `dataset`.
+ """
+ super(_RestructuredDataset, self).__init__(dataset)
+ self._input_dataset = dataset
+
+ if not allow_unsafe_cast:
+ # Validate that the types are compatible.
+ output_types = nest.map_structure(dtypes.as_dtype, output_types)
+ flat_original_types = nest.flatten(dataset.output_types)
+ flat_new_types = nest.flatten(output_types)
+ if flat_original_types != flat_new_types:
+ raise ValueError(
+ "Dataset with output types %r cannot be restructured to have "
+ "output types %r" % (dataset.output_types, output_types))
+
+ self._output_types = output_types
+
+ if output_shapes is None:
+ # Inherit shapes from the original `dataset`.
+ self._output_shapes = nest.pack_sequence_as(output_types,
+ nest.flatten(
+ dataset.output_shapes))
+ else:
+ if not allow_unsafe_cast:
+ # Validate that the shapes are compatible.
+ nest.assert_same_structure(output_types, output_shapes)
+ flat_original_shapes = nest.flatten(dataset.output_shapes)
+ flat_new_shapes = nest.flatten_up_to(output_types, output_shapes)
+
+ for original_shape, new_shape in zip(flat_original_shapes,
+ flat_new_shapes):
+ if not original_shape.is_compatible_with(new_shape):
+ raise ValueError(
+ "Dataset with output shapes %r cannot be restructured to have "
+ "incompatible output shapes %r" % (dataset.output_shapes,
+ output_shapes))
+ self._output_shapes = nest.map_structure_up_to(
+ output_types, tensor_shape.as_shape, output_shapes)
+ if output_classes is None:
+ # Inherit class types from the original `dataset`.
+ self._output_classes = nest.pack_sequence_as(output_types,
+ nest.flatten(
+ dataset.output_classes))
+ else:
+ self._output_classes = output_classes
+
+ def _as_variant_tensor(self):
+ return self._input_dataset._as_variant_tensor() # pylint: disable=protected-access
+
+ @property
+ def output_classes(self):
+ return self._output_classes
+
+ @property
+ def output_types(self):
+ return self._output_types
+
+ @property
+ def output_shapes(self):
+ return self._output_shapes
+
+
+class _MapAndBatchDataset(dataset_ops.MapDataset):
+ """A `Dataset` that maps a function over a batch of elements."""
+
+ def __init__(self, input_dataset, map_func, batch_size, num_parallel_calls,
+ drop_remainder):
+ """See `Dataset.map()` for details."""
+ super(_MapAndBatchDataset, self).__init__(input_dataset, map_func)
+ self._batch_size_t = ops.convert_to_tensor(
+ batch_size, dtype=dtypes.int64, name="batch_size")
+ self._num_parallel_calls_t = ops.convert_to_tensor(
+ num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls")
+ self._drop_remainder_t = ops.convert_to_tensor(
+ drop_remainder, dtype=dtypes.bool, name="drop_remainder")
+
+ self._batch_size = batch_size
+ self._drop_remainder = drop_remainder
+
+ def _as_variant_tensor(self):
+ # pylint: disable=protected-access
+ input_resource = self._input_dataset._as_variant_tensor()
+ return gen_dataset_ops.map_and_batch_dataset_v2(
+ input_resource,
+ self._map_func.captured_inputs,
+ f=self._map_func,
+ batch_size=self._batch_size_t,
+ num_parallel_calls=self._num_parallel_calls_t,
+ drop_remainder=self._drop_remainder_t,
+ **dataset_ops.flat_structure(self))
+ # pylint: enable=protected-access
+
+ @property
+ def output_shapes(self):
+ dim = self._batch_size if self._drop_remainder else None
+ return nest.pack_sequence_as(self._output_shapes, [
+ tensor_shape.vector(dim).concatenate(s)
+ for s in nest.flatten(self._output_shapes)
+ ])
+
+ @property
+ def output_types(self):
+ return self._output_types
+
+
+@tf_export("data.experimental.map_and_batch")
+def map_and_batch(map_func,
+ batch_size,
+ num_parallel_batches=None,
+ drop_remainder=False,
+ num_parallel_calls=None):
+ """Fused implementation of `map` and `batch`.
+
+ Maps `map_func` across `batch_size` consecutive elements of this dataset
+ and then combines them into a batch. Functionally, it is equivalent to `map`
+ followed by `batch`. However, by fusing the two transformations together, the
+ implementation can be more efficient. Surfacing this transformation in the API
+ is temporary. Once automatic input pipeline optimization is implemented,
+ the fusing of `map` and `batch` will happen automatically and this API will be
+ deprecated.
+
+ Args:
+ map_func: A function mapping a nested structure of tensors to another
+ nested structure of tensors.
+ batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
+ consecutive elements of this dataset to combine in a single batch.
+ num_parallel_batches: (Optional.) A `tf.int64` scalar `tf.Tensor`,
+ representing the number of batches to create in parallel. On one hand,
+ higher values can help mitigate the effect of stragglers. On the other
+ hand, higher values can increase contention if CPU is scarce.
+ drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
+ whether the last batch should be dropped in case its size is smaller than
+ desired; the default behavior is not to drop the smaller batch.
+ num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`,
+ representing the number of elements to process in parallel. If not
+ specified, `batch_size * num_parallel_batches` elements will be
+ processed in parallel.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+
+ Raises:
+ ValueError: If both `num_parallel_batches` and `num_parallel_calls` are
+ specified.
+ """
+
+ if num_parallel_batches is None and num_parallel_calls is None:
+ num_parallel_calls = batch_size
+ elif num_parallel_batches is not None and num_parallel_calls is None:
+ num_parallel_calls = batch_size * num_parallel_batches
+ elif num_parallel_batches is not None and num_parallel_calls is not None:
+ raise ValueError("The `num_parallel_batches` and `num_parallel_calls` "
+ "arguments are mutually exclusive.")
+
+ def _apply_fn(dataset):
+ return _MapAndBatchDataset(dataset, map_func, batch_size,
+ num_parallel_calls, drop_remainder)
+
+ return _apply_fn
diff --git a/tensorflow/python/data/experimental/ops/counter.py b/tensorflow/python/data/experimental/ops/counter.py
new file mode 100644
index 0000000000..42200eaef9
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/counter.py
@@ -0,0 +1,55 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""The Counter Dataset."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.ops import scan_ops
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("data.experimental.Counter")
+def Counter(start=0, step=1, dtype=dtypes.int64):
+ """Creates a `Dataset` that counts from `start` in steps of size `step`.
+
+ For example:
+
+ ```python
+ Dataset.count() == [0, 1, 2, ...)
+ Dataset.count(2) == [2, 3, ...)
+ Dataset.count(2, 5) == [2, 7, 12, ...)
+ Dataset.count(0, -1) == [0, -1, -2, ...)
+ Dataset.count(10, -1) == [10, 9, ...)
+ ```
+
+ Args:
+ start: (Optional.) The starting value for the counter. Defaults to 0.
+ step: (Optional.) The step size for the counter. Defaults to 1.
+ dtype: (Optional.) The data type for counter elements. Defaults to
+ `tf.int64`.
+
+ Returns:
+ A `Dataset` of scalar `dtype` elements.
+ """
+ with ops.name_scope("counter"):
+ start = ops.convert_to_tensor(start, dtype=dtype, name="start")
+ step = ops.convert_to_tensor(step, dtype=dtype, name="step")
+ return dataset_ops.Dataset.from_tensors(0).repeat(None).apply(
+ scan_ops.scan(start, lambda state, _: (state + step, state)))
diff --git a/tensorflow/python/data/experimental/ops/enumerate_ops.py b/tensorflow/python/data/experimental/ops/enumerate_ops.py
new file mode 100644
index 0000000000..a1af98f552
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/enumerate_ops.py
@@ -0,0 +1,60 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Enumerate dataset transformations."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("data.experimental.enumerate_dataset")
+def enumerate_dataset(start=0):
+ """A transformation that enumerate the elements of a dataset.
+
+ It is Similar to python's `enumerate`.
+ For example:
+
+ ```python
+ # NOTE: The following examples use `{ ... }` to represent the
+ # contents of a dataset.
+ a = { 1, 2, 3 }
+ b = { (7, 8), (9, 10) }
+
+ # The nested structure of the `datasets` argument determines the
+ # structure of elements in the resulting dataset.
+ a.apply(tf.data.experimental.enumerate(start=5)) == { (5, 1), (6, 2), (7, 3) }
+ b.apply(tf.data.experimental.enumerate()) == { (0, (7, 8)), (1, (9, 10)) }
+ ```
+
+ Args:
+ start: A `tf.int64` scalar `tf.Tensor`, representing the start
+ value for enumeration.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+
+ def _apply_fn(dataset):
+ max_value = np.iinfo(dtypes.int64.as_numpy_dtype).max
+ return dataset_ops.Dataset.zip((dataset_ops.Dataset.range(start, max_value),
+ dataset))
+
+ return _apply_fn
diff --git a/tensorflow/python/data/experimental/ops/error_ops.py b/tensorflow/python/data/experimental/ops/error_ops.py
new file mode 100644
index 0000000000..82e274b70c
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/error_ops.py
@@ -0,0 +1,78 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Ignore_errors dataset transformations."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.ops import gen_experimental_dataset_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("data.experimental.ignore_errors")
+def ignore_errors():
+ """Creates a `Dataset` from another `Dataset` and silently ignores any errors.
+
+ Use this transformation to produce a dataset that contains the same elements
+ as the input, but silently drops any elements that caused an error. For
+ example:
+
+ ```python
+ dataset = tf.data.Dataset.from_tensor_slices([1., 2., 0., 4.])
+
+ # Computing `tf.check_numerics(1. / 0.)` will raise an InvalidArgumentError.
+ dataset = dataset.map(lambda x: tf.check_numerics(1. / x, "error"))
+
+ # Using `ignore_errors()` will drop the element that causes an error.
+ dataset =
+ dataset.apply(tf.data.experimental.ignore_errors()) # ==> {1., 0.5, 0.2}
+ ```
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+
+ def _apply_fn(dataset):
+ return _IgnoreErrorsDataset(dataset)
+
+ return _apply_fn
+
+
+class _IgnoreErrorsDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` that silently ignores errors when computing its input."""
+
+ def __init__(self, input_dataset):
+ """See `Dataset.ignore_errors()` for details."""
+ super(_IgnoreErrorsDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+
+ def _as_variant_tensor(self):
+ return gen_experimental_dataset_ops.experimental_ignore_errors_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
diff --git a/tensorflow/python/data/experimental/ops/get_single_element.py b/tensorflow/python/data/experimental/ops/get_single_element.py
new file mode 100644
index 0000000000..132526166c
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/get_single_element.py
@@ -0,0 +1,72 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python wrappers for Datasets and Iterators."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import nest
+from tensorflow.python.data.util import sparse
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("data.experimental.get_single_element")
+def get_single_element(dataset):
+ """Returns the single element in `dataset` as a nested structure of tensors.
+
+ This function enables you to use a `tf.data.Dataset` in a stateless
+ "tensor-in tensor-out" expression, without creating a `tf.data.Iterator`.
+ This can be useful when your preprocessing transformations are expressed
+ as a `Dataset`, and you want to use the transformation at serving time.
+ For example:
+
+ ```python
+ input_batch = tf.placeholder(tf.string, shape=[BATCH_SIZE])
+
+ def preprocessing_fn(input_str):
+ # ...
+ return image, label
+
+ dataset = (tf.data.Dataset.from_tensor_slices(input_batch)
+ .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE)
+ .batch(BATCH_SIZE))
+
+ image_batch, label_batch = tf.data.experimental.get_single_element(dataset)
+ ```
+
+ Args:
+ dataset: A `tf.data.Dataset` object containing a single element.
+
+ Returns:
+ A nested structure of `tf.Tensor` objects, corresponding to the single
+ element of `dataset`.
+
+ Raises:
+ TypeError: if `dataset` is not a `tf.data.Dataset` object.
+ InvalidArgumentError (at runtime): if `dataset` does not contain exactly
+ one element.
+ """
+ if not isinstance(dataset, dataset_ops.Dataset):
+ raise TypeError("`dataset` must be a `tf.data.Dataset` object.")
+
+ nested_ret = nest.pack_sequence_as(
+ dataset.output_types, gen_dataset_ops.dataset_to_single_element(
+ dataset._as_variant_tensor(), # pylint: disable=protected-access
+ **dataset_ops.flat_structure(dataset)))
+ return sparse.deserialize_sparse_tensors(
+ nested_ret, dataset.output_types, dataset.output_shapes,
+ dataset.output_classes)
diff --git a/tensorflow/python/data/experimental/ops/grouping.py b/tensorflow/python/data/experimental/ops/grouping.py
new file mode 100644
index 0000000000..18ba583220
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/grouping.py
@@ -0,0 +1,551 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Grouping dataset transformations."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import nest
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("data.experimental.group_by_reducer")
+def group_by_reducer(key_func, reducer):
+ """A transformation that groups elements and performs a reduction.
+
+ This transformation maps element of a dataset to a key using `key_func` and
+ groups the elements by key. The `reducer` is used to process each group; its
+ `init_func` is used to initialize state for each group when it is created, the
+ `reduce_func` is used to update the state every time an element is mapped to
+ the matching group, and the `finalize_func` is used to map the final state to
+ an output value.
+
+ Args:
+ key_func: A function mapping a nested structure of tensors
+ (having shapes and types defined by `self.output_shapes` and
+ `self.output_types`) to a scalar `tf.int64` tensor.
+ reducer: An instance of `Reducer`, which captures the reduction logic using
+ the `init_func`, `reduce_func`, and `finalize_func` functions.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+
+ def _apply_fn(dataset):
+ """Function from `Dataset` to `Dataset` that applies the transformation."""
+ return _GroupByReducerDataset(dataset, key_func, reducer)
+
+ return _apply_fn
+
+
+@tf_export("data.experimental.group_by_window")
+def group_by_window(key_func,
+ reduce_func,
+ window_size=None,
+ window_size_func=None):
+ """A transformation that groups windows of elements by key and reduces them.
+
+ This transformation maps each consecutive element in a dataset to a key
+ using `key_func` and groups the elements by key. It then applies
+ `reduce_func` to at most `window_size_func(key)` elements matching the same
+ key. All except the final window for each key will contain
+ `window_size_func(key)` elements; the final window may be smaller.
+
+ You may provide either a constant `window_size` or a window size determined by
+ the key through `window_size_func`.
+
+ Args:
+ key_func: A function mapping a nested structure of tensors
+ (having shapes and types defined by `self.output_shapes` and
+ `self.output_types`) to a scalar `tf.int64` tensor.
+ reduce_func: A function mapping a key and a dataset of up to `window_size`
+ consecutive elements matching that key to another dataset.
+ window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
+ consecutive elements matching the same key to combine in a single
+ batch, which will be passed to `reduce_func`. Mutually exclusive with
+ `window_size_func`.
+ window_size_func: A function mapping a key to a `tf.int64` scalar
+ `tf.Tensor`, representing the number of consecutive elements matching
+ the same key to combine in a single batch, which will be passed to
+ `reduce_func`. Mutually exclusive with `window_size`.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+
+ Raises:
+ ValueError: if neither or both of {`window_size`, `window_size_func`} are
+ passed.
+ """
+ if (window_size is not None and window_size_func or
+ not (window_size is not None or window_size_func)):
+ raise ValueError("Must pass either window_size or window_size_func.")
+
+ if window_size is not None:
+
+ def constant_window_func(unused_key):
+ return ops.convert_to_tensor(window_size, dtype=dtypes.int64)
+
+ window_size_func = constant_window_func
+
+ assert window_size_func is not None
+
+ def _apply_fn(dataset):
+ """Function from `Dataset` to `Dataset` that applies the transformation."""
+ return _GroupByWindowDataset(dataset, key_func, reduce_func,
+ window_size_func)
+
+ return _apply_fn
+
+
+@tf_export("data.experimental.bucket_by_sequence_length")
+def bucket_by_sequence_length(element_length_func,
+ bucket_boundaries,
+ bucket_batch_sizes,
+ padded_shapes=None,
+ padding_values=None,
+ pad_to_bucket_boundary=False,
+ no_padding=False):
+ """A transformation that buckets elements in a `Dataset` by length.
+
+ Elements of the `Dataset` are grouped together by length and then are padded
+ and batched.
+
+ This is useful for sequence tasks in which the elements have variable length.
+ Grouping together elements that have similar lengths reduces the total
+ fraction of padding in a batch which increases training step efficiency.
+
+ Args:
+ element_length_func: function from element in `Dataset` to `tf.int32`,
+ determines the length of the element, which will determine the bucket it
+ goes into.
+ bucket_boundaries: `list<int>`, upper length boundaries of the buckets.
+ bucket_batch_sizes: `list<int>`, batch size per bucket. Length should be
+ `len(bucket_boundaries) + 1`.
+ padded_shapes: Nested structure of `tf.TensorShape` to pass to
+ `tf.data.Dataset.padded_batch`. If not provided, will use
+ `dataset.output_shapes`, which will result in variable length dimensions
+ being padded out to the maximum length in each batch.
+ padding_values: Values to pad with, passed to
+ `tf.data.Dataset.padded_batch`. Defaults to padding with 0.
+ pad_to_bucket_boundary: bool, if `False`, will pad dimensions with unknown
+ size to maximum length in batch. If `True`, will pad dimensions with
+ unknown size to bucket boundary minus 1 (i.e., the maximum length in each
+ bucket), and caller must ensure that the source `Dataset` does not contain
+ any elements with length longer than `max(bucket_boundaries)`.
+ no_padding: `bool`, indicates whether to pad the batch features (features
+ need to be either of type `tf.SparseTensor` or of same shape).
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+
+ Raises:
+ ValueError: if `len(bucket_batch_sizes) != len(bucket_boundaries) + 1`.
+ """
+ with ops.name_scope("bucket_by_seq_length"):
+ if len(bucket_batch_sizes) != (len(bucket_boundaries) + 1):
+ raise ValueError(
+ "len(bucket_batch_sizes) must equal len(bucket_boundaries) + 1")
+
+ batch_sizes = constant_op.constant(bucket_batch_sizes, dtype=dtypes.int64)
+
+ def element_to_bucket_id(*args):
+ """Return int64 id of the length bucket for this element."""
+ seq_length = element_length_func(*args)
+
+ boundaries = list(bucket_boundaries)
+ buckets_min = [np.iinfo(np.int32).min] + boundaries
+ buckets_max = boundaries + [np.iinfo(np.int32).max]
+ conditions_c = math_ops.logical_and(
+ math_ops.less_equal(buckets_min, seq_length),
+ math_ops.less(seq_length, buckets_max))
+ bucket_id = math_ops.reduce_min(array_ops.where(conditions_c))
+
+ return bucket_id
+
+ def window_size_fn(bucket_id):
+ # The window size is set to the batch size for this bucket
+ window_size = batch_sizes[bucket_id]
+ return window_size
+
+ def make_padded_shapes(shapes, none_filler=None):
+ padded = []
+ for shape in nest.flatten(shapes):
+ shape = tensor_shape.TensorShape(shape)
+ shape = [
+ none_filler if d.value is None else d
+ for d in shape
+ ]
+ padded.append(shape)
+ return nest.pack_sequence_as(shapes, padded)
+
+ def batching_fn(bucket_id, grouped_dataset):
+ """Batch elements in dataset."""
+ batch_size = window_size_fn(bucket_id)
+ if no_padding:
+ return grouped_dataset.batch(batch_size)
+ none_filler = None
+ if pad_to_bucket_boundary:
+ err_msg = ("When pad_to_bucket_boundary=True, elements must have "
+ "length < max(bucket_boundaries).")
+ check = check_ops.assert_less(
+ bucket_id,
+ constant_op.constant(len(bucket_batch_sizes) - 1,
+ dtype=dtypes.int64),
+ message=err_msg)
+ with ops.control_dependencies([check]):
+ boundaries = constant_op.constant(bucket_boundaries,
+ dtype=dtypes.int64)
+ bucket_boundary = boundaries[bucket_id]
+ none_filler = bucket_boundary - 1
+ shapes = make_padded_shapes(
+ padded_shapes or grouped_dataset.output_shapes,
+ none_filler=none_filler)
+ return grouped_dataset.padded_batch(batch_size, shapes, padding_values)
+
+ def _apply_fn(dataset):
+ return dataset.apply(
+ group_by_window(element_to_bucket_id, batching_fn,
+ window_size_func=window_size_fn))
+
+ return _apply_fn
+
+
+def _map_x_dataset(map_func):
+ """A transformation that maps `map_func` across its input.
+
+ This transformation is similar to `tf.data.Dataset.map`, but in addition to
+ supporting dense and sparse tensor inputs, it also supports dataset inputs.
+
+ Args:
+ map_func: A function mapping a nested structure of tensors and/or datasets
+ (having shapes and types defined by `self.output_shapes` and
+ `self.output_types`) to another nested structure of tensors and/or
+ datasets.
+
+ Returns:
+ Dataset: A `Dataset`.
+ """
+
+ def _apply_fn(dataset):
+ """Function from `Dataset` to `Dataset` that applies the transformation."""
+ return _MapXDataset(dataset, map_func)
+
+ return _apply_fn
+
+
+class _GroupByReducerDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` that groups its input and performs a reduction."""
+
+ def __init__(self, input_dataset, key_func, reducer):
+ """See `group_by_reducer()` for details."""
+ super(_GroupByReducerDataset, self).__init__(input_dataset)
+
+ self._input_dataset = input_dataset
+
+ self._make_key_func(key_func, input_dataset)
+ self._make_init_func(reducer.init_func)
+ self._make_reduce_func(reducer.reduce_func, input_dataset)
+ self._make_finalize_func(reducer.finalize_func)
+
+ def _make_key_func(self, key_func, input_dataset):
+ """Make wrapping Defun for key_func."""
+ wrapped_func = dataset_ops.StructuredFunctionWrapper(
+ key_func, "tf.data.experimental.group_by_reducer()", input_dataset)
+ if not (
+ wrapped_func.output_types == dtypes.int64 and
+ wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())):
+ raise ValueError(
+ "`key_func` must return a single tf.int64 tensor. "
+ "Got type=%s and shape=%s"
+ % (wrapped_func.output_types, wrapped_func.output_shapes))
+ self._key_func = wrapped_func.function
+
+ def _make_init_func(self, init_func):
+ """Make wrapping Defun for init_func."""
+ wrapped_func = dataset_ops.StructuredFunctionWrapper(
+ init_func,
+ "tf.data.experimental.group_by_reducer()",
+ input_classes=ops.Tensor,
+ input_shapes=tensor_shape.scalar(),
+ input_types=dtypes.int64)
+ self._init_func = wrapped_func.function
+ self._state_classes = wrapped_func.output_classes
+ self._state_shapes = wrapped_func.output_shapes
+ self._state_types = wrapped_func.output_types
+
+ def _make_reduce_func(self, reduce_func, input_dataset):
+ """Make wrapping Defun for reduce_func."""
+
+ # Iteratively rerun the reduce function until reaching a fixed point on
+ # `self._state_shapes`.
+ need_to_rerun = True
+ while need_to_rerun:
+
+ wrapped_func = dataset_ops.StructuredFunctionWrapper(
+ reduce_func,
+ "tf.data.experimental.group_by_reducer()",
+ input_classes=(self._state_classes, input_dataset.output_classes),
+ input_shapes=(self._state_shapes, input_dataset.output_shapes),
+ input_types=(self._state_types, input_dataset.output_types),
+ add_to_graph=False)
+
+ # Extract and validate class information from the returned values.
+ for new_state_class, state_class in zip(
+ nest.flatten(wrapped_func.output_classes),
+ nest.flatten(self._state_classes)):
+ if not issubclass(new_state_class, state_class):
+ raise TypeError(
+ "The element classes for the new state must match the initial "
+ "state. Expected %s; got %s." %
+ (self._state_classes, wrapped_func.output_classes))
+
+ # Extract and validate type information from the returned values.
+ for new_state_type, state_type in zip(
+ nest.flatten(wrapped_func.output_types),
+ nest.flatten(self._state_types)):
+ if new_state_type != state_type:
+ raise TypeError(
+ "The element types for the new state must match the initial "
+ "state. Expected %s; got %s." %
+ (self._state_types, wrapped_func.output_types))
+
+ # Extract shape information from the returned values.
+ flat_state_shapes = nest.flatten(self._state_shapes)
+ flat_new_state_shapes = nest.flatten(wrapped_func.output_shapes)
+ weakened_state_shapes = [
+ original.most_specific_compatible_shape(new)
+ for original, new in zip(flat_state_shapes, flat_new_state_shapes)
+ ]
+
+ need_to_rerun = False
+ for original_shape, weakened_shape in zip(flat_state_shapes,
+ weakened_state_shapes):
+ if original_shape.ndims is not None and (
+ weakened_shape.ndims is None or
+ original_shape.as_list() != weakened_shape.as_list()):
+ need_to_rerun = True
+ break
+
+ if need_to_rerun:
+ self._state_shapes = nest.pack_sequence_as(self._state_shapes,
+ weakened_state_shapes)
+
+ self._reduce_func = wrapped_func.function
+ self._reduce_func.add_to_graph(ops.get_default_graph())
+
+ def _make_finalize_func(self, finalize_func):
+ """Make wrapping Defun for finalize_func."""
+ wrapped_func = dataset_ops.StructuredFunctionWrapper(
+ finalize_func,
+ "tf.data.experimental.group_by_reducer()",
+ input_classes=self._state_classes,
+ input_shapes=self._state_shapes,
+ input_types=self._state_types)
+ self._finalize_func = wrapped_func.function
+ self._output_classes = wrapped_func.output_classes
+ self._output_shapes = wrapped_func.output_shapes
+ self._output_types = wrapped_func.output_types
+
+ @property
+ def output_classes(self):
+ return self._output_classes
+
+ @property
+ def output_shapes(self):
+ return self._output_shapes
+
+ @property
+ def output_types(self):
+ return self._output_types
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.group_by_reducer_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._key_func.captured_inputs,
+ self._init_func.captured_inputs,
+ self._reduce_func.captured_inputs,
+ self._finalize_func.captured_inputs,
+ key_func=self._key_func,
+ init_func=self._init_func,
+ reduce_func=self._reduce_func,
+ finalize_func=self._finalize_func,
+ **dataset_ops.flat_structure(self))
+
+
+class _GroupByWindowDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` that groups its input and performs a windowed reduction."""
+
+ def __init__(self, input_dataset, key_func, reduce_func, window_size_func):
+ """See `group_by_window()` for details."""
+ super(_GroupByWindowDataset, self).__init__(input_dataset)
+
+ self._input_dataset = input_dataset
+
+ self._make_key_func(key_func, input_dataset)
+ self._make_reduce_func(reduce_func, input_dataset)
+ self._make_window_size_func(window_size_func)
+
+ def _make_window_size_func(self, window_size_func):
+ """Make wrapping Defun for window_size_func."""
+ def window_size_func_wrapper(key):
+ return ops.convert_to_tensor(window_size_func(key), dtype=dtypes.int64)
+ wrapped_func = dataset_ops.StructuredFunctionWrapper(
+ window_size_func_wrapper,
+ "tf.data.experimental.group_by_window()",
+ input_classes=ops.Tensor,
+ input_shapes=tensor_shape.scalar(),
+ input_types=dtypes.int64)
+ if not (
+ wrapped_func.output_types == dtypes.int64 and
+ wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())):
+ raise ValueError(
+ "`window_size_func` must return a single tf.int64 scalar tensor.")
+ self._window_size_func = wrapped_func.function
+
+ def _make_key_func(self, key_func, input_dataset):
+ """Make wrapping Defun for key_func."""
+ def key_func_wrapper(*args):
+ return ops.convert_to_tensor(key_func(*args), dtype=dtypes.int64)
+ wrapped_func = dataset_ops.StructuredFunctionWrapper(
+ key_func_wrapper, "tf.data.experimental.group_by_window()",
+ input_dataset)
+ if not (
+ wrapped_func.output_types == dtypes.int64 and
+ wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())):
+ raise ValueError(
+ "`key_func` must return a single tf.int64 scalar tensor.")
+ self._key_func = wrapped_func.function
+
+ def _make_reduce_func(self, reduce_func, input_dataset):
+ """Make wrapping Defun for reduce_func."""
+ nested_dataset = dataset_ops._NestedDatasetComponent(input_dataset) # pylint: disable=protected-access
+ wrapped_func = dataset_ops.StructuredFunctionWrapper(
+ reduce_func,
+ "tf.data.experimental.reduce_by_window()",
+ input_classes=(ops.Tensor, nested_dataset),
+ input_shapes=(tensor_shape.scalar(), nested_dataset),
+ input_types=(dtypes.int64, nested_dataset),
+ experimental_nested_dataset_support=True)
+ if not isinstance(
+ wrapped_func.output_classes, dataset_ops._NestedDatasetComponent): # pylint: disable=protected-access
+ raise TypeError("`reduce_func` must return a `Dataset` object.")
+ self._output_classes = wrapped_func.output_classes.output_classes
+ self._output_types = wrapped_func.output_types.output_types
+ self._output_shapes = wrapped_func.output_shapes.output_shapes
+ self._reduce_func = wrapped_func.function
+
+ @property
+ def output_classes(self):
+ return self._output_classes
+
+ @property
+ def output_shapes(self):
+ return self._output_shapes
+
+ @property
+ def output_types(self):
+ return self._output_types
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.group_by_window_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._key_func.captured_inputs,
+ self._reduce_func.captured_inputs,
+ self._window_size_func.captured_inputs,
+ key_func=self._key_func,
+ reduce_func=self._reduce_func,
+ window_size_func=self._window_size_func,
+ **dataset_ops.flat_structure(self))
+
+
+@tf_export("data.experimental.Reducer")
+class Reducer(object):
+ """A reducer is used for reducing a set of elements.
+
+ A reducer is represented as a tuple of the three functions:
+ 1) initialization function: key => initial state
+ 2) reduce function: (old state, input) => new state
+ 3) finalization function: state => result
+ """
+
+ def __init__(self, init_func, reduce_func, finalize_func):
+ self._init_func = init_func
+ self._reduce_func = reduce_func
+ self._finalize_func = finalize_func
+
+ @property
+ def init_func(self):
+ return self._init_func
+
+ @property
+ def reduce_func(self):
+ return self._reduce_func
+
+ @property
+ def finalize_func(self):
+ return self._finalize_func
+
+
+class _MapXDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` that maps a function over elements in its input."""
+
+ def __init__(self, input_dataset, map_func):
+ """See `map_x_dataset()` for details."""
+ super(_MapXDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+
+ wrapped_func = dataset_ops.StructuredFunctionWrapper(
+ map_func,
+ "tf.data.experimental.map_x_dataset()",
+ input_dataset,
+ experimental_nested_dataset_support=True)
+ self._output_classes = wrapped_func.output_classes
+ self._output_shapes = wrapped_func.output_shapes
+ self._output_types = wrapped_func.output_types
+ self._map_func = wrapped_func.function
+
+ def _as_variant_tensor(self):
+ input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access
+ return gen_dataset_ops.map_dataset(
+ input_t,
+ self._map_func.captured_inputs,
+ f=self._map_func,
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_classes(self):
+ return self._output_classes
+
+ @property
+ def output_shapes(self):
+ return self._output_shapes
+
+ @property
+ def output_types(self):
+ return self._output_types
diff --git a/tensorflow/python/data/experimental/ops/indexed_dataset_ops.py b/tensorflow/python/data/experimental/ops/indexed_dataset_ops.py
new file mode 100644
index 0000000000..9c06474a2f
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/indexed_dataset_ops.py
@@ -0,0 +1,177 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python wrappers for indexed datasets."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import nest
+from tensorflow.python.data.util import sparse
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
+
+
+class MaterializedIndexedDataset(object):
+ """MaterializedIndexedDataset is highly experimental!
+ """
+
+ def __init__(self, materialized_resource, materializer, output_classes,
+ output_types, output_shapes):
+ self._materialized_resource = materialized_resource
+ self._materializer = materializer
+ self._output_classes = output_classes
+ self._output_types = output_types
+ self._output_shapes = output_shapes
+
+ @property
+ def initializer(self):
+ if self._materializer is not None:
+ return self._materializer
+ raise ValueError("MaterializedDataset does not have a materializer")
+
+ def get(self, index):
+ """Get retrieves a value (or set of values) from the IndexedDataset.
+
+ Args:
+ index: A uint64 scalar or vector tensor with the indices to retrieve.
+
+ Returns:
+ A tensor containing the values corresponding to `index`.
+ """
+ # TODO(saeta): nest.pack_sequence_as(...)
+ return ged_ops.experimental_indexed_dataset_get(
+ self._materialized_resource,
+ index,
+ output_types=nest.flatten(
+ sparse.as_dense_types(self._output_types, self._output_classes)),
+ output_shapes=nest.flatten(
+ sparse.as_dense_types(self._output_shapes, self._output_classes)))
+
+
+class IndexedDataset(dataset_ops.Dataset):
+ """IndexedDataset is highly experimental!
+ """
+
+ def __init__(self):
+ pass
+
+ def materialize(self, shared_name=None, container=None):
+ """Materialize creates a MaterializedIndexedDataset.
+
+ IndexedDatasets can be combined through operations such as TBD. Therefore,
+ they are only materialized when absolutely required.
+
+ Args:
+ shared_name: a string for the shared name to use for the resource.
+ container: a string for the container to store the resource.
+
+ Returns:
+ A MaterializedIndexedDataset.
+ """
+ if container is None:
+ container = ""
+ if shared_name is None:
+ shared_name = ""
+ materialized_resource = (
+ ged_ops.experimental_materialized_index_dataset_handle(
+ container=container,
+ shared_name=shared_name,
+ output_types=nest.flatten(
+ sparse.as_dense_types(self.output_types, self.output_classes)),
+ output_shapes=nest.flatten(
+ sparse.as_dense_types(self.output_shapes,
+ self.output_classes))))
+
+ with ops.colocate_with(materialized_resource):
+ materializer = ged_ops.experimental_indexed_dataset_materialize(
+ self._as_variant_tensor(), materialized_resource)
+ return MaterializedIndexedDataset(materialized_resource, materializer,
+ self.output_classes, self.output_types,
+ self.output_shapes)
+
+ @abc.abstractproperty
+ def output_types(self):
+ """Returns the type of each component of an element of this IndexedDataset.
+
+ Returns:
+ A nested structure of `tf.DType` objects corresponding to each component
+ of an element of this IndexedDataset.
+ """
+ raise NotImplementedError("IndexedDataset.output_types")
+
+ @abc.abstractproperty
+ def output_classes(self):
+ """Returns the class of each component of an element of this IndexedDataset.
+
+ The expected values are `tf.Tensor` and `tf.SparseTensor`.
+
+ Returns:
+ A nested structure of Python `type` objects corresponding to each
+ component of an element of this IndexedDataset.
+ """
+ raise NotImplementedError("IndexedDataset.output_classes")
+
+ @abc.abstractproperty
+ def output_shapes(self):
+ """Returns the shape of each component of an element of this IndexedDataset.
+
+ Returns:
+ A nested structure of `tf.TensorShape` objects corresponding to each
+ component of an element of this IndexedDataset.
+ """
+ raise NotImplementedError("IndexedDataset.output_shapes")
+
+ @abc.abstractmethod
+ def _as_variant_tensor(self):
+ """Creates a `tf.variant` `tf.Tensor` representing this IndexedDataset.
+
+ Returns:
+ A scalar `tf.Tensor` of `tf.variant` type, which represents this
+ IndexedDataset.
+ """
+ raise NotImplementedError("IndexedDataset._as_variant_tensor")
+
+
+class IdentityIndexedDataset(IndexedDataset):
+ """IdentityIndexedDataset is a trivial indexed dataset used for testing.
+ """
+
+ def __init__(self, size):
+ super(IdentityIndexedDataset, self).__init__()
+ # TODO(saeta): Verify _size is a scalar!
+ self._size = ops.convert_to_tensor(size, dtype=dtypes.uint64, name="size")
+
+ @property
+ def output_types(self):
+ return dtypes.uint64
+
+ @property
+ def output_classes(self):
+ return ops.Tensor
+
+ @property
+ def output_shapes(self):
+ return tensor_shape.scalar()
+
+ def _as_variant_tensor(self):
+ return ged_ops.experimental_identity_indexed_dataset(self._size)
+
+ def _inputs(self):
+ return []
diff --git a/tensorflow/python/data/experimental/ops/interleave_ops.py b/tensorflow/python/data/experimental/ops/interleave_ops.py
new file mode 100644
index 0000000000..a3c094859e
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/interleave_ops.py
@@ -0,0 +1,262 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Non-deterministic dataset transformations."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.ops import random_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.ops import readers
+from tensorflow.python.data.util import nest
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_experimental_dataset_ops
+from tensorflow.python.ops import gen_stateless_random_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("data.experimental.parallel_interleave")
+def parallel_interleave(map_func,
+ cycle_length,
+ block_length=1,
+ sloppy=False,
+ buffer_output_elements=None,
+ prefetch_input_elements=None):
+ """A parallel version of the `Dataset.interleave()` transformation.
+
+ `parallel_interleave()` maps `map_func` across its input to produce nested
+ datasets, and outputs their elements interleaved. Unlike
+ `tf.data.Dataset.interleave`, it gets elements from `cycle_length` nested
+ datasets in parallel, which increases the throughput, especially in the
+ presence of stragglers. Furthermore, the `sloppy` argument can be used to
+ improve performance, by relaxing the requirement that the outputs are produced
+ in a deterministic order, and allowing the implementation to skip over nested
+ datasets whose elements are not readily available when requested.
+
+ Example usage:
+
+ ```python
+ # Preprocess 4 files concurrently.
+ filenames = tf.data.Dataset.list_files("/path/to/data/train*.tfrecords")
+ dataset = filenames.apply(
+ tf.data.experimental.parallel_interleave(
+ lambda filename: tf.data.TFRecordDataset(filename),
+ cycle_length=4))
+ ```
+
+ WARNING: If `sloppy` is `True`, the order of produced elements is not
+ deterministic.
+
+ Args:
+ map_func: A function mapping a nested structure of tensors to a `Dataset`.
+ cycle_length: The number of input `Dataset`s to interleave from in parallel.
+ block_length: The number of consecutive elements to pull from an input
+ `Dataset` before advancing to the next input `Dataset`.
+ sloppy: If false, elements are produced in deterministic order. Otherwise,
+ the implementation is allowed, for the sake of expediency, to produce
+ elements in a non-deterministic order.
+ buffer_output_elements: The number of elements each iterator being
+ interleaved should buffer (similar to the `.prefetch()` transformation for
+ each interleaved iterator).
+ prefetch_input_elements: The number of input elements to transform to
+ iterators before they are needed for interleaving.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+ def _apply_fn(dataset):
+ return readers.ParallelInterleaveDataset(
+ dataset, map_func, cycle_length, block_length, sloppy,
+ buffer_output_elements, prefetch_input_elements)
+
+ return _apply_fn
+
+
+class _DirectedInterleaveDataset(dataset_ops.Dataset):
+ """A substitute for `Dataset.interleave()` on a fixed list of datasets."""
+
+ def __init__(self, selector_input, data_inputs):
+ self._selector_input = selector_input
+ self._data_inputs = list(data_inputs)
+
+ for data_input in data_inputs[1:]:
+ if (data_input.output_types != data_inputs[0].output_types or
+ data_input.output_classes != data_inputs[0].output_classes):
+ raise TypeError("All datasets must have the same type and class.")
+
+ def _as_variant_tensor(self):
+ # pylint: disable=protected-access
+ return (
+ gen_experimental_dataset_ops.experimental_directed_interleave_dataset(
+ self._selector_input._as_variant_tensor(), [
+ data_input._as_variant_tensor()
+ for data_input in self._data_inputs
+ ], **dataset_ops.flat_structure(self)))
+ # pylint: enable=protected-access
+
+ def _inputs(self):
+ return [self._selector_input] + self._data_inputs
+
+ @property
+ def output_classes(self):
+ return self._data_inputs[0].output_classes
+
+ @property
+ def output_shapes(self):
+ ret = self._data_inputs[0].output_shapes
+ for data_input in self._data_inputs[1:]:
+ ret = nest.pack_sequence_as(ret, [
+ ts1.most_specific_compatible_shape(ts2) for (ts1, ts2) in zip(
+ nest.flatten(ret), nest.flatten(data_input.output_shapes))
+ ])
+ return ret
+
+ @property
+ def output_types(self):
+ return self._data_inputs[0].output_types
+
+
+@tf_export("data.experimental.sample_from_datasets")
+def sample_from_datasets(datasets, weights=None, seed=None):
+ """Samples elements at random from the datasets in `datasets`.
+
+ Args:
+ datasets: A list of `tf.data.Dataset` objects with compatible structure.
+ weights: (Optional.) A list of `len(datasets)` floating-point values where
+ `weights[i]` represents the probability with which an element should be
+ sampled from `datasets[i]`, or a `tf.data.Dataset` object where each
+ element is such a list. Defaults to a uniform distribution across
+ `datasets`.
+ seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
+ random seed that will be used to create the distribution. See
+ `tf.set_random_seed` for behavior.
+
+ Returns:
+ A dataset that interleaves elements from `datasets` at random, according to
+ `weights` if provided, otherwise with uniform probability.
+
+ Raises:
+ TypeError: If the `datasets` or `weights` arguments have the wrong type.
+ ValueError: If the `weights` argument is specified and does not match the
+ length of the `datasets` element.
+ """
+ num_datasets = len(datasets)
+ if not isinstance(weights, dataset_ops.Dataset):
+ if weights is None:
+ # Select inputs with uniform probability.
+ logits = [[1.0] * num_datasets]
+
+ else:
+ # Use the given `weights` as the probability of choosing the respective
+ # input.
+ weights = ops.convert_to_tensor(weights, name="weights")
+ if weights.dtype not in (dtypes.float32, dtypes.float64):
+ raise TypeError("`weights` must be convertible to a tensor of "
+ "`tf.float32` or `tf.float64` elements.")
+ if not weights.shape.is_compatible_with([num_datasets]):
+ raise ValueError(
+ "`weights` must be a vector of length `len(datasets)`.")
+
+ # The `stateless_multinomial()` op expects log-probabilities, as opposed
+ # to weights.
+ logits = array_ops.expand_dims(math_ops.log(weights, name="logits"), 0)
+
+ # NOTE(mrry): We only specialize when `weights` is not a `Dataset`. When it
+ # is a `Dataset`, it is possible that evaluating it has a side effect the
+ # user depends on.
+ if len(datasets) == 1:
+ return datasets[0]
+
+ def select_dataset_constant_logits(seed):
+ return array_ops.squeeze(
+ gen_stateless_random_ops.stateless_multinomial(logits, 1, seed=seed),
+ axis=[0, 1])
+
+ selector_input = dataset_ops.MapDataset(
+ random_ops.RandomDataset(seed).batch(2),
+ select_dataset_constant_logits,
+ use_inter_op_parallelism=False)
+
+ else:
+ # Use each element of the given `weights` dataset as the probability of
+ # choosing the respective input.
+
+ # The `stateless_multinomial()` op expects log-probabilities, as opposed to
+ # weights.
+ logits_ds = weights.map(lambda *p: math_ops.log(p, name="logits"))
+
+ def select_dataset_varying_logits(logits, seed):
+ return array_ops.squeeze(
+ gen_stateless_random_ops.stateless_multinomial(logits, 1, seed=seed),
+ axis=[0, 1])
+
+ logits_and_seeds = dataset_ops.Dataset.zip(
+ (logits_ds, random_ops.RandomDataset(seed).batch(2)))
+ selector_input = dataset_ops.MapDataset(
+ logits_and_seeds,
+ select_dataset_varying_logits,
+ use_inter_op_parallelism=False)
+
+ return _DirectedInterleaveDataset(selector_input, datasets)
+
+
+@tf_export("data.experimental.choose_from_datasets")
+def choose_from_datasets(datasets, choice_dataset):
+ """Creates a dataset that deterministically chooses elements from `datasets`.
+
+ For example, given the following datasets:
+
+ ```python
+ datasets = [tf.data.Dataset.from_tensors("foo").repeat(),
+ tf.data.Dataset.from_tensors("bar").repeat(),
+ tf.data.Dataset.from_tensors("baz").repeat()]
+
+ # Define a dataset containing `[0, 1, 2, 0, 1, 2, 0, 1, 2]`.
+ choice_dataset = tf.data.Dataset.range(3).repeat(3)
+
+ result = tf.data.experimental.choose_from_datasets(datasets, choice_dataset)
+ ```
+
+ The elements of `result` will be:
+
+ ```
+ "foo", "bar", "baz", "foo", "bar", "baz", "foo", "bar", "baz"
+ ```
+
+ Args:
+ datasets: A list of `tf.data.Dataset` objects with compatible structure.
+ choice_dataset: A `tf.data.Dataset` of scalar `tf.int64` tensors between
+ `0` and `len(datasets) - 1`.
+
+ Returns:
+ A dataset that interleaves elements from `datasets` according to the values
+ of `choice_dataset`.
+
+ Raises:
+ TypeError: If the `datasets` or `choice_dataset` arguments have the wrong
+ type.
+ """
+ if not (choice_dataset.output_types == dtypes.int64
+ and choice_dataset.output_shapes.is_compatible_with(
+ tensor_shape.scalar())
+ and choice_dataset.output_classes == ops.Tensor):
+ raise TypeError("`choice_dataset` must be a dataset of scalar "
+ "`tf.int64` tensors.")
+ return _DirectedInterleaveDataset(choice_dataset, datasets)
diff --git a/tensorflow/python/data/experimental/ops/iterator_ops.py b/tensorflow/python/data/experimental/ops/iterator_ops.py
new file mode 100644
index 0000000000..72d7d58f06
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/iterator_ops.py
@@ -0,0 +1,268 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Iterator ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.data.ops import optional_ops
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.training import basic_session_run_hooks
+from tensorflow.python.training import checkpoint_management
+from tensorflow.python.training import saver as saver_lib
+from tensorflow.python.training import session_run_hook
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("data.experimental.make_saveable_from_iterator")
+def make_saveable_from_iterator(iterator):
+ """Returns a SaveableObject for saving/restore iterator state using Saver.
+
+ Args:
+ iterator: Iterator.
+
+ For example:
+
+ ```python
+ with tf.Graph().as_default():
+ ds = tf.data.Dataset.range(10)
+ iterator = ds.make_initializable_iterator()
+ # Build the iterator SaveableObject.
+ saveable_obj = tf.data.experimental.make_saveable_from_iterator(iterator)
+ # Add the SaveableObject to the SAVEABLE_OBJECTS collection so
+ # it can be automatically saved using Saver.
+ tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, saveable_obj)
+ saver = tf.train.Saver()
+
+ while continue_training:
+ ... Perform training ...
+ if should_save_checkpoint:
+ saver.save()
+ ```
+
+ Note: When restoring the iterator, the existing iterator state is completely
+ discarded. This means that any changes you may have made to the Dataset
+ graph will be discarded as well! This includes the new Dataset graph
+ that you may have built during validation. So, while running validation,
+ make sure to run the initializer for the validation input pipeline after
+ restoring the checkpoint.
+
+ Note: Not all iterators support checkpointing yet. Attempting to save the
+ state of an unsupported iterator will throw an error.
+ """
+ return _Saveable(iterator._iterator_resource) # pylint: disable=protected-access
+
+
+class _Saveable(saver_lib.BaseSaverBuilder.SaveableObject):
+ """SaveableObject for saving/restoring iterator state."""
+
+ def __init__(self, iterator_resource):
+ serialized_iterator = gen_dataset_ops.serialize_iterator(iterator_resource)
+ specs = [
+ saver_lib.BaseSaverBuilder.SaveSpec(serialized_iterator, "",
+ iterator_resource.name + "-state")
+ ]
+ super(_Saveable, self).__init__(iterator_resource, specs,
+ iterator_resource.name)
+
+ def restore(self, restored_tensors, unused_restored_shapes):
+ with ops.colocate_with(self.op):
+ return gen_dataset_ops.deserialize_iterator(self.op, restored_tensors[0])
+
+
+@tf_export("data.experimental.CheckpointInputPipelineHook")
+class CheckpointInputPipelineHook(session_run_hook.SessionRunHook):
+ """Checkpoints input pipeline state every N steps or seconds.
+
+ This hook saves the state of the iterators in the `Graph` so that when
+ training is resumed the input pipeline continues from where it left off.
+ This could potentially avoid overfitting in certain pipelines where the
+ number of training steps per eval are small compared to the dataset
+ size or if the training pipeline is pre-empted.
+
+ Differences from `CheckpointSaverHook`:
+ 1. Saves only the input pipelines in the "iterators" collection and not the
+ global variables or other saveable objects.
+ 2. Does not write the `GraphDef` and `MetaGraphDef` to the summary.
+
+ Example of checkpointing the training pipeline:
+
+ ```python
+ est = tf.estimator.Estimator(model_fn)
+ while True:
+ est.train(
+ train_input_fn,
+ hooks=[tf.data.experimental.CheckpointInputPipelineHook(est)],
+ steps=train_steps_per_eval)
+ # Note: We do not pass the hook here.
+ metrics = est.evaluate(eval_input_fn)
+ if should_stop_the_training(metrics):
+ break
+ ```
+
+ This hook should be used if the input pipeline state needs to be saved
+ separate from the model checkpoint. Doing so may be useful for a few reasons:
+ 1. The input pipeline checkpoint may be large, if there are large shuffle
+ or prefetch buffers for instance, and may bloat the checkpoint size.
+ 2. If the input pipeline is shared between training and validation, restoring
+ the checkpoint during validation may override the validation input
+ pipeline.
+
+ For saving the input pipeline checkpoint alongside the model weights use
+ `tf.data.experimental.make_saveable_from_iterator` directly to create a
+ `SaveableObject` and add to the `SAVEABLE_OBJECTS` collection. Note, however,
+ that you will need to be careful not to restore the training iterator during
+ eval. You can do that by not adding the iterator to the SAVEABLE_OBJECTS
+ collector when building the eval graph.
+ """
+
+ def __init__(self, estimator):
+ """Initializes a `CheckpointInputPipelineHook`.
+
+ Args:
+ estimator: Estimator.
+
+ Raises:
+ ValueError: One of `save_steps` or `save_secs` should be set.
+ ValueError: At most one of saver or scaffold should be set.
+ """
+ # `checkpoint_basename` is "input.ckpt" for non-distributed pipelines or
+ # of the form "input_<task_type>_<task_id>.ckpt" for distributed pipelines.
+ # Note: The default `checkpoint_basename` used by `CheckpointSaverHook` is
+ # "model.ckpt". We intentionally choose the input pipeline checkpoint prefix
+ # to be different to avoid conflicts with the model checkpoint.
+
+ # pylint: disable=protected-access
+ checkpoint_prefix = "input"
+ if estimator._config.num_worker_replicas > 1:
+ # Distributed setting.
+ suffix = "_{}_{}".format(estimator._config.task_type,
+ estimator._config.task_id)
+ checkpoint_prefix += suffix
+ # pylint: enable=protected-access
+
+ # We use a composition paradigm instead of inheriting from
+ # `CheckpointSaverHook` because `Estimator` does an `isinstance` check
+ # to check whether a `CheckpointSaverHook` is already present in the list
+ # of hooks and if not, adds one. Inheriting from `CheckpointSaverHook`
+ # would thwart this behavior. This hook checkpoints *only the iterators*
+ # and not the graph variables.
+ self._checkpoint_saver_hook = basic_session_run_hooks.CheckpointSaverHook(
+ estimator.model_dir,
+ save_secs=estimator._config.save_checkpoints_secs, # pylint: disable=protected-access
+ save_steps=estimator._config.save_checkpoints_steps, # pylint: disable=protected-access
+ checkpoint_basename=checkpoint_prefix + ".ckpt")
+
+ # Name for the protocol buffer file that will contain the list of most
+ # recent checkpoints stored as a `CheckpointState` protocol buffer.
+ # This file, kept in the same directory as the checkpoint files, is
+ # automatically managed by the `Saver` to keep track of recent checkpoints.
+ # The default name used by the `Saver` for this file is "checkpoint". Here
+ # we use the name "checkpoint_<checkpoint_prefix>" so that in case the
+ # `checkpoint_dir` is the same as the model checkpoint directory, there are
+ # no conflicts during restore.
+ self._latest_filename = "checkpoint_" + checkpoint_prefix
+ self._first_run = True
+
+ def begin(self):
+ # Build a Saver that saves all iterators in the `GLOBAL_ITERATORS`
+ # collection if no `Saver` or `Scaffold` is provided.
+ # pylint: disable=protected-access
+ if (self._checkpoint_saver_hook._saver is None and
+ self._checkpoint_saver_hook._scaffold is None):
+ iterators = ops.get_collection(iterator_ops.GLOBAL_ITERATORS)
+ saveables = [_Saveable(i) for i in iterators]
+ self._checkpoint_saver_hook._saver = _CustomSaver(saveables,
+ self._latest_filename)
+ # pylint: enable=protected-access
+ self._checkpoint_saver_hook.begin()
+
+ def _restore_or_save_initial_ckpt(self, session):
+ # Ideally this should be run in after_create_session but is not for the
+ # following reason:
+ # Currently there is no way of enforcing an order of running the
+ # `SessionRunHooks`. Hence it is possible that the `_DatasetInitializerHook`
+ # is run *after* this hook. That is troublesome because
+ # 1. If a checkpoint exists and this hook restores it, the initializer hook
+ # will override it.
+ # 2. If no checkpoint exists, this hook will try to save an initialized
+ # iterator which will result in an exception.
+ #
+ # As a temporary fix we enter the following implicit contract between this
+ # hook and the _DatasetInitializerHook.
+ # 1. The _DatasetInitializerHook initializes the iterator in the call to
+ # after_create_session.
+ # 2. This hook saves the iterator on the first call to `before_run()`, which
+ # is guaranteed to happen after `after_create_session()` of all hooks
+ # have been run.
+
+ # Check if there is an existing checkpoint. If so, restore from it.
+ # pylint: disable=protected-access
+ latest_checkpoint_path = checkpoint_management.latest_checkpoint(
+ self._checkpoint_saver_hook._checkpoint_dir,
+ latest_filename=self._latest_filename)
+ if latest_checkpoint_path:
+ self._checkpoint_saver_hook._get_saver().restore(session,
+ latest_checkpoint_path)
+ else:
+ # The checkpoint saved here is the state at step "global_step".
+ # Note: We do not save the GraphDef or MetaGraphDef here.
+ global_step = session.run(self._checkpoint_saver_hook._global_step_tensor)
+ self._checkpoint_saver_hook._save(session, global_step)
+ self._checkpoint_saver_hook._timer.update_last_triggered_step(global_step)
+ # pylint: enable=protected-access
+
+ def before_run(self, run_context):
+ if self._first_run:
+ self._restore_or_save_initial_ckpt(run_context.session)
+ self._first_run = False
+ return self._checkpoint_saver_hook.before_run(run_context)
+
+ def after_run(self, run_context, run_values):
+ self._checkpoint_saver_hook.after_run(run_context, run_values)
+
+ def end(self, session):
+ self._checkpoint_saver_hook.end(session)
+
+
+class _CustomSaver(saver_lib.Saver):
+ """`Saver` with a different default `latest_filename`.
+
+ This is used in the `CheckpointInputPipelineHook` to avoid conflicts with
+ the model ckpt saved by the `CheckpointSaverHook`.
+ """
+
+ def __init__(self, var_list, latest_filename):
+ super(_CustomSaver, self).__init__(var_list)
+ self._latest_filename = latest_filename
+
+ def save(self,
+ sess,
+ save_path,
+ global_step=None,
+ latest_filename=None,
+ meta_graph_suffix="meta",
+ write_meta_graph=True,
+ write_state=True,
+ strip_default_attrs=False):
+ return super(_CustomSaver, self).save(
+ sess, save_path, global_step, latest_filename or self._latest_filename,
+ meta_graph_suffix, write_meta_graph, write_state, strip_default_attrs)
+
+
+tf_export("data.experimental.Optional")(optional_ops.Optional)
+tf_export("data.experimental.get_next_as_optional")(
+ iterator_ops.get_next_as_optional)
diff --git a/tensorflow/python/data/experimental/ops/map_defun.py b/tensorflow/python/data/experimental/ops/map_defun.py
new file mode 100644
index 0000000000..3d0d0993c9
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/map_defun.py
@@ -0,0 +1,56 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Experimental API for optimizing `tf.data` pipelines."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import gen_dataset_ops
+
+
+def map_defun(fn, elems, output_dtypes, output_shapes):
+ """Map a function on the list of tensors unpacked from `elems` on dimension 0.
+
+ Args:
+ fn: A function (`function.Defun`) that takes a list of tensors and returns
+ another list of tensors. The output list has the same types as
+ output_dtypes. The elements of the output list have the same dimension 0
+ as `elems`, and the remaining dimensions correspond to those of
+ `fn_output_shapes`.
+ elems: A list of tensors.
+ output_dtypes: A list of dtypes corresponding to the output types of the
+ function.
+ output_shapes: A list of `TensorShape`s corresponding to the output
+ shapes from each invocation of the function on slices of inputs.
+
+ Raises:
+ ValueError: if any of the inputs are malformed.
+
+ Returns:
+ A list of `Tensor` objects with the same types as `output_dtypes`.
+ """
+ if not isinstance(elems, list):
+ raise ValueError("`elems` must be a list of tensors.")
+ if not isinstance(output_dtypes, list):
+ raise ValueError("`output_dtypes` must be a list of tensors.")
+ if not isinstance(output_shapes, list):
+ raise ValueError("`output_shapes` must be a list of tensors.")
+
+ elems = [ops.convert_to_tensor(e) for e in elems]
+ output_shapes = [tensor_shape.TensorShape(s) for s in output_shapes]
+ return gen_dataset_ops.map_defun(elems, output_dtypes, output_shapes, fn)
diff --git a/tensorflow/python/data/experimental/ops/optimization.py b/tensorflow/python/data/experimental/ops/optimization.py
new file mode 100644
index 0000000000..30348ede36
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/optimization.py
@@ -0,0 +1,171 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Experimental API for optimizing `tf.data` pipelines."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import gen_experimental_dataset_ops
+
+# A constant that can be used to enable auto-tuning.
+AUTOTUNE = -1
+
+
+# TODO(jsimsa): Support RE matching for both individual transformation (e.g. to
+# account for indexing) and transformation sequence.
+def assert_next(transformations):
+ """A transformation that asserts which transformations happen next.
+
+ Args:
+ transformations: A `tf.string` vector `tf.Tensor` identifying the
+ transformations that are expected to happen next.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+
+ def _apply_fn(dataset):
+ """Function from `Dataset` to `Dataset` that applies the transformation."""
+ return _AssertNextDataset(dataset, transformations)
+
+ return _apply_fn
+
+
+def model():
+ """A transformation that models performance.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+
+ def _apply_fn(dataset):
+ """Function from `Dataset` to `Dataset` that applies the transformation."""
+ return _ModelDataset(dataset)
+
+ return _apply_fn
+
+
+def optimize(optimizations=None):
+ """A transformation that applies optimizations.
+
+ Args:
+ optimizations: (Optional.) A `tf.string` vector `tf.Tensor` identifying
+ optimizations to use. If not specified, the default set of optimizations
+ is applied.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+
+ def _apply_fn(dataset):
+ """Function from `Dataset` to `Dataset` that applies the transformation."""
+ return _OptimizeDataset(dataset, optimizations)
+
+ return _apply_fn
+
+
+class _AssertNextDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` that asserts which transformations happen next."""
+
+ def __init__(self, input_dataset, transformations):
+ """See `assert_next()` for details."""
+ super(_AssertNextDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+ if transformations is None:
+ raise ValueError("At least one transformation should be specified")
+ self._transformations = ops.convert_to_tensor(
+ transformations, dtype=dtypes.string, name="transformations")
+
+ def _as_variant_tensor(self):
+ return gen_experimental_dataset_ops.experimental_assert_next_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._transformations,
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+
+class _ModelDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` that acts as an identity, and models performance."""
+
+ def __init__(self, input_dataset):
+ """See `optimize()` for details."""
+ super(_ModelDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.model_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+
+class _OptimizeDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` that acts as an identity, and applies optimizations."""
+
+ def __init__(self, input_dataset, optimizations):
+ """See `optimize()` for details."""
+ super(_OptimizeDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+ if optimizations is None:
+ optimizations = []
+ self._optimizations = ops.convert_to_tensor(
+ optimizations, dtype=dtypes.string, name="optimizations")
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.optimize_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._optimizations,
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
diff --git a/tensorflow/python/data/experimental/ops/parsing_ops.py b/tensorflow/python/data/experimental/ops/parsing_ops.py
new file mode 100644
index 0000000000..6615b9022a
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/parsing_ops.py
@@ -0,0 +1,152 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Experimental `dataset` API for parsing example."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import nest
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import parsing_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+class _ParseExampleDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` that parses `example` dataset into a `dict` dataset."""
+
+ def __init__(self, input_dataset, features, num_parallel_calls):
+ super(_ParseExampleDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+ if not all(types == dtypes.string
+ for types in nest.flatten(input_dataset.output_types)):
+ raise TypeError("Input dataset should be a dataset of vectors of strings")
+ self._num_parallel_calls = num_parallel_calls
+ # pylint: disable=protected-access
+ self._features = parsing_ops._prepend_none_dimension(features)
+ # sparse_keys and dense_keys come back sorted here.
+ (sparse_keys, sparse_types, dense_keys, dense_types, dense_defaults,
+ dense_shapes) = parsing_ops._features_to_raw_params(
+ self._features, [
+ parsing_ops.VarLenFeature, parsing_ops.SparseFeature,
+ parsing_ops.FixedLenFeature, parsing_ops.FixedLenSequenceFeature
+ ])
+ # TODO(b/112859642): Pass sparse_index and sparse_values for SparseFeature.
+ (_, dense_defaults_vec, sparse_keys, sparse_types, dense_keys, dense_shapes,
+ dense_shape_as_shape) = parsing_ops._process_raw_parameters(
+ None, dense_defaults, sparse_keys, sparse_types, dense_keys,
+ dense_types, dense_shapes)
+ # pylint: enable=protected-access
+ self._sparse_keys = sparse_keys
+ self._sparse_types = sparse_types
+ self._dense_keys = dense_keys
+ self._dense_defaults = dense_defaults_vec
+ self._dense_shapes = dense_shapes
+ self._dense_types = dense_types
+ dense_output_shapes = [
+ self._input_dataset.output_shapes.concatenate(shape)
+ for shape in dense_shape_as_shape
+ ]
+ sparse_output_shapes = [
+ self._input_dataset.output_shapes.concatenate([None])
+ for _ in range(len(sparse_keys))
+ ]
+
+ self._output_shapes = dict(
+ zip(self._dense_keys + self._sparse_keys,
+ dense_output_shapes + sparse_output_shapes))
+ self._output_types = dict(
+ zip(self._dense_keys + self._sparse_keys,
+ self._dense_types + self._sparse_types))
+ self._output_classes = dict(
+ zip(self._dense_keys + self._sparse_keys,
+ [ops.Tensor for _ in range(len(self._dense_defaults))] +
+ [sparse_tensor.SparseTensor for _ in range(len(self._sparse_keys))
+ ]))
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.parse_example_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._num_parallel_calls,
+ self._dense_defaults,
+ self._sparse_keys,
+ self._dense_keys,
+ self._sparse_types,
+ self._dense_shapes,
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_shapes(self):
+ return self._output_shapes
+
+ @property
+ def output_types(self):
+ return self._output_types
+
+ @property
+ def output_classes(self):
+ return self._output_classes
+
+
+# TODO(b/111553342): add arguments names and example names as well.
+@tf_export("data.experimental.parse_example_dataset")
+def parse_example_dataset(features, num_parallel_calls=1):
+ """A transformation that parses `Example` protos into a `dict` of tensors.
+
+ Parses a number of serialized `Example` protos given in `serialized`. We refer
+ to `serialized` as a batch with `batch_size` many entries of individual
+ `Example` protos.
+
+ This op parses serialized examples into a dictionary mapping keys to `Tensor`
+ and `SparseTensor` objects. `features` is a dict from keys to `VarLenFeature`,
+ `SparseFeature`, and `FixedLenFeature` objects. Each `VarLenFeature`
+ and `SparseFeature` is mapped to a `SparseTensor`, and each
+ `FixedLenFeature` is mapped to a `Tensor`. See `tf.parse_example` for more
+ details about feature dictionaries.
+
+ Args:
+ features: A `dict` mapping feature keys to `FixedLenFeature`,
+ `VarLenFeature`, and `SparseFeature` values.
+ num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`,
+ representing the number of parsing processes to call in parallel.
+
+ Returns:
+ A dataset transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+
+ Raises:
+ ValueError: if features argument is None.
+ """
+ if features is None:
+ raise ValueError("Missing: features was %s." % features)
+
+ def _apply_fn(dataset):
+ """Function from `Dataset` to `Dataset` that applies the transformation."""
+ out_dataset = _ParseExampleDataset(dataset, features, num_parallel_calls)
+ if any([
+ isinstance(feature, parsing_ops.SparseFeature)
+ for _, feature in features.items()
+ ]):
+ # pylint: disable=protected-access
+ # pylint: disable=g-long-lambda
+ out_dataset = out_dataset.map(
+ lambda x: parsing_ops._construct_sparse_tensors_for_sparse_features(
+ features, x), num_parallel_calls=num_parallel_calls)
+ return out_dataset
+
+ return _apply_fn
diff --git a/tensorflow/python/data/experimental/ops/prefetching_ops.py b/tensorflow/python/data/experimental/ops/prefetching_ops.py
new file mode 100644
index 0000000000..48d7136f95
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/prefetching_ops.py
@@ -0,0 +1,531 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python wrapper for prefetching_ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import warnings
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.data.util import nest
+from tensorflow.python.data.util import sparse
+from tensorflow.python.eager import context
+from tensorflow.python.framework import device as framework_device
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import function
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import functional_ops
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+def function_buffering_resource(string_arg,
+ target_device,
+ f,
+ buffer_size,
+ output_types,
+ container="",
+ shared_name=None,
+ name=None):
+ """Creates a FunctionBufferingResource.
+
+ A FunctionBufferingResource fills up a buffer by calling a function `f` on
+ `target_device`. `f` should take in only a single string argument as input.
+
+ Args:
+ string_arg: The single string argument to the function.
+ target_device: The device to run `f` on.
+ f: The function to be executed.
+ buffer_size: Size of the buffer to be populated.
+ output_types: The output types generated by the function.
+ container: (Optional) string. Defaults to "".
+ shared_name: (Optional) string.
+ name: (Optional) string to name the op.
+
+ Returns:
+ Handle to a FunctionBufferingResource.
+ """
+ if shared_name is None:
+ shared_name = ""
+ return ged_ops.experimental_function_buffering_resource(
+ string_arg=string_arg,
+ target_device=target_device,
+ shared_name=shared_name,
+ f=f,
+ buffer_size=buffer_size,
+ container=container,
+ name=name,
+ output_types=output_types)
+
+
+def function_buffering_resource_get_next(function_buffer_resource,
+ output_types,
+ name=None):
+ return ged_ops.experimental_function_buffering_resource_get_next(
+ function_buffer_resource=function_buffer_resource,
+ output_types=output_types,
+ name=name)
+
+
+def function_buffering_resource_reset(function_buffer_resource, name=None):
+ return ged_ops.experimental_function_buffering_resource_reset(
+ function_buffer_resource=function_buffer_resource, name=name)
+
+
+# pylint: disable=protected-access
+class _PrefetchToDeviceIterator(object):
+ """A replacement for `tf.data.Iterator` that prefetches to another device.
+
+ Args:
+ input_dataset: The input dataset
+ one_shot: If true, we make a one shot iterator that's already initialized.
+ device: A fully specified device string where we want to prefetch to
+ buffer_size: Size of the prefetching buffer.
+ shared_name: (Optional.) If non-empty, the returned iterator will be
+ shared under the given name across multiple sessions that share the
+ same devices (e.g. when using a remote server).
+
+ Returns:
+ An Iterator type object.
+ """
+
+ def __init__(self,
+ input_dataset,
+ one_shot,
+ device,
+ buffer_size,
+ shared_name=None):
+ self._input_dataset = input_dataset
+ self._get_next_call_count = 0
+ self._one_shot = one_shot
+ if shared_name is None:
+ shared_name = ""
+
+ if self._one_shot:
+ self._input_iterator = input_dataset.make_one_shot_iterator()
+ else:
+ self._input_iterator = iterator_ops.Iterator.from_structure(
+ self._input_dataset.output_types, self._input_dataset.output_shapes,
+ shared_name, self._input_dataset.output_classes)
+ input_iterator_handle = self._input_iterator.string_handle()
+
+ @function.Defun(dtypes.string)
+ def _prefetch_fn(handle):
+ """Prefetches one element from `input_iterator`."""
+ remote_iterator = iterator_ops.Iterator.from_string_handle(
+ handle, self._input_iterator.output_types,
+ self._input_iterator.output_shapes,
+ self._input_iterator.output_classes)
+ ret = remote_iterator.get_next()
+ return nest.flatten(sparse.serialize_sparse_tensors(ret))
+
+ iterator_device = ged_ops.experimental_iterator_get_device(
+ self._input_iterator._iterator_resource)
+
+ with ops.device(device):
+ self._buffering_resource = function_buffering_resource(
+ f=_prefetch_fn,
+ target_device=iterator_device,
+ string_arg=input_iterator_handle,
+ buffer_size=buffer_size,
+ shared_name=shared_name,
+ output_types=nest.flatten(
+ sparse.as_dense_types(self._input_dataset.output_types,
+ self._input_dataset.output_classes)))
+
+ if not self._one_shot:
+ reset_op = function_buffering_resource_reset(self._buffering_resource)
+ with ops.control_dependencies([reset_op]):
+ self._initializer = self._input_iterator.make_initializer(
+ self._input_dataset)
+
+ def get_next(self, name=None):
+ """See `tf.data.Iterator.get_next`."""
+ self._get_next_call_count += 1
+ if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD:
+ warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE)
+
+ flat_ret = ged_ops.experimental_function_buffering_resource_get_next(
+ self._buffering_resource,
+ output_types=nest.flatten(
+ sparse.as_dense_types(self.output_types, self.output_classes)),
+ name=name)
+
+ ret = sparse.deserialize_sparse_tensors(
+ nest.pack_sequence_as(self.output_types, flat_ret),
+ self.output_types, self.output_shapes, self.output_classes)
+
+ for tensor, shape in zip(
+ nest.flatten(ret), nest.flatten(self.output_shapes)):
+ if isinstance(tensor, ops.Tensor):
+ tensor.set_shape(shape)
+
+ return ret
+
+ @property
+ def initializer(self):
+ if self._one_shot:
+ raise NotImplementedError("Can't initialize a one_shot_iterator")
+ return self._initializer
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+
+class _PrefetchToDeviceEagerIterator(iterator_ops.EagerIterator):
+ """A replacement for `tf.data.Iterator` that prefetches to another device.
+
+ Args:
+ input_dataset: The input dataset
+ one_shot: If true, we make a one shot iterator that's already initialized.
+ device: A fully specified device string where we want to prefetch to
+ buffer_size: Size of the prefetching buffer.
+ shared_name: (Optional.) If non-empty, the returned iterator will be
+ shared under the given name across multiple sessions that share the
+ same devices (e.g. when using a remote server).
+
+ Returns:
+ An Iterator type object.
+ """
+
+ def __init__(self,
+ input_dataset,
+ device,
+ buffer_size):
+ with ops.device("/device:CPU:0"):
+ super(_PrefetchToDeviceEagerIterator, self).__init__(input_dataset)
+ input_iterator_handle = gen_dataset_ops.iterator_to_string_handle(
+ self._resource)
+
+ self._device = device
+
+ @function.Defun(dtypes.string)
+ def _prefetch_fn(handle):
+ """Prefetches one element from `input_iterator`."""
+ remote_iterator = iterator_ops.Iterator.from_string_handle(
+ handle, self.output_types, self.output_shapes, self.output_classes)
+ ret = remote_iterator.get_next()
+ return nest.flatten(sparse.serialize_sparse_tensors(ret))
+
+ _prefetch_fn.add_to_graph(None)
+
+ with ops.device(device):
+ self._buffering_resource = function_buffering_resource(
+ f=_prefetch_fn,
+ output_types=self._flat_output_types,
+ target_device=ged_ops.experimental_iterator_get_device(
+ self._resource),
+ string_arg=input_iterator_handle,
+ buffer_size=buffer_size,
+ shared_name=iterator_ops._generate_shared_name(
+ "function_buffer_resource"))
+
+ def _next_internal(self):
+ """Returns a nested structure of `tf.Tensor`s containing the next element.
+ """
+ # This runs in sync mode as iterators use an error status to communicate
+ # that there is no more data to iterate over.
+ # TODO(b/77291417): Fix
+ with context.execution_mode(context.SYNC):
+ with ops.device(self._device):
+ ret = ged_ops.experimental_function_buffering_resource_get_next(
+ function_buffer_resource=self._buffering_resource,
+ output_types=self._flat_output_types)
+ return sparse.deserialize_sparse_tensors(
+ nest.pack_sequence_as(self._output_types, ret), self._output_types,
+ self._output_shapes, self._output_classes)
+# pylint: enable=protected-access
+
+
+class _PrefetchToDeviceDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` whose iterator prefetches elements to another device."""
+
+ def __init__(self, input_dataset, device, buffer_size):
+ super(_PrefetchToDeviceDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+ self._device = device
+ self._buffer_size = buffer_size if buffer_size is not None else 1
+
+ # The static analysis cannot tell that the eager iterator's superclass has
+ # a `next()` method.
+ # pylint: disable=non-iterator-returned
+ def __iter__(self):
+ """Creates an `Iterator` for enumerating the elements of this dataset.
+
+ The returned iterator implements the Python iterator protocol and therefore
+ can only be used in eager mode.
+
+ Returns:
+ An `Iterator` over the elements of this dataset.
+
+ Raises:
+ RuntimeError: If eager execution is enabled.
+ """
+ if context.executing_eagerly():
+ return _PrefetchToDeviceEagerIterator(self._input_dataset, self._device,
+ self._buffer_size)
+ else:
+ raise RuntimeError("dataset.__iter__() is only supported when eager "
+ "execution is enabled.")
+ # pylint: enable=non-iterator-returned
+
+ def make_one_shot_iterator(self):
+ if context.executing_eagerly():
+ return _PrefetchToDeviceEagerIterator(self._input_dataset, self._device,
+ self._buffer_size)
+ else:
+ return _PrefetchToDeviceIterator(self._input_dataset, one_shot=True,
+ device=self._device,
+ buffer_size=self._buffer_size)
+
+ def make_initializable_iterator(self, shared_name=None):
+ return _PrefetchToDeviceIterator(
+ self._input_dataset,
+ one_shot=False,
+ device=self._device,
+ buffer_size=self._buffer_size,
+ shared_name=shared_name)
+
+ def _as_variant_tensor(self):
+ # TODO(mrry): Raise this error earlier (e.g. when one of the Dataset
+ # transformation methods is called.
+ # TODO(mrry): Investigate support for chaining further transformations after
+ # the prefetch, including GPU support.
+ raise NotImplementedError("`prefetch_to_device()` must be the last "
+ "transformation in a dataset pipeline.")
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+
+@tf_export("data.experimental.prefetch_to_device")
+def prefetch_to_device(device, buffer_size=None):
+ """A transformation that prefetches dataset values to the given `device`.
+
+ NOTE: Although the transformation creates a `tf.data.Dataset`, the
+ transformation must be the final `Dataset` in the input pipeline.
+
+ Args:
+ device: A string. The name of a device to which elements will be prefetched.
+ buffer_size: (Optional.) The number of elements to buffer on `device`.
+ Defaults to an automatically chosen value.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+ def _apply_fn(dataset):
+ return _PrefetchToDeviceDataset(dataset, device, buffer_size)
+
+ return _apply_fn
+
+
+@tf_export("data.experimental.copy_to_device")
+def copy_to_device(target_device, source_device="/cpu:0"):
+ """A transformation that copies dataset elements to the given `target_device`.
+
+ Args:
+ target_device: The name of a device to which elements will be copied.
+ source_device: The original device on which `input_dataset` will be placed.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+
+ def _apply_fn(dataset):
+ return _CopyToDeviceDataset(
+ dataset, target_device=target_device, source_device=source_device)
+
+ return _apply_fn
+
+
+# TODO(rohanj): Use the _input_hostmem attr on the RemoteCall ops to indicate
+# all inputs to the Op are in host memory, thereby avoiding some unnecessary
+# Sends and Recvs.
+class _CopyToDeviceDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` that copies elements to another device."""
+
+ def __init__(self, input_dataset, target_device, source_device="/cpu:0"):
+ """Constructs a _CopyToDeviceDataset.
+
+ Args:
+ input_dataset: `Dataset` to be copied
+ target_device: The name of the device to which elements would be copied.
+ source_device: Device where input_dataset would be placed.
+ """
+ super(_CopyToDeviceDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+ self._target_device = target_device
+ spec = framework_device.DeviceSpec().from_string(self._target_device)
+ self._is_gpu_target = (spec.device_type == "GPU")
+ self._source_device_string = source_device
+ self._source_device = ops.convert_to_tensor(source_device)
+
+ self._flat_output_shapes = nest.flatten(
+ sparse.as_dense_shapes(self._input_dataset.output_shapes,
+ self._input_dataset.output_classes))
+ self._flat_output_types = nest.flatten(
+ sparse.as_dense_types(self._input_dataset.output_types,
+ self._input_dataset.output_classes))
+
+ @function.Defun()
+ def _init_func():
+ """Creates an iterator for the input dataset.
+
+ Returns:
+ A `string` tensor that encapsulates the iterator created.
+ """
+ # pylint: disable=protected-access
+ ds_variant = self._input_dataset._as_variant_tensor()
+ resource = gen_dataset_ops.anonymous_iterator(
+ output_types=self._flat_output_types,
+ output_shapes=self._flat_output_shapes)
+ with ops.control_dependencies(
+ [gen_dataset_ops.make_iterator(ds_variant, resource)]):
+ return gen_dataset_ops.iterator_to_string_handle(resource)
+
+ @function.Defun()
+ def _remote_init_func():
+ return functional_ops.remote_call(
+ target=self._source_device,
+ args=_init_func.captured_inputs,
+ Tout=[dtypes.string],
+ f=_init_func)
+
+ self._init_func = _remote_init_func
+ self._init_captured_args = _remote_init_func.captured_inputs
+
+ @function.Defun(dtypes.string)
+ def _next_func(string_handle):
+ """Calls get_next for created iterator.
+
+ Args:
+ string_handle: An iterator string handle created by _init_func
+ Returns:
+ The elements generated from `input_dataset`
+ """
+ with ops.device(self._source_device_string):
+ iterator = iterator_ops.Iterator.from_string_handle(
+ string_handle, self.output_types, self.output_shapes,
+ self.output_classes)
+ ret = iterator.get_next()
+ return nest.flatten(sparse.serialize_sparse_tensors(ret))
+
+ @function.Defun(dtypes.string)
+ def _remote_next_func(string_handle):
+ return functional_ops.remote_call(
+ target=self._source_device,
+ args=[string_handle] + _next_func.captured_inputs,
+ Tout=self._flat_output_types,
+ f=_next_func)
+
+ self._next_func = _remote_next_func
+ self._next_captured_args = _remote_next_func.captured_inputs
+
+ @function.Defun(dtypes.string)
+ def _finalize_func(string_handle):
+ """Destroys the iterator resource created.
+
+ Args:
+ string_handle: An iterator string handle created by _init_func
+ Returns:
+ Tensor constant 0
+ """
+ iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2(
+ string_handle,
+ output_types=self._flat_output_types,
+ output_shapes=self._flat_output_shapes)
+ with ops.control_dependencies([
+ resource_variable_ops.destroy_resource_op(
+ iterator_resource, ignore_lookup_error=True)]):
+ return array_ops.constant(0, dtypes.int64)
+
+ @function.Defun(dtypes.string)
+ def _remote_finalize_func(string_handle):
+ return functional_ops.remote_call(
+ target=self._source_device,
+ args=[string_handle] + _finalize_func.captured_inputs,
+ Tout=[dtypes.int64],
+ f=_finalize_func)
+
+ self._finalize_func = _remote_finalize_func
+ self._finalize_captured_args = _remote_finalize_func.captured_inputs
+
+ g = ops.get_default_graph()
+ _remote_init_func.add_to_graph(g)
+ _remote_next_func.add_to_graph(g)
+ _remote_finalize_func.add_to_graph(g)
+ # pylint: enable=protected-scope
+
+ # The one_shot_iterator implementation needs a 0 arg _make_dataset function
+ # that thereby captures all the inputs required to create the dataset. Since
+ # there are strings that are inputs to the GeneratorDataset which can't be
+ # placed on a GPU, this fails for the GPU case. Therefore, disabling it for
+ # GPU
+ def make_one_shot_iterator(self):
+ if self._is_gpu_target:
+ raise ValueError("Cannot create a one shot iterator when using "
+ "`tf.data.experimental.copy_to_device()` on GPU. Please "
+ "use `Dataset.make_initializable_iterator()` instead.")
+ else:
+ return super(_CopyToDeviceDataset, self).make_one_shot_iterator()
+
+ def _as_variant_tensor(self):
+ with ops.device(self._target_device):
+ return gen_dataset_ops.generator_dataset(
+ self._init_captured_args,
+ self._next_captured_args,
+ self._finalize_captured_args,
+ init_func=self._init_func,
+ next_func=self._next_func,
+ finalize_func=self._finalize_func,
+ output_types=self._flat_output_types,
+ output_shapes=self._flat_output_shapes)
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
diff --git a/tensorflow/python/data/experimental/ops/random_ops.py b/tensorflow/python/data/experimental/ops/random_ops.py
new file mode 100644
index 0000000000..e3a2aeab31
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/random_ops.py
@@ -0,0 +1,54 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Datasets for random number generators."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import random_seed
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("data.experimental.RandomDataset")
+class RandomDataset(dataset_ops.DatasetSource):
+ """A `Dataset` of pseudorandom values."""
+
+ def __init__(self, seed=None):
+ """A `Dataset` of pseudorandom values."""
+ super(RandomDataset, self).__init__()
+ self._seed, self._seed2 = random_seed.get_seed(seed)
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.random_dataset(
+ seed=self._seed,
+ seed2=self._seed2,
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_classes(self):
+ return ops.Tensor
+
+ @property
+ def output_shapes(self):
+ return tensor_shape.scalar()
+
+ @property
+ def output_types(self):
+ return dtypes.int64
diff --git a/tensorflow/python/data/experimental/ops/readers.py b/tensorflow/python/data/experimental/ops/readers.py
new file mode 100644
index 0000000000..3b2d094514
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/readers.py
@@ -0,0 +1,904 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python wrappers for reader Datasets."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import csv
+
+import numpy as np
+
+from tensorflow.python.data.experimental.ops import batching
+from tensorflow.python.data.experimental.ops import interleave_ops
+from tensorflow.python.data.experimental.ops import optimization
+from tensorflow.python.data.experimental.ops import parsing_ops
+from tensorflow.python.data.experimental.ops import shuffle_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.ops import readers as core_readers
+from tensorflow.python.data.util import convert
+from tensorflow.python.data.util import nest
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.lib.io import file_io
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import gen_experimental_dataset_ops
+from tensorflow.python.platform import gfile
+from tensorflow.python.util.tf_export import tf_export
+
+_ACCEPTABLE_CSV_TYPES = (dtypes.float32, dtypes.float64, dtypes.int32,
+ dtypes.int64, dtypes.string)
+
+
+def _is_valid_int32(str_val):
+ try:
+ # Checks equality to prevent int32 overflow
+ return dtypes.int32.as_numpy_dtype(str_val) == dtypes.int64.as_numpy_dtype(
+ str_val)
+ except (ValueError, OverflowError):
+ return False
+
+
+def _is_valid_int64(str_val):
+ try:
+ dtypes.int64.as_numpy_dtype(str_val)
+ return True
+ except (ValueError, OverflowError):
+ return False
+
+
+def _is_valid_float(str_val, float_dtype):
+ try:
+ return float_dtype.as_numpy_dtype(str_val) < np.inf
+ except ValueError:
+ return False
+
+
+def _infer_type(str_val, na_value, prev_type):
+ """Given a string, infers its tensor type.
+
+ Infers the type of a value by picking the least 'permissive' type possible,
+ while still allowing the previous type inference for this column to be valid.
+
+ Args:
+ str_val: String value to infer the type of.
+ na_value: Additional string to recognize as a NA/NaN CSV value.
+ prev_type: Type previously inferred based on values of this column that
+ we've seen up till now.
+ Returns:
+ Inferred dtype.
+ """
+ if str_val in ("", na_value):
+ # If the field is null, it gives no extra information about its type
+ return prev_type
+
+ type_list = [
+ dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64, dtypes.string
+ ] # list of types to try, ordered from least permissive to most
+
+ type_functions = [
+ _is_valid_int32,
+ _is_valid_int64,
+ lambda str_val: _is_valid_float(str_val, dtypes.float32),
+ lambda str_val: _is_valid_float(str_val, dtypes.float64),
+ lambda str_val: True,
+ ] # Corresponding list of validation functions
+
+ for i in range(len(type_list)):
+ validation_fn = type_functions[i]
+ if validation_fn(str_val) and (prev_type is None or
+ prev_type in type_list[:i + 1]):
+ return type_list[i]
+
+
+def _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header):
+ """Generator that yields rows of CSV file(s) in order."""
+ for fn in filenames:
+ with file_io.FileIO(fn, "r") as f:
+ rdr = csv.reader(
+ f,
+ delimiter=field_delim,
+ quoting=csv.QUOTE_MINIMAL if use_quote_delim else csv.QUOTE_NONE)
+ if header:
+ next(rdr) # Skip header lines
+
+ for csv_row in rdr:
+ if len(csv_row) != num_cols:
+ raise ValueError(
+ "Problem inferring types: CSV row has different number of fields "
+ "than expected.")
+ yield csv_row
+
+
+def _infer_column_defaults(filenames, num_cols, field_delim, use_quote_delim,
+ na_value, header, num_rows_for_inference,
+ select_columns):
+ """Infers column types from the first N valid CSV records of files."""
+ if select_columns is None:
+ select_columns = range(num_cols)
+ inferred_types = [None] * len(select_columns)
+
+ for i, csv_row in enumerate(
+ _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header)):
+ if num_rows_for_inference is not None and i >= num_rows_for_inference:
+ break
+
+ for j, col_index in enumerate(select_columns):
+ inferred_types[j] = _infer_type(csv_row[col_index], na_value,
+ inferred_types[j])
+
+ # Replace None's with a default type
+ inferred_types = [t or dtypes.string for t in inferred_types]
+ # Default to 0 or '' for null values
+ return [
+ constant_op.constant([0 if t is not dtypes.string else ""], dtype=t)
+ for t in inferred_types
+ ]
+
+
+def _infer_column_names(filenames, field_delim, use_quote_delim):
+ """Infers column names from first rows of files."""
+ csv_kwargs = {
+ "delimiter": field_delim,
+ "quoting": csv.QUOTE_MINIMAL if use_quote_delim else csv.QUOTE_NONE
+ }
+ with file_io.FileIO(filenames[0], "r") as f:
+ try:
+ column_names = next(csv.reader(f, **csv_kwargs))
+ except StopIteration:
+ raise ValueError(("Received StopIteration when reading the header line "
+ "of %s. Empty file?") % filenames[0])
+
+ for name in filenames[1:]:
+ with file_io.FileIO(name, "r") as f:
+ try:
+ if next(csv.reader(f, **csv_kwargs)) != column_names:
+ raise ValueError(
+ "Files have different column names in the header row.")
+ except StopIteration:
+ raise ValueError(("Received StopIteration when reading the header line "
+ "of %s. Empty file?") % filenames[0])
+ return column_names
+
+
+def _get_sorted_col_indices(select_columns, column_names):
+ """Transforms select_columns argument into sorted column indices."""
+ names_to_indices = {n: i for i, n in enumerate(column_names)}
+ num_cols = len(column_names)
+ for i, v in enumerate(select_columns):
+ if isinstance(v, int):
+ if v < 0 or v >= num_cols:
+ raise ValueError(
+ "Column index %d specified in select_columns out of valid range." %
+ v)
+ continue
+ if v not in names_to_indices:
+ raise ValueError(
+ "Value '%s' specified in select_columns not a valid column index or "
+ "name." % v)
+ select_columns[i] = names_to_indices[v]
+
+ # Sort and ensure there are no duplicates
+ result = sorted(set(select_columns))
+ if len(result) != len(select_columns):
+ raise ValueError("select_columns contains duplicate columns")
+ return result
+
+
+def _maybe_shuffle_and_repeat(
+ dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed):
+ """Optionally shuffle and repeat dataset, as requested."""
+ if num_epochs != 1 and shuffle:
+ # Use shuffle_and_repeat for perf
+ return dataset.apply(
+ shuffle_ops.shuffle_and_repeat(shuffle_buffer_size, num_epochs,
+ shuffle_seed))
+ elif shuffle:
+ return dataset.shuffle(shuffle_buffer_size, shuffle_seed)
+ elif num_epochs != 1:
+ return dataset.repeat(num_epochs)
+ return dataset
+
+
+def make_tf_record_dataset(file_pattern,
+ batch_size,
+ parser_fn=None,
+ num_epochs=None,
+ shuffle=True,
+ shuffle_buffer_size=None,
+ shuffle_seed=None,
+ prefetch_buffer_size=optimization.AUTOTUNE,
+ num_parallel_reads=None,
+ num_parallel_parser_calls=None,
+ drop_final_batch=False):
+ """Reads and optionally parses TFRecord files into a dataset.
+
+ Provides common functionality such as batching, optional parsing, shuffling,
+ and performant defaults.
+
+ Args:
+ file_pattern: List of files or patterns of TFRecord file paths.
+ See `tf.gfile.Glob` for pattern rules.
+ batch_size: An int representing the number of records to combine
+ in a single batch.
+ parser_fn: (Optional.) A function accepting string input to parse
+ and process the record contents. This function must map records
+ to components of a fixed shape, so they may be batched. By
+ default, uses the record contents unmodified.
+ num_epochs: (Optional.) An int specifying the number of times this
+ dataset is repeated. If None (the default), cycles through the
+ dataset forever.
+ shuffle: (Optional.) A bool that indicates whether the input
+ should be shuffled. Defaults to `True`.
+ shuffle_buffer_size: (Optional.) Buffer size to use for
+ shuffling. A large buffer size ensures better shuffling, but
+ increases memory usage and startup time.
+ shuffle_seed: (Optional.) Randomization seed to use for shuffling.
+ prefetch_buffer_size: (Optional.) An int specifying the number of
+ feature batches to prefetch for performance improvement.
+ Defaults to auto-tune. Set to 0 to disable prefetching.
+ num_parallel_reads: (Optional.) Number of threads used to read
+ records from files. By default or if set to a value >1, the
+ results will be interleaved.
+ num_parallel_parser_calls: (Optional.) Number of parallel
+ records to parse in parallel. Defaults to an automatic selection.
+ drop_final_batch: (Optional.) Whether the last batch should be
+ dropped in case its size is smaller than `batch_size`; the
+ default behavior is not to drop the smaller batch.
+
+ Returns:
+ A dataset, where each element matches the output of `parser_fn`
+ except it will have an additional leading `batch-size` dimension,
+ or a `batch_size`-length 1-D tensor of strings if `parser_fn` is
+ unspecified.
+ """
+ files = dataset_ops.Dataset.list_files(
+ file_pattern, shuffle=shuffle, seed=shuffle_seed)
+
+ if num_parallel_reads is None:
+ # Note: We considered auto-tuning this value, but there is a concern
+ # that this affects the mixing of records from different files, which
+ # could affect training convergence/accuracy, so we are defaulting to
+ # a constant for now.
+ num_parallel_reads = 24
+ dataset = core_readers.TFRecordDataset(
+ files, num_parallel_reads=num_parallel_reads)
+
+ if shuffle_buffer_size is None:
+ # TODO(josh11b): Auto-tune this value when not specified
+ shuffle_buffer_size = 10000
+ dataset = _maybe_shuffle_and_repeat(
+ dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed)
+
+ # NOTE(mrry): We set `drop_final_batch=True` when `num_epochs is None` to
+ # improve the shape inference, because it makes the batch dimension static.
+ # It is safe to do this because in that case we are repeating the input
+ # indefinitely, and all batches will be full-sized.
+ drop_final_batch = drop_final_batch or num_epochs is None
+
+ if parser_fn is None:
+ dataset = dataset.batch(batch_size, drop_remainder=drop_final_batch)
+ else:
+ # TODO(josh11b): if num_parallel_parser_calls is None, use some function
+ # of num cores instead of map_and_batch's default behavior of one batch.
+ dataset = dataset.apply(batching.map_and_batch(
+ parser_fn, batch_size, num_parallel_calls=num_parallel_parser_calls,
+ drop_remainder=drop_final_batch))
+
+ if prefetch_buffer_size == 0:
+ return dataset
+ else:
+ return dataset.prefetch(buffer_size=prefetch_buffer_size)
+
+
+@tf_export("data.experimental.make_csv_dataset")
+def make_csv_dataset(
+ file_pattern,
+ batch_size,
+ column_names=None,
+ column_defaults=None,
+ label_name=None,
+ select_columns=None,
+ field_delim=",",
+ use_quote_delim=True,
+ na_value="",
+ header=True,
+ num_epochs=None,
+ shuffle=True,
+ shuffle_buffer_size=10000,
+ shuffle_seed=None,
+ prefetch_buffer_size=optimization.AUTOTUNE,
+ num_parallel_reads=1,
+ sloppy=False,
+ num_rows_for_inference=100,
+ compression_type=None,
+):
+ """Reads CSV files into a dataset.
+
+ Reads CSV files into a dataset, where each element is a (features, labels)
+ tuple that corresponds to a batch of CSV rows. The features dictionary
+ maps feature column names to `Tensor`s containing the corresponding
+ feature data, and labels is a `Tensor` containing the batch's label data.
+
+ Args:
+ file_pattern: List of files or patterns of file paths containing CSV
+ records. See `tf.gfile.Glob` for pattern rules.
+ batch_size: An int representing the number of records to combine
+ in a single batch.
+ column_names: An optional list of strings that corresponds to the CSV
+ columns, in order. One per column of the input record. If this is not
+ provided, infers the column names from the first row of the records.
+ These names will be the keys of the features dict of each dataset element.
+ column_defaults: A optional list of default values for the CSV fields. One
+ item per selected column of the input record. Each item in the list is
+ either a valid CSV dtype (float32, float64, int32, int64, or string), or a
+ `Tensor` with one of the aforementioned types. The tensor can either be
+ a scalar default value (if the column is optional), or an empty tensor (if
+ the column is required). If a dtype is provided instead of a tensor, the
+ column is also treated as required. If this list is not provided, tries
+ to infer types based on reading the first num_rows_for_inference rows of
+ files specified, and assumes all columns are optional, defaulting to `0`
+ for numeric values and `""` for string values. If both this and
+ `select_columns` are specified, these must have the same lengths, and
+ `column_defaults` is assumed to be sorted in order of increasing column
+ index.
+ label_name: A optional string corresponding to the label column. If
+ provided, the data for this column is returned as a separate `Tensor` from
+ the features dictionary, so that the dataset complies with the format
+ expected by a `tf.Estimator.train` or `tf.Estimator.evaluate` input
+ function.
+ select_columns: An optional list of integer indices or string column
+ names, that specifies a subset of columns of CSV data to select. If
+ column names are provided, these must correspond to names provided in
+ `column_names` or inferred from the file header lines. When this argument
+ is specified, only a subset of CSV columns will be parsed and returned,
+ corresponding to the columns specified. Using this results in faster
+ parsing and lower memory usage. If both this and `column_defaults` are
+ specified, these must have the same lengths, and `column_defaults` is
+ assumed to be sorted in order of increasing column index.
+ field_delim: An optional `string`. Defaults to `","`. Char delimiter to
+ separate fields in a record.
+ use_quote_delim: An optional bool. Defaults to `True`. If false, treats
+ double quotation marks as regular characters inside of the string fields.
+ na_value: Additional string to recognize as NA/NaN.
+ header: A bool that indicates whether the first rows of provided CSV files
+ correspond to header lines with column names, and should not be included
+ in the data.
+ num_epochs: An int specifying the number of times this dataset is repeated.
+ If None, cycles through the dataset forever.
+ shuffle: A bool that indicates whether the input should be shuffled.
+ shuffle_buffer_size: Buffer size to use for shuffling. A large buffer size
+ ensures better shuffling, but increases memory usage and startup time.
+ shuffle_seed: Randomization seed to use for shuffling.
+ prefetch_buffer_size: An int specifying the number of feature
+ batches to prefetch for performance improvement. Recommended value is the
+ number of batches consumed per training step. Defaults to auto-tune.
+
+ num_parallel_reads: Number of threads used to read CSV records from files.
+ If >1, the results will be interleaved.
+ sloppy: If `True`, reading performance will be improved at
+ the cost of non-deterministic ordering. If `False`, the order of elements
+ produced is deterministic prior to shuffling (elements are still
+ randomized if `shuffle=True`. Note that if the seed is set, then order
+ of elements after shuffling is deterministic). Defaults to `False`.
+ num_rows_for_inference: Number of rows of a file to use for type inference
+ if record_defaults is not provided. If None, reads all the rows of all
+ the files. Defaults to 100.
+ compression_type: (Optional.) A `tf.string` scalar evaluating to one of
+ `""` (no compression), `"ZLIB"`, or `"GZIP"`. Defaults to no compression.
+
+ Returns:
+ A dataset, where each element is a (features, labels) tuple that corresponds
+ to a batch of `batch_size` CSV rows. The features dictionary maps feature
+ column names to `Tensor`s containing the corresponding column data, and
+ labels is a `Tensor` containing the column data for the label column
+ specified by `label_name`.
+
+ Raises:
+ ValueError: If any of the arguments is malformed.
+ """
+ # Create dataset of all matching filenames
+ filenames = _get_file_names(file_pattern, False)
+ dataset = dataset_ops.Dataset.from_tensor_slices(filenames)
+ if shuffle:
+ dataset = dataset.shuffle(len(filenames), shuffle_seed)
+
+ # Clean arguments; figure out column names and defaults
+
+ if column_names is None:
+ if not header:
+ raise ValueError("Cannot infer column names without a header line.")
+ # If column names are not provided, infer from the header lines
+ column_names = _infer_column_names(filenames, field_delim, use_quote_delim)
+ if len(column_names) != len(set(column_names)):
+ raise ValueError("Cannot have duplicate column names.")
+
+ if select_columns is not None:
+ select_columns = _get_sorted_col_indices(select_columns, column_names)
+
+ if column_defaults is not None:
+ column_defaults = [
+ constant_op.constant([], dtype=x) if x in _ACCEPTABLE_CSV_TYPES else x
+ for x in column_defaults
+ ]
+ else:
+ # If column defaults are not provided, infer from records at graph
+ # construction time
+ column_defaults = _infer_column_defaults(
+ filenames, len(column_names), field_delim, use_quote_delim, na_value,
+ header, num_rows_for_inference, select_columns)
+
+ if select_columns is not None and len(column_defaults) != len(select_columns):
+ raise ValueError(
+ "If specified, column_defaults and select_columns must have same "
+ "length."
+ )
+ if select_columns is not None and len(column_names) > len(select_columns):
+ # Pick the relevant subset of column names
+ column_names = [column_names[i] for i in select_columns]
+
+ if label_name is not None and label_name not in column_names:
+ raise ValueError("`label_name` provided must be one of the columns.")
+
+ def filename_to_dataset(filename):
+ return CsvDataset(
+ filename,
+ record_defaults=column_defaults,
+ field_delim=field_delim,
+ use_quote_delim=use_quote_delim,
+ na_value=na_value,
+ select_cols=select_columns,
+ header=header,
+ compression_type=compression_type,
+ )
+
+ def map_fn(*columns):
+ """Organizes columns into a features dictionary.
+
+ Args:
+ *columns: list of `Tensor`s corresponding to one csv record.
+ Returns:
+ An OrderedDict of feature names to values for that particular record. If
+ label_name is provided, extracts the label feature to be returned as the
+ second element of the tuple.
+ """
+ features = collections.OrderedDict(zip(column_names, columns))
+ if label_name is not None:
+ label = features.pop(label_name)
+ return features, label
+ return features
+
+ # Read files sequentially (if num_parallel_reads=1) or in parallel
+ dataset = dataset.apply(
+ interleave_ops.parallel_interleave(
+ filename_to_dataset, cycle_length=num_parallel_reads, sloppy=sloppy))
+
+ dataset = _maybe_shuffle_and_repeat(
+ dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed)
+
+ # Apply batch before map for perf, because map has high overhead relative
+ # to the size of the computation in each map.
+ # NOTE(mrry): We set `drop_remainder=True` when `num_epochs is None` to
+ # improve the shape inference, because it makes the batch dimension static.
+ # It is safe to do this because in that case we are repeating the input
+ # indefinitely, and all batches will be full-sized.
+ dataset = dataset.batch(batch_size=batch_size,
+ drop_remainder=num_epochs is None)
+ dataset = dataset_ops.MapDataset(
+ dataset, map_fn, use_inter_op_parallelism=False)
+ dataset = dataset.prefetch(prefetch_buffer_size)
+
+ return dataset
+
+
+_DEFAULT_READER_BUFFER_SIZE_BYTES = 4 * 1024 * 1024 # 4 MB
+
+
+@tf_export("data.experimental.CsvDataset")
+class CsvDataset(dataset_ops.DatasetSource):
+ """A Dataset comprising lines from one or more CSV files."""
+
+ def __init__(self,
+ filenames,
+ record_defaults,
+ compression_type=None,
+ buffer_size=None,
+ header=False,
+ field_delim=",",
+ use_quote_delim=True,
+ na_value="",
+ select_cols=None):
+ """Creates a `CsvDataset` by reading and decoding CSV files.
+
+ The elements of this dataset correspond to records from the file(s).
+ RFC 4180 format is expected for CSV files
+ (https://tools.ietf.org/html/rfc4180)
+ Note that we allow leading and trailing spaces with int or float field.
+
+
+ For example, suppose we have a file 'my_file0.csv' with four CSV columns of
+ different data types:
+ ```
+ abcdefg,4.28E10,5.55E6,12
+ hijklmn,-5.3E14,,2
+ ```
+
+ We can construct a CsvDataset from it as follows:
+ ```python
+ dataset = tf.data.experimental.CsvDataset(
+ "my_file*.csv",
+ [tf.float32, # Required field, use dtype or empty tensor
+ tf.constant([0.0], dtype=tf.float32), # Optional field, default to 0.0
+ tf.int32, # Required field, use dtype or empty tensor
+ ],
+ select_cols=[1,2,3] # Only parse last three columns
+ )
+ ```
+
+ The expected output of its iterations is:
+ ```python
+ next_element = dataset.make_one_shot_iterator().get_next()
+ with tf.Session() as sess:
+ while True:
+ try:
+ print(sess.run(next_element))
+ except tf.errors.OutOfRangeError:
+ break
+
+ >> (4.28e10, 5.55e6, 12)
+ >> (-5.3e14, 0.0, 2)
+ ```
+
+ Args:
+ filenames: A `tf.string` tensor containing one or more filenames.
+ record_defaults: A list of default values for the CSV fields. Each item in
+ the list is either a valid CSV `DType` (float32, float64, int32, int64,
+ string), or a `Tensor` object with one of the above types. One per
+ column of CSV data, with either a scalar `Tensor` default value for the
+ column if it is optional, or `DType` or empty `Tensor` if required. If
+ both this and `select_columns` are specified, these must have the same
+ lengths, and `column_defaults` is assumed to be sorted in order of
+ increasing column index.
+ compression_type: (Optional.) A `tf.string` scalar evaluating to one of
+ `""` (no compression), `"ZLIB"`, or `"GZIP"`. Defaults to no
+ compression.
+ buffer_size: (Optional.) A `tf.int64` scalar denoting the number of bytes
+ to buffer while reading files. Defaults to 4MB.
+ header: (Optional.) A `tf.bool` scalar indicating whether the CSV file(s)
+ have header line(s) that should be skipped when parsing. Defaults to
+ `False`.
+ field_delim: (Optional.) A `tf.string` scalar containing the delimiter
+ character that separates fields in a record. Defaults to `","`.
+ use_quote_delim: (Optional.) A `tf.bool` scalar. If `False`, treats
+ double quotation marks as regular characters inside of string fields
+ (ignoring RFC 4180, Section 2, Bullet 5). Defaults to `True`.
+ na_value: (Optional.) A `tf.string` scalar indicating a value that will
+ be treated as NA/NaN.
+ select_cols: (Optional.) A sorted list of column indices to select from
+ the input data. If specified, only this subset of columns will be
+ parsed. Defaults to parsing all columns.
+ """
+ super(CsvDataset, self).__init__()
+ self._filenames = ops.convert_to_tensor(
+ filenames, dtype=dtypes.string, name="filenames")
+ self._compression_type = convert.optional_param_to_tensor(
+ "compression_type",
+ compression_type,
+ argument_default="",
+ argument_dtype=dtypes.string)
+ record_defaults = [
+ constant_op.constant([], dtype=x) if x in _ACCEPTABLE_CSV_TYPES else x
+ for x in record_defaults
+ ]
+ self._record_defaults = ops.convert_n_to_tensor(
+ record_defaults, name="record_defaults")
+ self._buffer_size = convert.optional_param_to_tensor(
+ "buffer_size", buffer_size, _DEFAULT_READER_BUFFER_SIZE_BYTES)
+ self._header = ops.convert_to_tensor(
+ header, dtype=dtypes.bool, name="header")
+ self._field_delim = ops.convert_to_tensor(
+ field_delim, dtype=dtypes.string, name="field_delim")
+ self._use_quote_delim = ops.convert_to_tensor(
+ use_quote_delim, dtype=dtypes.bool, name="use_quote_delim")
+ self._na_value = ops.convert_to_tensor(
+ na_value, dtype=dtypes.string, name="na_value")
+ self._select_cols = convert.optional_param_to_tensor(
+ "select_cols",
+ select_cols,
+ argument_default=[],
+ argument_dtype=dtypes.int64,
+ )
+ self._output_shapes = tuple(
+ tensor_shape.scalar() for _ in range(len(record_defaults)))
+ self._output_types = tuple(d.dtype for d in self._record_defaults)
+ self._output_classes = tuple(
+ ops.Tensor for _ in range(len(record_defaults)))
+
+ def _as_variant_tensor(self):
+ # Constructs graph node for the dataset op.
+ return gen_experimental_dataset_ops.experimental_csv_dataset(
+ filenames=self._filenames,
+ record_defaults=self._record_defaults,
+ buffer_size=self._buffer_size,
+ header=self._header,
+ output_shapes=self._output_shapes,
+ field_delim=self._field_delim,
+ use_quote_delim=self._use_quote_delim,
+ na_value=self._na_value,
+ select_cols=self._select_cols,
+ compression_type=self._compression_type,
+ )
+
+ @property
+ def output_types(self):
+ return self._output_types
+
+ @property
+ def output_shapes(self):
+ return self._output_shapes
+
+ @property
+ def output_classes(self):
+ return self._output_classes
+
+
+@tf_export("data.experimental.make_batched_features_dataset")
+def make_batched_features_dataset(file_pattern,
+ batch_size,
+ features,
+ reader=core_readers.TFRecordDataset,
+ label_key=None,
+ reader_args=None,
+ num_epochs=None,
+ shuffle=True,
+ shuffle_buffer_size=10000,
+ shuffle_seed=None,
+ prefetch_buffer_size=optimization.AUTOTUNE,
+ reader_num_threads=1,
+ parser_num_threads=2,
+ sloppy_ordering=False,
+ drop_final_batch=False):
+ """Returns a `Dataset` of feature dictionaries from `Example` protos.
+
+ If label_key argument is provided, returns a `Dataset` of tuple
+ comprising of feature dictionaries and label.
+
+ Example:
+
+ ```
+ serialized_examples = [
+ features {
+ feature { key: "age" value { int64_list { value: [ 0 ] } } }
+ feature { key: "gender" value { bytes_list { value: [ "f" ] } } }
+ feature { key: "kws" value { bytes_list { value: [ "code", "art" ] } } }
+ },
+ features {
+ feature { key: "age" value { int64_list { value: [] } } }
+ feature { key: "gender" value { bytes_list { value: [ "f" ] } } }
+ feature { key: "kws" value { bytes_list { value: [ "sports" ] } } }
+ }
+ ]
+ ```
+
+ We can use arguments:
+
+ ```
+ features: {
+ "age": FixedLenFeature([], dtype=tf.int64, default_value=-1),
+ "gender": FixedLenFeature([], dtype=tf.string),
+ "kws": VarLenFeature(dtype=tf.string),
+ }
+ ```
+
+ And the expected output is:
+
+ ```python
+ {
+ "age": [[0], [-1]],
+ "gender": [["f"], ["f"]],
+ "kws": SparseTensor(
+ indices=[[0, 0], [0, 1], [1, 0]],
+ values=["code", "art", "sports"]
+ dense_shape=[2, 2]),
+ }
+ ```
+
+ Args:
+ file_pattern: List of files or patterns of file paths containing
+ `Example` records. See `tf.gfile.Glob` for pattern rules.
+ batch_size: An int representing the number of records to combine
+ in a single batch.
+ features: A `dict` mapping feature keys to `FixedLenFeature` or
+ `VarLenFeature` values. See `tf.parse_example`.
+ reader: A function or class that can be
+ called with a `filenames` tensor and (optional) `reader_args` and returns
+ a `Dataset` of `Example` tensors. Defaults to `tf.data.TFRecordDataset`.
+ label_key: (Optional) A string corresponding to the key labels are stored in
+ `tf.Examples`. If provided, it must be one of the `features` key,
+ otherwise results in `ValueError`.
+ reader_args: Additional arguments to pass to the reader class.
+ num_epochs: Integer specifying the number of times to read through the
+ dataset. If None, cycles through the dataset forever. Defaults to `None`.
+ shuffle: A boolean, indicates whether the input should be shuffled. Defaults
+ to `True`.
+ shuffle_buffer_size: Buffer size of the ShuffleDataset. A large capacity
+ ensures better shuffling but would increase memory usage and startup time.
+ shuffle_seed: Randomization seed to use for shuffling.
+ prefetch_buffer_size: Number of feature batches to prefetch in order to
+ improve performance. Recommended value is the number of batches consumed
+ per training step. Defaults to auto-tune.
+ reader_num_threads: Number of threads used to read `Example` records. If >1,
+ the results will be interleaved.
+ parser_num_threads: Number of threads to use for parsing `Example` tensors
+ into a dictionary of `Feature` tensors.
+ sloppy_ordering: If `True`, reading performance will be improved at
+ the cost of non-deterministic ordering. If `False`, the order of elements
+ produced is deterministic prior to shuffling (elements are still
+ randomized if `shuffle=True`. Note that if the seed is set, then order
+ of elements after shuffling is deterministic). Defaults to `False`.
+ drop_final_batch: If `True`, and the batch size does not evenly divide the
+ input dataset size, the final smaller batch will be dropped. Defaults to
+ `False`.
+
+ Returns:
+ A dataset of `dict` elements, (or a tuple of `dict` elements and label).
+ Each `dict` maps feature keys to `Tensor` or `SparseTensor` objects.
+
+ Raises:
+ ValueError: If `label_key` is not one of the `features` keys.
+ """
+ # Create dataset of all matching filenames
+ filenames = _get_file_names(file_pattern, False)
+ dataset = dataset_ops.Dataset.from_tensor_slices(filenames)
+ if shuffle:
+ dataset = dataset.shuffle(len(filenames), shuffle_seed)
+
+ # Read `Example` records from files as tensor objects.
+ if reader_args is None:
+ reader_args = []
+
+ # Read files sequentially (if reader_num_threads=1) or in parallel
+ dataset = dataset.apply(
+ interleave_ops.parallel_interleave(
+ lambda filename: reader(filename, *reader_args),
+ cycle_length=reader_num_threads,
+ sloppy=sloppy_ordering))
+
+ # Extract values if the `Example` tensors are stored as key-value tuples.
+ if dataset.output_types == (dtypes.string, dtypes.string):
+ dataset = dataset_ops.MapDataset(
+ dataset, lambda _, v: v, use_inter_op_parallelism=False)
+
+ # Apply dataset repeat and shuffle transformations.
+ dataset = _maybe_shuffle_and_repeat(
+ dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed)
+
+ # NOTE(mrry): We set `drop_remainder=True` when `num_epochs is None` to
+ # improve the shape inference, because it makes the batch dimension static.
+ # It is safe to do this because in that case we are repeating the input
+ # indefinitely, and all batches will be full-sized.
+ dataset = dataset.batch(
+ batch_size, drop_remainder=drop_final_batch or num_epochs is None)
+
+ # Parse `Example` tensors to a dictionary of `Feature` tensors.
+ dataset = dataset.apply(
+ parsing_ops.parse_example_dataset(
+ features, num_parallel_calls=parser_num_threads))
+
+ if label_key:
+ if label_key not in features:
+ raise ValueError(
+ "The `label_key` provided (%r) must be one of the `features` keys." %
+ label_key)
+ dataset = dataset.map(lambda x: (x, x.pop(label_key)))
+
+ dataset = dataset.prefetch(prefetch_buffer_size)
+ return dataset
+
+
+def _get_file_names(file_pattern, shuffle):
+ """Parse list of file names from pattern, optionally shuffled.
+
+ Args:
+ file_pattern: File glob pattern, or list of glob patterns.
+ shuffle: Whether to shuffle the order of file names.
+
+ Returns:
+ List of file names matching `file_pattern`.
+
+ Raises:
+ ValueError: If `file_pattern` is empty, or pattern matches no files.
+ """
+ if isinstance(file_pattern, list):
+ if not file_pattern:
+ raise ValueError("File pattern is empty.")
+ file_names = []
+ for entry in file_pattern:
+ file_names.extend(gfile.Glob(entry))
+ else:
+ file_names = list(gfile.Glob(file_pattern))
+
+ if not file_names:
+ raise ValueError("No files match %s." % file_pattern)
+
+ # Sort files so it will be deterministic for unit tests.
+ if not shuffle:
+ file_names = sorted(file_names)
+ return file_names
+
+
+@tf_export("data.experimental.SqlDataset")
+class SqlDataset(dataset_ops.DatasetSource):
+ """A `Dataset` consisting of the results from a SQL query."""
+
+ def __init__(self, driver_name, data_source_name, query, output_types):
+ """Creates a `SqlDataset`.
+
+ `SqlDataset` allows a user to read data from the result set of a SQL query.
+ For example:
+
+ ```python
+ dataset = tf.data.experimental.SqlDataset("sqlite", "/foo/bar.sqlite3",
+ "SELECT name, age FROM people",
+ (tf.string, tf.int32))
+ iterator = dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+ # Prints the rows of the result set of the above query.
+ while True:
+ try:
+ print(sess.run(next_element))
+ except tf.errors.OutOfRangeError:
+ break
+ ```
+
+ Args:
+ driver_name: A 0-D `tf.string` tensor containing the database type.
+ Currently, the only supported value is 'sqlite'.
+ data_source_name: A 0-D `tf.string` tensor containing a connection string
+ to connect to the database.
+ query: A 0-D `tf.string` tensor containing the SQL query to execute.
+ output_types: A tuple of `tf.DType` objects representing the types of the
+ columns returned by `query`.
+ """
+ super(SqlDataset, self).__init__()
+ self._driver_name = ops.convert_to_tensor(
+ driver_name, dtype=dtypes.string, name="driver_name")
+ self._data_source_name = ops.convert_to_tensor(
+ data_source_name, dtype=dtypes.string, name="data_source_name")
+ self._query = ops.convert_to_tensor(
+ query, dtype=dtypes.string, name="query")
+ self._output_types = output_types
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.sql_dataset(self._driver_name,
+ self._data_source_name, self._query,
+ nest.flatten(self.output_types),
+ nest.flatten(self.output_shapes))
+
+ @property
+ def output_classes(self):
+ return nest.map_structure(lambda _: ops.Tensor, self._output_types)
+
+ @property
+ def output_shapes(self):
+ return nest.map_structure(lambda _: tensor_shape.TensorShape([]),
+ self._output_types)
+
+ @property
+ def output_types(self):
+ return self._output_types
diff --git a/tensorflow/python/data/experimental/ops/resampling.py b/tensorflow/python/data/experimental/ops/resampling.py
new file mode 100644
index 0000000000..3a3040ae9a
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/resampling.py
@@ -0,0 +1,296 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Resampling dataset transformations."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.ops import batching
+from tensorflow.python.data.experimental.ops import interleave_ops
+from tensorflow.python.data.experimental.ops import scan_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import logging_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("data.experimental.rejection_resample")
+def rejection_resample(class_func, target_dist, initial_dist=None, seed=None):
+ """A transformation that resamples a dataset to achieve a target distribution.
+
+ **NOTE** Resampling is performed via rejection sampling; some fraction
+ of the input values will be dropped.
+
+ Args:
+ class_func: A function mapping an element of the input dataset to a scalar
+ `tf.int32` tensor. Values should be in `[0, num_classes)`.
+ target_dist: A floating point type tensor, shaped `[num_classes]`.
+ initial_dist: (Optional.) A floating point type tensor, shaped
+ `[num_classes]`. If not provided, the true class distribution is
+ estimated live in a streaming fashion.
+ seed: (Optional.) Python integer seed for the resampler.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+ def _apply_fn(dataset):
+ """Function from `Dataset` to `Dataset` that applies the transformation."""
+ target_dist_t = ops.convert_to_tensor(target_dist, name="target_dist")
+ class_values_ds = dataset.map(class_func)
+
+ # Get initial distribution.
+ if initial_dist is not None:
+ initial_dist_t = ops.convert_to_tensor(initial_dist, name="initial_dist")
+ acceptance_dist, prob_of_original = (
+ _calculate_acceptance_probs_with_mixing(initial_dist_t,
+ target_dist_t))
+ initial_dist_ds = dataset_ops.Dataset.from_tensors(
+ initial_dist_t).repeat()
+ acceptance_dist_ds = dataset_ops.Dataset.from_tensors(
+ acceptance_dist).repeat()
+ prob_of_original_ds = dataset_ops.Dataset.from_tensors(
+ prob_of_original).repeat()
+ else:
+ initial_dist_ds = _estimate_initial_dist_ds(
+ target_dist_t, class_values_ds)
+ acceptance_and_original_prob_ds = initial_dist_ds.map(
+ lambda initial: _calculate_acceptance_probs_with_mixing( # pylint: disable=g-long-lambda
+ initial, target_dist_t))
+ acceptance_dist_ds = acceptance_and_original_prob_ds.map(
+ lambda accept_prob, _: accept_prob)
+ prob_of_original_ds = acceptance_and_original_prob_ds.map(
+ lambda _, prob_original: prob_original)
+ filtered_ds = _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds,
+ class_values_ds, seed)
+ # Prefetch filtered dataset for speed.
+ filtered_ds = filtered_ds.prefetch(3)
+
+ prob_original_static = _get_prob_original_static(
+ initial_dist_t, target_dist_t) if initial_dist is not None else None
+ if prob_original_static == 1:
+ return dataset_ops.Dataset.zip((class_values_ds, dataset))
+ elif prob_original_static == 0:
+ return filtered_ds
+ else:
+ return interleave_ops.sample_from_datasets(
+ [dataset_ops.Dataset.zip((class_values_ds, dataset)), filtered_ds],
+ weights=prob_of_original_ds.map(lambda prob: [(prob, 1.0 - prob)]),
+ seed=seed)
+
+ return _apply_fn
+
+
+def _get_prob_original_static(initial_dist_t, target_dist_t):
+ """Returns the static probability of sampling from the original.
+
+ `tensor_util.constant_value(prob_of_original)` returns `None` if it encounters
+ an Op that it isn't defined for. We have some custom logic to avoid this.
+
+ Args:
+ initial_dist_t: A tensor of the initial distribution.
+ target_dist_t: A tensor of the target distribution.
+
+ Returns:
+ The probability of sampling from the original distribution as a constant,
+ if it is a constant, or `None`.
+ """
+ init_static = tensor_util.constant_value(initial_dist_t)
+ target_static = tensor_util.constant_value(target_dist_t)
+
+ if init_static is None or target_static is None:
+ return None
+ else:
+ return np.min(target_static / init_static)
+
+
+def _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds, class_values_ds,
+ seed):
+ """Filters a dataset based on per-class acceptance probabilities.
+
+ Args:
+ dataset: The dataset to be filtered.
+ acceptance_dist_ds: A dataset of acceptance probabilities.
+ initial_dist_ds: A dataset of the initial probability distribution, given or
+ estimated.
+ class_values_ds: A dataset of the corresponding classes.
+ seed: (Optional.) Python integer seed for the resampler.
+
+ Returns:
+ A dataset of (class value, data) after filtering.
+ """
+ def maybe_warn_on_large_rejection(accept_dist, initial_dist):
+ proportion_rejected = math_ops.reduce_sum((1 - accept_dist) * initial_dist)
+ return control_flow_ops.cond(
+ math_ops.less(proportion_rejected, .5),
+ lambda: accept_dist,
+ lambda: logging_ops.Print( # pylint: disable=g-long-lambda
+ accept_dist, [proportion_rejected, initial_dist, accept_dist],
+ message="Proportion of examples rejected by sampler is high: ",
+ summarize=100,
+ first_n=10))
+
+ acceptance_dist_ds = (dataset_ops.Dataset.zip((acceptance_dist_ds,
+ initial_dist_ds))
+ .map(maybe_warn_on_large_rejection))
+
+ def _gather_and_copy(class_val, acceptance_prob, data):
+ return class_val, array_ops.gather(acceptance_prob, class_val), data
+
+ current_probabilities_and_class_and_data_ds = dataset_ops.Dataset.zip(
+ (class_values_ds, acceptance_dist_ds, dataset)).map(_gather_and_copy)
+ filtered_ds = (
+ current_probabilities_and_class_and_data_ds
+ .filter(lambda _1, p, _2: random_ops.random_uniform([], seed=seed) < p))
+ return filtered_ds.map(lambda class_value, _, data: (class_value, data))
+
+
+def _estimate_initial_dist_ds(
+ target_dist_t, class_values_ds, dist_estimation_batch_size=32,
+ smoothing_constant=10):
+ num_classes = (target_dist_t.shape[0].value or
+ array_ops.shape(target_dist_t)[0])
+ initial_examples_per_class_seen = array_ops.fill(
+ [num_classes], np.int64(smoothing_constant))
+
+ def update_estimate_and_tile(num_examples_per_class_seen, c):
+ updated_examples_per_class_seen, dist = _estimate_data_distribution(
+ c, num_examples_per_class_seen)
+ tiled_dist = array_ops.tile(
+ array_ops.expand_dims(dist, 0), [dist_estimation_batch_size, 1])
+ return updated_examples_per_class_seen, tiled_dist
+
+ initial_dist_ds = (class_values_ds.batch(dist_estimation_batch_size)
+ .apply(scan_ops.scan(initial_examples_per_class_seen,
+ update_estimate_and_tile))
+ .apply(batching.unbatch()))
+
+ return initial_dist_ds
+
+
+def _get_target_to_initial_ratio(initial_probs, target_probs):
+ # Add tiny to initial_probs to avoid divide by zero.
+ denom = (initial_probs + np.finfo(initial_probs.dtype.as_numpy_dtype).tiny)
+ return target_probs / denom
+
+
+def _estimate_data_distribution(c, num_examples_per_class_seen):
+ """Estimate data distribution as labels are seen.
+
+ Args:
+ c: The class labels. Type `int32`, shape `[batch_size]`.
+ num_examples_per_class_seen: Type `int64`, shape `[num_classes]`,
+ containing counts.
+
+ Returns:
+ num_examples_per_lass_seen: Updated counts. Type `int64`, shape
+ `[num_classes]`.
+ dist: The updated distribution. Type `float32`, shape `[num_classes]`.
+ """
+ num_classes = num_examples_per_class_seen.get_shape()[0].value
+ # Update the class-count based on what labels are seen in batch.
+ num_examples_per_class_seen = math_ops.add(
+ num_examples_per_class_seen, math_ops.reduce_sum(
+ array_ops.one_hot(c, num_classes, dtype=dtypes.int64), 0))
+ init_prob_estimate = math_ops.truediv(
+ num_examples_per_class_seen,
+ math_ops.reduce_sum(num_examples_per_class_seen))
+ dist = math_ops.cast(init_prob_estimate, dtypes.float32)
+ return num_examples_per_class_seen, dist
+
+
+def _calculate_acceptance_probs_with_mixing(initial_probs, target_probs):
+ """Calculates the acceptance probabilities and mixing ratio.
+
+ In this case, we assume that we can *either* sample from the original data
+ distribution with probability `m`, or sample from a reshaped distribution
+ that comes from rejection sampling on the original distribution. This
+ rejection sampling is done on a per-class basis, with `a_i` representing the
+ probability of accepting data from class `i`.
+
+ This method is based on solving the following analysis for the reshaped
+ distribution:
+
+ Let F be the probability of a rejection (on any example).
+ Let p_i be the proportion of examples in the data in class i (init_probs)
+ Let a_i is the rate the rejection sampler should *accept* class i
+ Let t_i is the target proportion in the minibatches for class i (target_probs)
+
+ ```
+ F = sum_i(p_i * (1-a_i))
+ = 1 - sum_i(p_i * a_i) using sum_i(p_i) = 1
+ ```
+
+ An example with class `i` will be accepted if `k` rejections occur, then an
+ example with class `i` is seen by the rejector, and it is accepted. This can
+ be written as follows:
+
+ ```
+ t_i = sum_k=0^inf(F^k * p_i * a_i)
+ = p_i * a_j / (1 - F) using geometric series identity, since 0 <= F < 1
+ = p_i * a_i / sum_j(p_j * a_j) using F from above
+ ```
+
+ Note that the following constraints hold:
+ ```
+ 0 <= p_i <= 1, sum_i(p_i) = 1
+ 0 <= a_i <= 1
+ 0 <= t_i <= 1, sum_i(t_i) = 1
+ ```
+
+ A solution for a_i in terms of the other variables is the following:
+ ```a_i = (t_i / p_i) / max_i[t_i / p_i]```
+
+ If we try to minimize the amount of data rejected, we get the following:
+
+ M_max = max_i [ t_i / p_i ]
+ M_min = min_i [ t_i / p_i ]
+
+ The desired probability of accepting data if it comes from class `i`:
+
+ a_i = (t_i/p_i - m) / (M_max - m)
+
+ The desired probability of pulling a data element from the original dataset,
+ rather than the filtered one:
+
+ m = M_min
+
+ Args:
+ initial_probs: A Tensor of the initial probability distribution, given or
+ estimated.
+ target_probs: A Tensor of the corresponding classes.
+
+ Returns:
+ (A 1D Tensor with the per-class acceptance probabilities, the desired
+ probability of pull from the original distribution.)
+ """
+ ratio_l = _get_target_to_initial_ratio(initial_probs, target_probs)
+ max_ratio = math_ops.reduce_max(ratio_l)
+ min_ratio = math_ops.reduce_min(ratio_l)
+
+ # Target prob to sample from original distribution.
+ m = min_ratio
+
+ # TODO(joelshor): Simplify fraction, if possible.
+ a_i = (ratio_l - m) / (max_ratio - m)
+ return a_i, m
diff --git a/tensorflow/python/data/experimental/ops/scan_ops.py b/tensorflow/python/data/experimental/ops/scan_ops.py
new file mode 100644
index 0000000000..e05e7c5a18
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/scan_ops.py
@@ -0,0 +1,177 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Scan dataset transformation."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import nest
+from tensorflow.python.data.util import sparse
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+class _ScanDataset(dataset_ops.UnaryDataset):
+ """A dataset that scans a function across its input."""
+
+ def __init__(self, input_dataset, initial_state, scan_func):
+ """See `scan()` for details."""
+ super(_ScanDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+
+ with ops.name_scope("initial_state"):
+ # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
+ # values to tensors.
+ self._initial_state = nest.pack_sequence_as(initial_state, [
+ sparse_tensor.SparseTensor.from_value(t)
+ if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(
+ t, name="component_%d" % i)
+ for i, t in enumerate(nest.flatten(initial_state))
+ ])
+
+ # Compute initial values for the state classes, shapes and types based on
+ # the initial state. The shapes may be refined by running `tf_scan_func` one
+ # or more times below.
+ self._state_classes = sparse.get_classes(self._initial_state)
+ self._state_shapes = nest.pack_sequence_as(
+ self._initial_state,
+ [t.get_shape() for t in nest.flatten(self._initial_state)])
+ self._state_types = nest.pack_sequence_as(
+ self._initial_state,
+ [t.dtype for t in nest.flatten(self._initial_state)])
+
+ # Will be populated by calling `tf_scan_func`.
+ self._output_classes = None
+ self._output_shapes = None
+ self._output_types = None
+
+ # Iteratively rerun the scan function until reaching a fixed point on
+ # `self._state_shapes`.
+ need_to_rerun = True
+ while need_to_rerun:
+
+ wrapped_func = dataset_ops.StructuredFunctionWrapper(
+ scan_func,
+ "tf.data.experimental.scan()",
+ input_classes=(self._state_classes, input_dataset.output_classes),
+ input_shapes=(self._state_shapes, input_dataset.output_shapes),
+ input_types=(self._state_types, input_dataset.output_types),
+ add_to_graph=False)
+ if not (
+ isinstance(wrapped_func.output_types, collections.Sequence) and
+ len(wrapped_func.output_types) == 2):
+ raise TypeError("The scan function must return a pair comprising the "
+ "new state and the output value.")
+
+ new_state_classes, self._output_classes = wrapped_func.output_classes
+
+ # Extract and validate class information from the returned values.
+ for new_state_class, state_class in zip(
+ nest.flatten(new_state_classes),
+ nest.flatten(self._state_classes)):
+ if not issubclass(new_state_class, state_class):
+ raise TypeError(
+ "The element classes for the new state must match the initial "
+ "state. Expected %s; got %s." %
+ (self._state_classes, new_state_classes))
+
+ # Extract and validate type information from the returned values.
+ new_state_types, self._output_types = wrapped_func.output_types
+ for new_state_type, state_type in zip(
+ nest.flatten(new_state_types), nest.flatten(self._state_types)):
+ if new_state_type != state_type:
+ raise TypeError(
+ "The element types for the new state must match the initial "
+ "state. Expected %s; got %s." %
+ (self._state_types, new_state_types))
+
+ # Extract shape information from the returned values.
+ new_state_shapes, self._output_shapes = wrapped_func.output_shapes
+
+ flat_state_shapes = nest.flatten(self._state_shapes)
+ flat_new_state_shapes = nest.flatten(new_state_shapes)
+ weakened_state_shapes = [
+ original.most_specific_compatible_shape(new)
+ for original, new in zip(flat_state_shapes, flat_new_state_shapes)
+ ]
+
+ need_to_rerun = False
+ for original_shape, weakened_shape in zip(flat_state_shapes,
+ weakened_state_shapes):
+ if original_shape.ndims is not None and (
+ weakened_shape.ndims is None or
+ original_shape.as_list() != weakened_shape.as_list()):
+ need_to_rerun = True
+ break
+
+ if need_to_rerun:
+ self._state_shapes = nest.pack_sequence_as(self._state_shapes,
+ weakened_state_shapes)
+
+ self._scan_func = wrapped_func.function
+ self._scan_func.add_to_graph(ops.get_default_graph())
+
+ def _as_variant_tensor(self):
+ input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access
+ return gen_dataset_ops.scan_dataset(
+ input_t,
+ nest.flatten(sparse.serialize_sparse_tensors(self._initial_state)),
+ self._scan_func.captured_inputs,
+ f=self._scan_func,
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_classes(self):
+ return self._output_classes
+
+ @property
+ def output_shapes(self):
+ return self._output_shapes
+
+ @property
+ def output_types(self):
+ return self._output_types
+
+
+@tf_export("data.experimental.scan")
+def scan(initial_state, scan_func):
+ """A transformation that scans a function across an input dataset.
+
+ This transformation is a stateful relative of `tf.data.Dataset.map`.
+ In addition to mapping `scan_func` across the elements of the input dataset,
+ `scan()` accumulates one or more state tensors, whose initial values are
+ `initial_state`.
+
+ Args:
+ initial_state: A nested structure of tensors, representing the initial state
+ of the accumulator.
+ scan_func: A function that maps `(old_state, input_element)` to
+ `(new_state, output_element). It must take two arguments and return a
+ pair of nested structures of tensors. The `new_state` must match the
+ structure of `initial_state`.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+ def _apply_fn(dataset):
+ return _ScanDataset(dataset, initial_state, scan_func)
+
+ return _apply_fn
diff --git a/tensorflow/python/data/experimental/ops/shuffle_ops.py b/tensorflow/python/data/experimental/ops/shuffle_ops.py
new file mode 100644
index 0000000000..a4307212da
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/shuffle_ops.py
@@ -0,0 +1,102 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Experimental shuffle ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import random_seed
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+class _ShuffleAndRepeatDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` that fuses `shuffle` and `repeat`."""
+
+ def __init__(self, input_dataset, buffer_size, count=None, seed=None):
+ super(_ShuffleAndRepeatDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+ self._buffer_size = ops.convert_to_tensor(
+ buffer_size, dtype=dtypes.int64, name="buffer_size")
+ if count is None:
+ self._count = constant_op.constant(-1, dtype=dtypes.int64, name="count")
+ else:
+ self._count = ops.convert_to_tensor(
+ count, dtype=dtypes.int64, name="count")
+ self._seed, self._seed2 = random_seed.get_seed(seed)
+
+ def _as_variant_tensor(self):
+ # pylint: disable=protected-access
+ input_resource = self._input_dataset._as_variant_tensor()
+ return gen_dataset_ops.shuffle_and_repeat_dataset(
+ input_resource,
+ buffer_size=self._buffer_size,
+ count=self._count,
+ seed=self._seed,
+ seed2=self._seed2,
+ **dataset_ops.flat_structure(self))
+ # pylint: enable=protected-access
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+
+@tf_export("data.experimental.shuffle_and_repeat")
+def shuffle_and_repeat(buffer_size, count=None, seed=None):
+ """Shuffles and repeats a Dataset returning a new permutation for each epoch.
+
+ `dataset.apply(tf.data.experimental.shuffle_and_repeat(buffer_size, count))`
+
+ is equivalent to
+
+ `dataset.shuffle(buffer_size, reshuffle_each_iteration=True).repeat(count)`
+
+ The difference is that the latter dataset is not serializable. So,
+ if you need to checkpoint an input pipeline with reshuffling you must use
+ this implementation.
+
+ Args:
+ buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the
+ maximum number elements that will be buffered when prefetching.
+ count: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
+ number of times the dataset should be repeated. The default behavior
+ (if `count` is `None` or `-1`) is for the dataset be repeated
+ indefinitely.
+ seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
+ random seed that will be used to create the distribution. See
+ `tf.set_random_seed` for behavior.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+
+ def _apply_fn(dataset): # pylint: disable=missing-docstring
+ return _ShuffleAndRepeatDataset(dataset, buffer_size, count, seed)
+
+ return _apply_fn
diff --git a/tensorflow/python/data/experimental/ops/stats_ops.py b/tensorflow/python/data/experimental/ops/stats_ops.py
new file mode 100644
index 0000000000..c918d223e8
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/stats_ops.py
@@ -0,0 +1,205 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Experimental API for gathering statistics from `tf.data` pipelines."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("data.experimental.StatsAggregator")
+class StatsAggregator(object):
+ """A stateful resource that aggregates statistics from one or more iterators.
+
+ To record statistics, use one of the custom transformation functions defined
+ in this module when defining your `tf.data.Dataset`. All statistics will be
+ aggregated by the `StatsAggregator` that is associated with a particular
+ iterator (see below). For example, to record the latency of producing each
+ element by iterating over a dataset:
+
+ ```python
+ dataset = ...
+ dataset = dataset.apply(tf.data.experimental.latency_stats("total_bytes"))
+ ```
+
+ To associate a `StatsAggregator` with a `tf.data.Dataset` object, use
+ the following pattern:
+
+ ```python
+ stats_aggregator = stats_ops.StatsAggregator()
+ dataset = ...
+
+ # Apply `set_stats_aggregator` to associate `dataset` with `stats_aggregator`.
+ dataset = dataset.apply(
+ tf.data.experimental.set_stats_aggregator(stats_aggregator))
+ iterator = dataset.make_one_shot_iterator()
+ ```
+
+ To get a protocol buffer summary of the currently aggregated statistics,
+ use the `StatsAggregator.get_summary()` tensor. The easiest way to do this
+ is to add the returned tensor to the `tf.GraphKeys.SUMMARIES` collection,
+ so that the summaries will be included with any existing summaries.
+
+ ```python
+ stats_aggregator = stats_ops.StatsAggregator()
+ # ...
+ stats_summary = stats_aggregator.get_summary()
+ tf.add_to_collection(tf.GraphKeys.SUMMARIES, stats_summary)
+ ```
+
+ Note: This interface is experimental and expected to change. In particular,
+ we expect to add other implementations of `StatsAggregator` that provide
+ different ways of exporting statistics, and add more types of statistics.
+ """
+
+ def __init__(self):
+ """Creates a `StatsAggregator`."""
+ self._resource = gen_dataset_ops.stats_aggregator_handle()
+
+ # TODO(b/116314787): Update this/add support for V2 summary API.
+ def get_summary(self):
+ """Returns a string `tf.Tensor` that summarizes the aggregated statistics.
+
+ The returned tensor will contain a serialized `tf.summary.Summary` protocol
+ buffer, which can be used with the standard TensorBoard logging facilities.
+
+ Returns:
+ A scalar string `tf.Tensor` that summarizes the aggregated statistics.
+ """
+ return gen_dataset_ops.stats_aggregator_summary(self._resource)
+
+
+class _SetStatsAggregatorDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` that acts as an identity, and sets given stats_aggregator."""
+
+ def __init__(self, input_dataset, stats_aggregator):
+ super(_SetStatsAggregatorDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+ self._stats_aggregator = stats_aggregator
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.set_stats_aggregator_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._stats_aggregator._resource, # pylint: disable=protected-access
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+
+@tf_export("data.experimental.set_stats_aggregator")
+def set_stats_aggregator(stats_aggregator):
+ """Set the given `stats_aggregator` for aggregating the input dataset stats.
+
+ Args:
+ stats_aggregator: A `tf.data.experimental.StatsAggregator` object.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+
+ def _apply_fn(dataset):
+ return _SetStatsAggregatorDataset(dataset, stats_aggregator)
+
+ return _apply_fn
+
+
+# TODO(b/38416882): Properly export in the `tf.data.experimental` API when
+# stable or make private / remove.
+def bytes_produced_stats(tag):
+ """Records the number of bytes produced by each element of the input dataset.
+
+ To consume the statistics, associate a `StatsAggregator` with the output
+ dataset.
+
+ Args:
+ tag: String. All statistics recorded by the returned transformation will
+ be associated with the given `tag`.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+
+ def _apply_fn(dataset):
+ return _StatsDataset(dataset, gen_dataset_ops.bytes_produced_stats_dataset,
+ tag)
+
+ return _apply_fn
+
+
+@tf_export("data.experimental.latency_stats")
+def latency_stats(tag):
+ """Records the latency of producing each element of the input dataset.
+
+ To consume the statistics, associate a `StatsAggregator` with the output
+ dataset.
+
+ Args:
+ tag: String. All statistics recorded by the returned transformation will
+ be associated with the given `tag`.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+
+ def _apply_fn(dataset):
+ return _StatsDataset(dataset, gen_dataset_ops.latency_stats_dataset, tag)
+
+ return _apply_fn
+
+
+class _StatsDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` that acts as an identity, and also records statistics."""
+
+ def __init__(self, input_dataset, op_function, tag):
+ super(_StatsDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+ self._op_function = op_function
+ self._tag = ops.convert_to_tensor(tag, dtype=dtypes.string)
+
+ def _as_variant_tensor(self):
+ return self._op_function(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._tag,
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
diff --git a/tensorflow/python/data/experimental/ops/threadpool.py b/tensorflow/python/data/experimental/ops/threadpool.py
new file mode 100644
index 0000000000..3ea017c6e8
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/threadpool.py
@@ -0,0 +1,104 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Experimental API for controlling threading in `tf.data` pipelines."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import threading
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.eager import context
+from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
+from tensorflow.python.ops import resource_variable_ops
+
+_uid_counter = 0
+_uid_lock = threading.Lock()
+
+
+def _generate_shared_name(prefix):
+ with _uid_lock:
+ global _uid_counter
+ uid = _uid_counter
+ _uid_counter += 1
+ return "{}{}".format(prefix, uid)
+
+
+# TODO(b/73383364): Properly export in the `tf.data.experimental` API when
+# stable or make private / remove.
+class PrivateThreadPool(object):
+ """A stateful resource that represents a private thread pool."""
+
+ def __init__(self, num_threads, display_name=None,
+ max_intra_op_parallelism=1):
+ """Creates a `PrivateThreadPool` with the given number of threads."""
+ if context.executing_eagerly():
+ shared_name = _generate_shared_name("privatethreadpool")
+ self._resource = ged_ops.experimental_thread_pool_handle(
+ num_threads=num_threads,
+ max_intra_op_parallelism=max_intra_op_parallelism,
+ display_name=display_name,
+ shared_name=shared_name)
+ self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
+ handle=self._resource, handle_device=context.context().device_name)
+ else:
+ self._resource = ged_ops.experimental_thread_pool_handle(
+ num_threads=num_threads,
+ max_intra_op_parallelism=max_intra_op_parallelism,
+ display_name=display_name)
+
+
+class _ThreadPoolDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` that acts as an identity, and sets a custom threadpool."""
+
+ def __init__(self, input_dataset, thread_pool):
+ super(_ThreadPoolDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+ self._thread_pool = thread_pool
+
+ def _as_variant_tensor(self):
+ return ged_ops.experimental_thread_pool_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._thread_pool._resource, # pylint: disable=protected-access
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+
+# TODO(b/73383364): Properly export in the `tf.data.experimental` API when
+# stable or make private / remove.
+def override_threadpool(dataset, thread_pool):
+ """Returns a new dataset that uses the given thread pool for its operations.
+
+ Args:
+ dataset: A `tf.data.Dataset` object.
+ thread_pool: A `PrivateThreadPool` object.
+
+ Returns:
+ A dataset containing the same values as `dataset`, but which uses
+ `thread_pool` to compute any of its parallel operations (such as
+ `tf.data.Dataset.map`).
+ """
+ return _ThreadPoolDataset(dataset, thread_pool)
diff --git a/tensorflow/python/data/experimental/ops/unique.py b/tensorflow/python/data/experimental/ops/unique.py
new file mode 100644
index 0000000000..2a7775c456
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/unique.py
@@ -0,0 +1,79 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Unique element dataset transformations."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import gen_experimental_dataset_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("data.experimental.unique")
+def unique():
+ """Creates a `Dataset` from another `Dataset`, discarding duplicates.
+
+ Use this transformation to produce a dataset that contains one instance of
+ each unique element in the input. For example:
+
+ ```python
+ dataset = tf.data.Dataset.from_tensor_slices([1, 37, 2, 37, 2, 1])
+
+ # Using `unique()` will drop the duplicate elements.
+ dataset = dataset.apply(tf.data.experimental.unique()) # ==> { 1, 37, 2 }
+ ```
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+
+ def _apply_fn(dataset):
+ return _UniqueDataset(dataset)
+
+ return _apply_fn
+
+
+class _UniqueDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` contains the unique elements from its input."""
+
+ def __init__(self, input_dataset):
+ """See `unique()` for details."""
+ super(_UniqueDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+ if input_dataset.output_types not in (dtypes.int32, dtypes.int64,
+ dtypes.string):
+ raise TypeError(
+ "`tf.data.experimental.unique()` only supports inputs with a single "
+ "`tf.int32`, `tf.int64`, or `tf.string` component.")
+
+ def _as_variant_tensor(self):
+ return gen_experimental_dataset_ops.experimental_unique_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
diff --git a/tensorflow/python/data/experimental/ops/writers.py b/tensorflow/python/data/experimental/ops/writers.py
new file mode 100644
index 0000000000..994447cb4d
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/writers.py
@@ -0,0 +1,60 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python wrappers for tf.data writers."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import convert
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("data.experimental.TFRecordWriter")
+class TFRecordWriter(object):
+ """Writes data to a TFRecord file."""
+
+ def __init__(self, filename, compression_type=None):
+ self._filename = ops.convert_to_tensor(
+ filename, dtypes.string, name="filename")
+ self._compression_type = convert.optional_param_to_tensor(
+ "compression_type",
+ compression_type,
+ argument_default="",
+ argument_dtype=dtypes.string)
+
+ def write(self, dataset):
+ """Returns a `tf.Operation` to write a dataset to a file.
+
+ Args:
+ dataset: a `tf.data.Dataset` whose elements are to be written to a file
+
+ Returns:
+ A `tf.Operation` that, when run, writes contents of `dataset` to a file.
+ """
+ if not isinstance(dataset, dataset_ops.Dataset):
+ raise TypeError("`dataset` must be a `tf.data.Dataset` object.")
+ if (dataset.output_types != dtypes.string or
+ dataset.output_shapes != tensor_shape.scalar()):
+ raise TypeError(
+ "`dataset` must produce scalar `DT_STRING` tensors whereas it "
+ "produces shape {0} and types {1}".format(dataset.output_shapes,
+ dataset.output_types))
+ return gen_dataset_ops.dataset_to_tf_record(
+ dataset._as_variant_tensor(), self._filename, self._compression_type) # pylint: disable=protected-access
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index 6bba72a8e9..3b9d3a639d 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -889,8 +889,8 @@ class Dataset(object):
will be padded out to the maximum length of all elements in that
dimension.
- See also `tf.contrib.data.dense_to_sparse_batch`, which combines elements
- that may have different shapes into a `tf.SparseTensor`.
+ See also `tf.data.experimental.dense_to_sparse_batch`, which combines
+ elements that may have different shapes into a `tf.SparseTensor`.
Args:
batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
diff --git a/tensorflow/python/data/ops/optional_ops.py b/tensorflow/python/data/ops/optional_ops.py
index 3bbebd7878..aca989e03a 100644
--- a/tensorflow/python/data/ops/optional_ops.py
+++ b/tensorflow/python/data/ops/optional_ops.py
@@ -31,7 +31,7 @@ class Optional(object):
An `Optional` can represent the result of an operation that may fail as a
value, rather than raising an exception and halting execution. For example,
- `tf.contrib.data.get_next_as_optional` returns an `Optional` that either
+ `tf.data.experimental.get_next_as_optional` returns an `Optional` that either
contains the next value from a `tf.data.Iterator` if one exists, or a "none"
value that indicates the end of the sequence has been reached.
"""
@@ -111,7 +111,7 @@ class Optional(object):
class _OptionalImpl(Optional):
- """Concrete implementation of `tf.contrib.data.Optional`.
+ """Concrete implementation of `tf.data.experimental.Optional`.
NOTE(mrry): This implementation is kept private, to avoid defining
`Optional.__init__()` in the public API.
diff --git a/tensorflow/python/data/ops/readers.py b/tensorflow/python/data/ops/readers.py
index b0f26631f9..d08da6704c 100644
--- a/tensorflow/python/data/ops/readers.py
+++ b/tensorflow/python/data/ops/readers.py
@@ -129,7 +129,7 @@ class ParallelInterleaveDataset(dataset_ops.InterleaveDataset):
def __init__(self, input_dataset, map_func, cycle_length, block_length,
sloppy, buffer_output_elements, prefetch_input_elements):
- """See `tf.contrib.data.parallel_interleave()` for details."""
+ """See `tf.data.experimental.parallel_interleave()` for details."""
super(ParallelInterleaveDataset, self).__init__(input_dataset, map_func,
cycle_length, block_length)
self._sloppy = ops.convert_to_tensor(
@@ -158,7 +158,7 @@ class ParallelInterleaveDataset(dataset_ops.InterleaveDataset):
# pylint: enable=protected-access
def _transformation_name(self):
- return "tf.contrib.data.parallel_interleave()"
+ return "tf.data.experimental.parallel_interleave()"
@tf_export("data.TFRecordDataset")