aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-10-01 16:45:11 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-01 16:50:05 -0700
commitb72265dc002e712fc3d0f33434f13c7a36a484b2 (patch)
treef92d1f23c329654772f95d93f5cf4458741b72df
parentbb1f9e1a57c8bc18325b3c86298be96e6647a0a3 (diff)
[tf.data] Deprecate `tf.contrib.data` and introduce `tf.data.experimental` to replace it.
This change prepares `tf.data` for TensorFlow 2.0, where `tf.contrib` will no longer exist. It retains the pre-existing endpoints in `tf.contrib.data` with deprecation warnings. Note there are some exceptions to the move: * Deprecated symbols in `tf.contrib.data` have not been moved to `tf.data.experimental`, because replacements already exist. * `tf.contrib.data.LMDBDataset` has not been moved, because we plan to move it to a SIG-maintained repository. * `tf.contrib.data.assert_element_shape()` has not yet been moved, because it depends on functionality in `tf.contrib`, and it will move in a later change. * `tf.contrib.data.AUTOTUNE` has not yet been moved, because we have not yet determined how to `tf_export()` a Python integer. * The stats-related API endpoints have not yet appeared in a released version of TensorFlow, so these are moved to `tf.data.experimental` without retaining an endpoint in `tf.contrib.data`. In addition, this change includes some build rule and ApiDef refactoring: * Some of the "//third_party/tensorflow/python:training" dependencies had to be split in order to avoid a circular dependency. * The `tf.contrib.stateless` ops now have a private core library for the generated wrappers (and accordingly are hidden in their ApiDef) so that `tf.data.experimental.sample_from_datasets()` can depend on them. PiperOrigin-RevId: 215304249
-rw-r--r--tensorflow/contrib/bigtable/README.md4
-rw-r--r--tensorflow/contrib/bigtable/python/ops/bigtable_api.py4
-rw-r--r--tensorflow/contrib/cmake/python_modules.txt1
-rw-r--r--tensorflow/contrib/data/README.md18
-rw-r--r--tensorflow/contrib/data/__init__.py11
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/BUILD560
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py226
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/reduce_dataset_test.py62
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py527
-rw-r--r--tensorflow/contrib/data/python/ops/BUILD170
-rw-r--r--tensorflow/contrib/data/python/ops/batching.py549
-rw-r--r--tensorflow/contrib/data/python/ops/counter.py13
-rw-r--r--tensorflow/contrib/data/python/ops/enumerate_ops.py15
-rw-r--r--tensorflow/contrib/data/python/ops/error_ops.py37
-rw-r--r--tensorflow/contrib/data/python/ops/get_single_element.py29
-rw-r--r--tensorflow/contrib/data/python/ops/grouping.py441
-rw-r--r--tensorflow/contrib/data/python/ops/interleave_ops.py149
-rw-r--r--tensorflow/contrib/data/python/ops/iterator_ops.py167
-rw-r--r--tensorflow/contrib/data/python/ops/parsing_ops.py107
-rw-r--r--tensorflow/contrib/data/python/ops/prefetching_ops.py486
-rw-r--r--tensorflow/contrib/data/python/ops/random_ops.py34
-rw-r--r--tensorflow/contrib/data/python/ops/readers.py674
-rw-r--r--tensorflow/contrib/data/python/ops/resampling.py260
-rw-r--r--tensorflow/contrib/data/python/ops/scan_ops.py137
-rw-r--r--tensorflow/contrib/data/python/ops/shuffle_ops.py56
-rw-r--r--tensorflow/contrib/data/python/ops/threadpool.py88
-rw-r--r--tensorflow/contrib/data/python/ops/unique.py43
-rw-r--r--tensorflow/contrib/data/python/ops/writers.py40
-rw-r--r--tensorflow/contrib/distribute/python/prefetching_ops_v2.py2
-rw-r--r--tensorflow/contrib/eager/python/datasets.py4
-rw-r--r--tensorflow/contrib/eager/python/datasets_test.py6
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/imagenet_input.py12
-rw-r--r--tensorflow/contrib/estimator/python/estimator/rnn_test.py2
-rw-r--r--tensorflow/contrib/lookup/lookup_ops_test.py2
-rw-r--r--tensorflow/contrib/stateless/BUILD8
-rw-r--r--tensorflow/contrib/stateless/__init__.py5
-rw-r--r--tensorflow/contrib/tpu/python/tpu/datasets.py4
-rw-r--r--tensorflow/contrib/tpu/tpu_estimator.md2
-rw-r--r--tensorflow/contrib/training/BUILD2
-rw-r--r--tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_StatelessMultinomial.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_StatelessRandomNormal.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_StatelessRandomUniform.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_StatelessTruncatedNormal.pbtxt4
-rw-r--r--tensorflow/examples/get_started/regression/test.py2
-rw-r--r--tensorflow/python/BUILD34
-rw-r--r--tensorflow/python/data/BUILD1
-rw-r--r--tensorflow/python/data/__init__.py1
-rw-r--r--tensorflow/python/data/experimental/BUILD16
-rw-r--r--tensorflow/python/data/experimental/__init__.py109
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/BUILD569
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/batch_dataset_op_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py)317
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/bucketing_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/bucketing_test.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/csv_dataset_op_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/dataset_constructor_op_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/dataset_serialization_test_base.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/directed_interleave_dataset_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/filter_dataset_op_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/get_single_element_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py)30
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/indexed_dataset_ops_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/interleave_dataset_op_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/iterator_ops_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/map_dataset_op_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py)6
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/optimization/BUILD (renamed from tensorflow/contrib/data/python/kernel_tests/optimization/BUILD)30
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/optimization/assert_next_dataset_op_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/optimization/hoist_random_uniform_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/optimization/hoist_random_uniform_test.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/optimization/latency_all_edges_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py)6
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/optimization/map_and_filter_fusion_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/optimization/map_parallelization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/optimization/model_dataset_op_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/optimization/noop_elimination_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_op_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/parsing_ops_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py)3
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/prefetching_ops_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/range_dataset_op_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/reader_dataset_ops_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/reader_dataset_ops_test_base.py (renamed from tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/resample_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/resample_test.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/scan_dataset_op_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/BUILD (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/BUILD)46
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/batch_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/batch_dataset_serialization_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/cache_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/cache_dataset_serialization_test.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/concatenate_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/concatenate_dataset_serialization_test.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/csv_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/csv_dataset_serialization_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/dataset_constructor_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/dataset_constructor_serialization_test.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/dataset_serialization_test_base.py692
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/filter_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/filter_dataset_serialization_test.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/flat_map_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/flat_map_dataset_serialization_test.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/group_by_reducer_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/group_by_reducer_serialization_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/group_by_window_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/group_by_window_serialization_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/ignore_errors_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/ignore_errors_serialization_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/interleave_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/map_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/map_dataset_serialization_test.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/optimize_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/optimize_dataset_serialization_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/padded_batch_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/padded_batch_dataset_serialization_test.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/parallel_map_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/parallel_map_dataset_serialization_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/parse_example_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/parse_example_dataset_serialization_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/prefetch_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/prefetch_dataset_serialization_test.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/range_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/range_dataset_serialization_test.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/sample_from_datasets_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/sample_from_datasets_serialization_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/scan_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/scan_dataset_serialization_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/sequence_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/sequence_dataset_serialization_test.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/serialization_integration_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/serialization_integration_test.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/shuffle_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_dataset_serialization_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/sql_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/sql_dataset_serialization_test.py)6
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/stats_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/textline_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/textline_dataset_serialization_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/tf_record_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/tf_record_dataset_serialization_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/unbatch_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/unbatch_dataset_serialization_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/unique_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/unique_dataset_serialization_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/zip_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/zip_dataset_serialization_test.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization_integration_test.py85
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/shuffle_dataset_op_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/sql_dataset_op_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/sql_dataset_op_test_base.py (renamed from tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/stats_dataset_test_base.py (renamed from tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py)0
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/threadpool_dataset_ops_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/unique_dataset_op_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/writer_ops_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py)2
-rw-r--r--tensorflow/python/data/experimental/ops/BUILD377
-rw-r--r--tensorflow/python/data/experimental/ops/batching.py669
-rw-r--r--tensorflow/python/data/experimental/ops/counter.py55
-rw-r--r--tensorflow/python/data/experimental/ops/enumerate_ops.py60
-rw-r--r--tensorflow/python/data/experimental/ops/error_ops.py78
-rw-r--r--tensorflow/python/data/experimental/ops/get_single_element.py72
-rw-r--r--tensorflow/python/data/experimental/ops/grouping.py551
-rw-r--r--tensorflow/python/data/experimental/ops/indexed_dataset_ops.py (renamed from tensorflow/contrib/data/python/ops/indexed_dataset_ops.py)0
-rw-r--r--tensorflow/python/data/experimental/ops/interleave_ops.py262
-rw-r--r--tensorflow/python/data/experimental/ops/iterator_ops.py268
-rw-r--r--tensorflow/python/data/experimental/ops/map_defun.py (renamed from tensorflow/contrib/data/python/ops/map_defun.py)0
-rw-r--r--tensorflow/python/data/experimental/ops/optimization.py (renamed from tensorflow/contrib/data/python/ops/optimization.py)0
-rw-r--r--tensorflow/python/data/experimental/ops/parsing_ops.py152
-rw-r--r--tensorflow/python/data/experimental/ops/prefetching_ops.py531
-rw-r--r--tensorflow/python/data/experimental/ops/random_ops.py54
-rw-r--r--tensorflow/python/data/experimental/ops/readers.py904
-rw-r--r--tensorflow/python/data/experimental/ops/resampling.py296
-rw-r--r--tensorflow/python/data/experimental/ops/scan_ops.py177
-rw-r--r--tensorflow/python/data/experimental/ops/shuffle_ops.py102
-rw-r--r--tensorflow/python/data/experimental/ops/stats_ops.py (renamed from tensorflow/contrib/data/python/ops/stats_ops.py)14
-rw-r--r--tensorflow/python/data/experimental/ops/threadpool.py104
-rw-r--r--tensorflow/python/data/experimental/ops/unique.py79
-rw-r--r--tensorflow/python/data/experimental/ops/writers.py60
-rw-r--r--tensorflow/python/data/ops/dataset_ops.py4
-rw-r--r--tensorflow/python/data/ops/optional_ops.py4
-rw-r--r--tensorflow/python/data/ops/readers.py4
-rw-r--r--tensorflow/python/debug/examples/debug_tflearn_iris.py14
-rw-r--r--tensorflow/python/tools/api/generator/api_init_files.bzl1
-rw-r--r--tensorflow/python/tools/api/generator/api_init_files_v1.bzl1
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-checkpoint-input-pipeline-hook.pbtxt30
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.__metaclass__.pbtxt14
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.pbtxt127
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-optional.pbtxt28
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.__metaclass__.pbtxt14
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.pbtxt127
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-reducer.pbtxt21
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.__metaclass__.pbtxt14
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.pbtxt127
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-stats-aggregator.pbtxt13
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-t-f-record-writer.pbtxt13
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.experimental.pbtxt139
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-checkpoint-input-pipeline-hook.pbtxt30
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.__metaclass__.pbtxt14
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt127
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-optional.pbtxt28
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.__metaclass__.pbtxt14
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt127
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-reducer.pbtxt21
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.__metaclass__.pbtxt14
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt127
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-stats-aggregator.pbtxt13
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-t-f-record-writer.pbtxt13
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.experimental.pbtxt139
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.pbtxt4
-rw-r--r--tensorflow/tools/pip_package/BUILD4
182 files changed, 8389 insertions, 4960 deletions
diff --git a/tensorflow/contrib/bigtable/README.md b/tensorflow/contrib/bigtable/README.md
index f33eaf7e3d..2c44abed5e 100644
--- a/tensorflow/contrib/bigtable/README.md
+++ b/tensorflow/contrib/bigtable/README.md
@@ -203,7 +203,7 @@ def interleave_fn(index):
start = tf.string_join(['training_data_', start_idx_str])
end = tf.string_join(['training_data_', end_idx_str])
return table.scan_range(start_idx, end_idx, columns=columns)
-ds = ds.apply(tf.contrib.data.parallel_interleave(
+ds = ds.apply(tf.data.experimental.parallel_interleave(
interleave_fn, cycle_length=NUM_PARALLEL_READS, prefetch_input_elements=1))
```
@@ -249,7 +249,7 @@ def make_row_key_dataset():
- ...
- fake-data-23498103
"""
- counter_dataset = tf.contrib.data.Counter()
+ counter_dataset = tf.data.experimental.Counter()
width = 8
row_key_prefix = 'fake-data-'
ds = counter_dataset.map(lambda index: tf.as_string(index,
diff --git a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py
index cf56822ff4..7c87b0daeb 100644
--- a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py
+++ b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py
@@ -31,8 +31,8 @@ from six import iteritems
from six import string_types
from tensorflow.contrib.bigtable.ops import gen_bigtable_ops
-from tensorflow.contrib.data.python.ops import interleave_ops
from tensorflow.contrib.util import loader
+from tensorflow.python.data.experimental.ops import interleave_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.framework import dtypes
@@ -228,7 +228,7 @@ class BigtableTable(object):
"""Retrieves a sampling of row keys from the Bigtable table.
This dataset is most often used in conjunction with
- `tf.contrib.data.parallel_interleave` to construct a set of ranges for
+ `tf.data.experimental.parallel_interleave` to construct a set of ranges for
scanning in parallel.
Returns:
diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt
index 9b80eb559f..6e72670142 100644
--- a/tensorflow/contrib/cmake/python_modules.txt
+++ b/tensorflow/contrib/cmake/python_modules.txt
@@ -134,7 +134,6 @@ tensorflow/contrib/cudnn_rnn/python/ops
tensorflow/contrib/data
tensorflow/contrib/data/python
tensorflow/contrib/data/python/kernel_tests
-tensorflow/contrib/data/python/kernel_tests/serialization
tensorflow/contrib/data/python/ops
tensorflow/contrib/decision_trees
tensorflow/contrib/decision_trees/proto
diff --git a/tensorflow/contrib/data/README.md b/tensorflow/contrib/data/README.md
index 848782e8d8..90be7a66ca 100644
--- a/tensorflow/contrib/data/README.md
+++ b/tensorflow/contrib/data/README.md
@@ -1,10 +1,12 @@
`tf.contrib.data` API
=====================
-NOTE: The `tf.contrib.data` module has been deprecated. Use `tf.data` instead.
-We are continuing to support existing code using the `tf.contrib.data` APIs in
-the current version of TensorFlow, but will eventually remove support. The
-`tf.data` APIs are subject to backwards compatibility guarantees.
+NOTE: The `tf.contrib.data` module has been deprecated. Use `tf.data` instead,
+or `tf.data.experimental` for the experimental transformations previously hosted
+in this module. We are continuing to support existing code using the
+`tf.contrib.data` APIs in the current version of TensorFlow, but will eventually
+remove support. The non-experimental `tf.data` APIs are subject to backwards
+compatibility guarantees.
Porting your code to `tf.data`
------------------------------
@@ -25,13 +27,13 @@ instead apply them using `Dataset.apply()` transformation. The full list of
changes is as follows:
* `dataset.dense_to_sparse_batch(...)` is now
- `dataset.apply(tf.contrib.data.dense_to_sparse_batch(...)`.
+ `dataset.apply(tf.data.experimental.dense_to_sparse_batch(...)`.
* `dataset.enumerate(...)` is now
- `dataset.apply(tf.contrib.data.enumerate_dataset(...))`.
+ `dataset.apply(tf.data.experimental.enumerate_dataset(...))`.
* `dataset.group_by_window(...)` is now
- `dataset.apply(tf.contrib.data.group_by_window(...))`.
+ `dataset.apply(tf.data.experimental.group_by_window(...))`.
* `dataset.ignore_errors()` is now
- `dataset.apply(tf.contrib.data.ignore_errors())`.
+ `dataset.apply(tf.data.experimental.ignore_errors())`.
* `dataset.unbatch()` is now `dataset.apply(tf.contrib.data.unbatch())`.
The `Dataset.make_dataset_resource()` and `Iterator.dispose_op()` methods have
diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py
index 3cb51279c3..c3d3e981fa 100644
--- a/tensorflow/contrib/data/__init__.py
+++ b/tensorflow/contrib/data/__init__.py
@@ -96,10 +96,6 @@ from tensorflow.contrib.data.python.ops.interleave_ops import sample_from_datase
from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave
from tensorflow.contrib.data.python.ops.iterator_ops import CheckpointInputPipelineHook
from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator
-
-# Optimization constant that can be used to enable auto-tuning.
-from tensorflow.contrib.data.python.ops.optimization import AUTOTUNE
-
from tensorflow.contrib.data.python.ops.parsing_ops import parse_example_dataset
from tensorflow.contrib.data.python.ops.prefetching_ops import copy_to_device
from tensorflow.contrib.data.python.ops.prefetching_ops import prefetch_to_device
@@ -114,11 +110,12 @@ from tensorflow.contrib.data.python.ops.resampling import rejection_resample
from tensorflow.contrib.data.python.ops.scan_ops import scan
from tensorflow.contrib.data.python.ops.shuffle_ops import shuffle_and_repeat
from tensorflow.contrib.data.python.ops.sliding import sliding_window_batch
-from tensorflow.contrib.data.python.ops.stats_ops import latency_stats
-from tensorflow.contrib.data.python.ops.stats_ops import set_stats_aggregator
-from tensorflow.contrib.data.python.ops.stats_ops import StatsAggregator
from tensorflow.contrib.data.python.ops.unique import unique
from tensorflow.contrib.data.python.ops.writers import TFRecordWriter
+
+# Optimization constant that can be used to enable auto-tuning.
+from tensorflow.python.data.experimental.ops.optimization import AUTOTUNE
+
from tensorflow.python.data.ops.iterator_ops import get_next_as_optional
from tensorflow.python.data.ops.optional_ops import Optional
# pylint: enable=unused-import
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index 33784afa3f..42f538b4ba 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -8,51 +8,17 @@ load("//tensorflow:tensorflow.bzl", "cuda_py_test")
load("//tensorflow:tensorflow.bzl", "py_test")
py_test(
- name = "batch_dataset_op_test",
- size = "medium",
- srcs = ["batch_dataset_op_test.py"],
+ name = "assert_element_shape_test",
+ srcs = ["assert_element_shape_test.py"],
srcs_version = "PY2AND3",
- tags = [
- "no_oss", # (b/79552534)
- "no_pip",
- ],
deps = [
"//tensorflow/contrib/data/python/ops:batching",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
- "//tensorflow/python:math_ops",
"//tensorflow/python:script_ops",
- "//tensorflow/python:session",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python:string_ops",
- "//tensorflow/python:tensor_shape",
- "//tensorflow/python:util",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- "@absl_py//absl/testing:parameterized",
- ],
-)
-
-py_test(
- name = "bucketing_test",
- size = "medium",
- srcs = ["bucketing_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/data/python/ops:grouping",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:sparse_tensor",
"//tensorflow/python:string_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python/data/kernel_tests:test_base",
@@ -62,147 +28,6 @@ py_test(
)
py_test(
- name = "csv_dataset_op_test",
- size = "medium",
- srcs = ["csv_dataset_op_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- "//tensorflow/contrib/data/python/ops:error_ops",
- "//tensorflow/contrib/data/python/ops:readers",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:parsing_ops",
- "//tensorflow/python:platform",
- "//tensorflow/python:platform_test",
- "//tensorflow/python:session",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:readers",
- "//tensorflow/python/eager:context",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "dataset_constructor_op_test",
- size = "medium",
- srcs = ["dataset_constructor_op_test.py"],
- srcs_version = "PY2AND3",
- tags = [
- "manual",
- "nomac", # b/62040583
- ],
- deps = [
- "//tensorflow/contrib/data/python/ops:batching",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- ],
-)
-
-py_test(
- name = "directed_interleave_dataset_test",
- size = "medium",
- srcs = ["directed_interleave_dataset_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/data/python/ops:interleave_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:errors",
- "//tensorflow/python:random_seed",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "get_single_element_test",
- size = "small",
- srcs = ["get_single_element_test.py"],
- deps = [
- "//tensorflow/contrib/data/python/ops:get_single_element",
- "//tensorflow/contrib/data/python/ops:grouping",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "@absl_py//absl/testing:parameterized",
- ],
-)
-
-py_test(
- name = "indexed_dataset_ops_test",
- srcs = ["indexed_dataset_ops_test.py"],
- deps = [
- "//tensorflow/contrib/data/python/ops:indexed_dataset_ops",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:experimental_dataset_ops_gen",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "interleave_dataset_op_test",
- size = "medium",
- srcs = ["interleave_dataset_op_test.py"],
- srcs_version = "PY2AND3",
- tags = [
- "no_oss",
- "no_pip",
- "notap",
- ],
- deps = [
- "//tensorflow/contrib/data/python/ops:interleave_ops",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:script_ops",
- "//tensorflow/python:sparse_ops",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "@six_archive//:six",
- ],
-)
-
-py_test(
- name = "iterator_ops_test",
- size = "small",
- srcs = ["iterator_ops_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- "//tensorflow/contrib/data/python/ops:iterator_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:training",
- "//tensorflow/python:variables",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/estimator:estimator_py",
- ],
-)
-
-py_test(
name = "lmdb_dataset_op_test",
size = "medium",
srcs = ["lmdb_dataset_op_test.py"],
@@ -229,252 +54,18 @@ py_test(
)
py_test(
- name = "map_dataset_op_test",
- size = "medium",
- srcs = ["map_dataset_op_test.py"],
- srcs_version = "PY2AND3",
- tags = [
- "no_pip",
- "noasan", # times out
- "optonly",
- ],
- deps = [
- "//tensorflow/contrib/data/python/ops:batching",
- "//tensorflow/contrib/data/python/ops:error_ops",
- "//tensorflow/contrib/data/python/ops:optimization",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:io_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:util",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "filter_dataset_op_test",
- size = "medium",
- srcs = ["filter_dataset_op_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/data/python/ops:optimization",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:io_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:util",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "map_defun_op_test",
+ name = "reduce_dataset_test",
size = "small",
- srcs = ["map_defun_op_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
+ srcs = ["reduce_dataset_test.py"],
deps = [
- "//tensorflow/contrib/data/python/ops:map_defun",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:check_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:data_flow_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:function",
- "//tensorflow/python:functional_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:session",
- "//tensorflow/python/data/kernel_tests:test_base",
- ],
-)
-
-py_test(
- name = "parsing_ops_test",
- size = "small",
- srcs = ["parsing_ops_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/data/python/ops:parsing_ops",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:parsing_ops",
- "//tensorflow/python:platform",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- "//third_party/py/numpy",
- ],
-)
-
-cuda_py_test(
- name = "prefetching_ops_test",
- size = "small",
- srcs = ["prefetching_ops_test.py"],
- additional_deps = [
- "//tensorflow/contrib/data/python/ops:prefetching_ops",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:function",
- "//tensorflow/python:resource_variable_ops",
- "//tensorflow/python/compat:compat",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/ops:iterator_ops",
- ],
- tags = ["no_windows_gpu"],
-)
-
-py_test(
- name = "range_dataset_op_test",
- size = "small",
- srcs = ["range_dataset_op_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/data/python/ops:counter",
- "//tensorflow/contrib/data/python/ops:enumerate_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:tensor_shape",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- ],
-)
-
-py_library(
- name = "reader_dataset_ops_test_base",
- testonly = 1,
- srcs = [
- "reader_dataset_ops_test_base.py",
- ],
- srcs_version = "PY2AND3",
- visibility = [
- "//tensorflow/contrib/data/python/kernel_tests:__pkg__",
- "//tensorflow/contrib/data/python/kernel_tests/serialization:__pkg__",
- ],
- deps = [
- "//tensorflow/contrib/data/python/ops:readers",
- "//tensorflow/core:protos_all_py",
+ "//tensorflow/contrib/data/python/ops:get_single_element",
+ "//tensorflow/contrib/data/python/ops:grouping",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
- "//tensorflow/python:lib",
- "//tensorflow/python:parsing_ops",
- "//tensorflow/python:util",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:iterator_ops",
- "//tensorflow/python/data/ops:readers",
- ],
-)
-
-py_test(
- name = "reader_dataset_ops_test",
- size = "medium",
- srcs = ["reader_dataset_ops_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":reader_dataset_ops_test_base",
- "//tensorflow/contrib/data/python/ops:readers",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:parsing_ops",
- "//tensorflow/python:string_ops",
- "//tensorflow/python/data/ops:readers",
- "//tensorflow/python/data/util:nest",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "resample_test",
- size = "medium",
- srcs = ["resample_test.py"],
- shard_count = 2,
- srcs_version = "PY2AND3",
- tags = [
- "noasan",
- "optonly",
- ],
- deps = [
- "//tensorflow/contrib/data/python/ops:resampling",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:random_ops",
- "//tensorflow/python:string_ops",
- "//tensorflow/python:util",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
- "@six_archive//:six",
- ],
-)
-
-py_test(
- name = "scan_dataset_op_test",
- size = "small",
- srcs = ["scan_dataset_op_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- "//tensorflow/contrib/data/python/ops:scan_ops",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/eager:context",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "shuffle_dataset_op_test",
- size = "medium",
- srcs = ["shuffle_dataset_op_test.py"],
- srcs_version = "PY2AND3",
- tags = [
- "no_pip",
- "optonly",
- ],
- deps = [
- "//tensorflow/contrib/data/python/ops:shuffle_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
],
)
@@ -496,142 +87,3 @@ py_test(
"@absl_py//absl/testing:parameterized",
],
)
-
-py_library(
- name = "sql_dataset_op_test_base",
- srcs = ["sql_dataset_op_test_base.py"],
- srcs_version = "PY2AND3",
- visibility = [
- "//tensorflow/contrib/data/python/kernel_tests:__pkg__",
- "//tensorflow/contrib/data/python/kernel_tests/serialization:__pkg__",
- ],
- deps = [
- "//tensorflow/contrib/data/python/ops:readers",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python/data/kernel_tests:test_base",
- "@org_sqlite//:python",
- ],
-)
-
-py_test(
- name = "sql_dataset_op_test",
- size = "small",
- srcs = ["sql_dataset_op_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":sql_dataset_op_test_base",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- ],
-)
-
-py_test(
- name = "stats_dataset_ops_test",
- size = "medium",
- srcs = ["stats_dataset_ops_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":reader_dataset_ops_test_base",
- ":stats_dataset_test_base",
- "//tensorflow/contrib/data/python/ops:stats_ops",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_library(
- name = "stats_dataset_test_base",
- srcs = ["stats_dataset_test_base.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python/data/kernel_tests:test_base",
- ],
-)
-
-py_test(
- name = "threadpool_dataset_ops_test",
- size = "small",
- srcs = ["threadpool_dataset_ops_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- "//tensorflow/contrib/data/python/ops:threadpool",
- "//tensorflow/contrib/data/python/ops:unique",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:script_ops",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- "@absl_py//absl/testing:parameterized",
- ],
-)
-
-py_test(
- name = "unique_dataset_op_test",
- size = "small",
- srcs = ["unique_dataset_op_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- "//tensorflow/contrib/data/python/ops:unique",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:util",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- ],
-)
-
-py_test(
- name = "window_dataset_op_test",
- size = "medium",
- srcs = ["window_dataset_op_test.py"],
- srcs_version = "PY2AND3",
- tags = [
- "no_pip",
- ],
- deps = [
- "//tensorflow/contrib/data/python/ops:batching",
- "//tensorflow/contrib/data/python/ops:grouping",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- "@absl_py//absl/testing:parameterized",
- ],
-)
-
-py_test(
- name = "writer_ops_test",
- size = "small",
- srcs = ["writer_ops_test.py"],
- deps = [
- "//tensorflow/contrib/data/python/ops:writers",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:lib",
- "//tensorflow/python:util",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/ops:readers",
- ],
-)
diff --git a/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py b/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py
new file mode 100644
index 0000000000..0456463a19
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py
@@ -0,0 +1,226 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.data.python.ops import batching
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import script_ops
+from tensorflow.python.platform import test
+
+
+class AssertElementShapeTest(test_base.DatasetTestBase):
+
+ def test_assert_element_shape(self):
+
+ def create_dataset(_):
+ return (array_ops.ones(2, dtype=dtypes.float32),
+ array_ops.zeros((3, 4), dtype=dtypes.int32))
+
+ dataset = dataset_ops.Dataset.range(5).map(create_dataset)
+ expected_shapes = (tensor_shape.TensorShape(2),
+ tensor_shape.TensorShape((3, 4)))
+ self.assertEqual(expected_shapes, dataset.output_shapes)
+
+ result = dataset.apply(batching.assert_element_shape(expected_shapes))
+ self.assertEqual(expected_shapes, result.output_shapes)
+
+ iterator = result.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ for _ in range(5):
+ sess.run(get_next)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def test_assert_wrong_element_shape(self):
+
+ def create_dataset(_):
+ return (array_ops.ones(2, dtype=dtypes.float32),
+ array_ops.zeros((3, 4), dtype=dtypes.int32))
+
+ dataset = dataset_ops.Dataset.range(3).map(create_dataset)
+ wrong_shapes = (tensor_shape.TensorShape(2),
+ tensor_shape.TensorShape((3, 10)))
+ with self.assertRaises(ValueError):
+ dataset.apply(batching.assert_element_shape(wrong_shapes))
+
+ def test_assert_element_shape_on_unknown_shape_dataset(self):
+
+ def create_unknown_shape_dataset(x):
+ return script_ops.py_func(
+ lambda _: ( # pylint: disable=g-long-lambda
+ np.ones(2, dtype=np.float32),
+ np.zeros((3, 4), dtype=np.int32)),
+ [x],
+ [dtypes.float32, dtypes.int32])
+
+ dataset = dataset_ops.Dataset.range(5).map(create_unknown_shape_dataset)
+ unknown_shapes = (tensor_shape.TensorShape(None),
+ tensor_shape.TensorShape(None))
+ self.assertEqual(unknown_shapes, dataset.output_shapes)
+
+ expected_shapes = (tensor_shape.TensorShape(2),
+ tensor_shape.TensorShape((3, 4)))
+ result = dataset.apply(batching.assert_element_shape(expected_shapes))
+ self.assertEqual(expected_shapes, result.output_shapes)
+
+ iterator = result.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ for _ in range(5):
+ sess.run(get_next)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def test_assert_wrong_element_shape_on_unknown_shape_dataset(self):
+
+ def create_unknown_shape_dataset(x):
+ return script_ops.py_func(
+ lambda _: ( # pylint: disable=g-long-lambda
+ np.ones(2, dtype=np.float32),
+ np.zeros((3, 4), dtype=np.int32)),
+ [x],
+ [dtypes.float32, dtypes.int32])
+
+ dataset = dataset_ops.Dataset.range(3).map(create_unknown_shape_dataset)
+ unknown_shapes = (tensor_shape.TensorShape(None),
+ tensor_shape.TensorShape(None))
+ self.assertEqual(unknown_shapes, dataset.output_shapes)
+
+ wrong_shapes = (tensor_shape.TensorShape(2),
+ tensor_shape.TensorShape((3, 10)))
+ iterator = (
+ dataset.apply(batching.assert_element_shape(wrong_shapes))
+ .make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(get_next)
+
+ def test_assert_partial_element_shape(self):
+
+ def create_dataset(_):
+ return (array_ops.ones(2, dtype=dtypes.float32),
+ array_ops.zeros((3, 4), dtype=dtypes.int32))
+
+ dataset = dataset_ops.Dataset.range(5).map(create_dataset)
+ partial_expected_shape = (
+ tensor_shape.TensorShape(None), # Unknown shape
+ tensor_shape.TensorShape((None, 4))) # Partial shape
+ result = dataset.apply(
+ batching.assert_element_shape(partial_expected_shape))
+ # Partial shapes are merged with actual shapes:
+ actual_shapes = (tensor_shape.TensorShape(2),
+ tensor_shape.TensorShape((3, 4)))
+ self.assertEqual(actual_shapes, result.output_shapes)
+
+ iterator = result.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ for _ in range(5):
+ sess.run(get_next)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def test_assert_wrong_partial_element_shape(self):
+
+ def create_dataset(_):
+ return (array_ops.ones(2, dtype=dtypes.float32),
+ array_ops.zeros((3, 4), dtype=dtypes.int32))
+
+ dataset = dataset_ops.Dataset.range(3).map(create_dataset)
+ wrong_shapes = (tensor_shape.TensorShape(2),
+ tensor_shape.TensorShape((None, 10)))
+ with self.assertRaises(ValueError):
+ dataset.apply(batching.assert_element_shape(wrong_shapes))
+
+ def test_assert_partial_element_shape_on_unknown_shape_dataset(self):
+
+ def create_unknown_shape_dataset(x):
+ return script_ops.py_func(
+ lambda _: ( # pylint: disable=g-long-lambda
+ np.ones(2, dtype=np.float32),
+ np.zeros((3, 4), dtype=np.int32)),
+ [x],
+ [dtypes.float32, dtypes.int32])
+
+ dataset = dataset_ops.Dataset.range(5).map(create_unknown_shape_dataset)
+ unknown_shapes = (tensor_shape.TensorShape(None),
+ tensor_shape.TensorShape(None))
+ self.assertEqual(unknown_shapes, dataset.output_shapes)
+
+ expected_shapes = (tensor_shape.TensorShape(2),
+ tensor_shape.TensorShape((None, 4)))
+ result = dataset.apply(batching.assert_element_shape(expected_shapes))
+ self.assertEqual(expected_shapes, result.output_shapes)
+
+ iterator = result.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ for _ in range(5):
+ sess.run(get_next)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def test_assert_wrong_partial_element_shape_on_unknown_shape_dataset(self):
+
+ def create_unknown_shape_dataset(x):
+ return script_ops.py_func(
+ lambda _: ( # pylint: disable=g-long-lambda
+ np.ones(2, dtype=np.float32),
+ np.zeros((3, 4), dtype=np.int32)),
+ [x],
+ [dtypes.float32, dtypes.int32])
+
+ dataset = dataset_ops.Dataset.range(3).map(create_unknown_shape_dataset)
+ unknown_shapes = (tensor_shape.TensorShape(None),
+ tensor_shape.TensorShape(None))
+ self.assertEqual(unknown_shapes, dataset.output_shapes)
+
+ wrong_shapes = (tensor_shape.TensorShape(2),
+ tensor_shape.TensorShape((None, 10)))
+ iterator = (
+ dataset.apply(batching.assert_element_shape(wrong_shapes))
+ .make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(get_next)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/reduce_dataset_test.py b/tensorflow/contrib/data/python/kernel_tests/reduce_dataset_test.py
new file mode 100644
index 0000000000..e7281d5318
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/reduce_dataset_test.py
@@ -0,0 +1,62 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.contrib.data.python.ops import get_single_element
+from tensorflow.contrib.data.python.ops import grouping
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class ReduceDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ ("SumZero", 0),
+ ("SumOne", 1),
+ ("SumFive", 5),
+ ("SumTen", 10),
+ )
+ def testReduceDataset(self, stop):
+ def init_fn(_):
+ return np.int64(0)
+
+ def reduce_fn(state, value):
+ return state + value
+
+ def finalize_fn(state):
+ return state
+
+ sum_reducer = grouping.Reducer(init_fn, reduce_fn, finalize_fn)
+
+ stop_t = array_ops.placeholder(dtypes.int64, shape=[])
+ dataset = dataset_ops.Dataset.range(stop_t)
+ element = get_single_element.reduce_dataset(dataset, sum_reducer)
+
+ with self.cached_session() as sess:
+ value = sess.run(element, feed_dict={stop_t: stop})
+ self.assertEqual(stop * (stop - 1) / 2, value)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
deleted file mode 100644
index 79134c7bc6..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
+++ /dev/null
@@ -1,527 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from absl.testing import parameterized
-import numpy as np
-
-from tensorflow.contrib.data.python.ops import batching
-from tensorflow.contrib.data.python.ops import grouping
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import sparse_ops
-from tensorflow.python.platform import test
-
-
-class WindowDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
-
- def _structuredDataset(self, structure, shape, dtype):
- if structure is None:
- return dataset_ops.Dataset.from_tensors(
- array_ops.zeros(shape, dtype=dtype))
- else:
- return dataset_ops.Dataset.zip(
- tuple([
- self._structuredDataset(substructure, shape, dtype)
- for substructure in structure
- ]))
-
- def _structuredElement(self, structure, shape, dtype):
- if structure is None:
- return array_ops.zeros(shape, dtype=dtype)
- else:
- return tuple([
- self._structuredElement(substructure, shape, dtype)
- for substructure in structure
- ])
-
- def _assertEqual(self, xs, ys):
- self.assertEqual(type(xs), type(ys))
- if isinstance(xs, tuple) and isinstance(ys, tuple):
- self.assertEqual(len(xs), len(ys))
- for x, y in zip(xs, ys):
- self._assertEqual(x, y)
- elif isinstance(xs, np.ndarray) and isinstance(ys, np.ndarray):
- self.assertAllEqual(xs, ys)
- else:
- self.assertEqual(xs, ys)
-
- @parameterized.named_parameters(
- ("1", None, np.int32([]), dtypes.bool),
- ("2", None, np.int32([]), dtypes.int32),
- ("3", None, np.int32([]), dtypes.float32),
- ("4", None, np.int32([]), dtypes.string),
- ("5", None, np.int32([2]), dtypes.int32),
- ("6", None, np.int32([2, 2]), dtypes.int32),
- ("7", (None, None, None), np.int32([]), dtypes.int32),
- ("8", (None, (None, None)), np.int32([]), dtypes.int32),
- )
- def testWindowDatasetFlatMap(self, structure, shape, dtype):
- """Tests windowing by chaining it with flat map.
-
- Args:
- structure: the input structure
- shape: the input shape
- dtype: the input data type
- """
-
- def fn(*args):
- if len(args) == 1 and not isinstance(args[0], tuple):
- return args[0]
- return dataset_ops.Dataset.zip(
- tuple([fn(*arg) if isinstance(arg, tuple) else arg for arg in args]))
-
- dataset = self._structuredDataset(structure, shape, dtype).repeat(5).apply(
- grouping.window_dataset(5)).flat_map(fn)
- get_next = dataset.make_one_shot_iterator().get_next()
- with self.cached_session() as sess:
- expected = sess.run(self._structuredElement(structure, shape, dtype))
- for _ in range(5):
- actual = sess.run(get_next)
- self._assertEqual(expected, actual)
-
- @parameterized.named_parameters(
- ("1", None, np.int32([]), dtypes.bool),
- ("2", None, np.int32([]), dtypes.int32),
- ("3", None, np.int32([]), dtypes.float32),
- ("4", None, np.int32([]), dtypes.string),
- ("5", None, np.int32([2]), dtypes.int32),
- ("6", None, np.int32([2, 2]), dtypes.int32),
- ("7", (None, None, None), np.int32([]), dtypes.int32),
- ("8", (None, (None, None)), np.int32([]), dtypes.int32),
- )
- def testWindowDatasetBatchDense(self, structure, shape, dtype):
- """Tests batching of dense tensor windows.
-
- Args:
- structure: the input structure
- shape: the input shape
- dtype: the input data type
- """
-
- def fn(*args):
- if len(args) == 1 and not isinstance(args[0], tuple):
- return batching.batch_window(args[0])
-
- return tuple([
- fn(*arg) if isinstance(arg, tuple) else batching.batch_window(arg)
- for arg in args
- ])
-
- dataset = self._structuredDataset(structure, shape, dtype).repeat(5).apply(
- grouping.window_dataset(5)).apply(grouping._map_x_dataset(fn))
- get_next = dataset.make_one_shot_iterator().get_next()
- with self.cached_session() as sess:
- expected = sess.run(
- self._structuredElement(structure, np.concatenate(
- ([5], shape), axis=0), dtype))
- actual = sess.run(get_next)
- self._assertEqual(expected, actual)
-
- @parameterized.named_parameters(
- ("1", np.int32([])),
- ("2", np.int32([1])),
- ("3", np.int32([1, 2, 3])),
- )
- def testWindowDatasetBatchDenseDynamicShape(self, shape):
- """Tests batching of dynamically shaped dense tensor windows.
-
- Args:
- shape: the input shape
- """
-
- shape_t = array_ops.placeholder(dtypes.int32)
- dataset = dataset_ops.Dataset.from_tensors(
- array_ops.zeros(shape_t)).repeat(5).apply(
- grouping.window_dataset(5)).apply(
- grouping._map_x_dataset(batching.batch_window))
- iterator = dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- sess.run(init_op, {shape_t: shape})
- expected = sess.run(
- self._structuredElement(None, np.concatenate(([5], shape), axis=0),
- dtypes.int32))
- actual = sess.run(get_next)
- self._assertEqual(expected, actual)
-
- def _make_dense_to_sparse_fn(self, is_scalar):
-
- def dense_to_sparse_scalar(tensor):
- indices = [[]]
- values = array_ops.expand_dims(tensor, 0)
- shape = []
- return sparse_tensor.SparseTensorValue(indices, values, shape)
-
- def dense_to_sparse_non_scalar(tensor):
- indices = array_ops.where(array_ops.ones_like(tensor, dtype=dtypes.bool))
- values = array_ops.gather_nd(tensor, indices)
- shape = array_ops.shape(tensor, out_type=dtypes.int64)
- return sparse_tensor.SparseTensorValue(indices, values, shape)
-
- if is_scalar:
- return dense_to_sparse_scalar
- return dense_to_sparse_non_scalar
-
- def _structuredSparseDataset(self, structure, shape, dtype):
- dense_to_sparse = self._make_dense_to_sparse_fn(len(shape) == 0) # pylint: disable=g-explicit-length-test
- if structure is None:
- return dataset_ops.Dataset.from_tensors(
- dense_to_sparse(array_ops.zeros(shape, dtype=dtype)))
- else:
- return dataset_ops.Dataset.zip(
- tuple([
- self._structuredSparseDataset(substructure, shape, dtype)
- for substructure in structure
- ]))
-
- def _structuredSparseElement(self, structure, shape, dtype):
- dense_to_sparse = self._make_dense_to_sparse_fn(len(shape) == 0) # pylint: disable=g-explicit-length-test
- if structure is None:
- return dense_to_sparse(array_ops.zeros(shape, dtype=dtype))
- else:
- return tuple([
- self._structuredSparseElement(substructure, shape, dtype)
- for substructure in structure
- ])
-
- @parameterized.named_parameters(
- ("1", None, np.int32([]), dtypes.bool),
- ("2", None, np.int32([]), dtypes.int32),
- ("3", None, np.int32([]), dtypes.float32),
- ("4", None, np.int32([]), dtypes.string),
- ("5", None, np.int32([2]), dtypes.int32),
- ("6", None, np.int32([2, 2]), dtypes.int32),
- ("7", (None, None, None), np.int32([]), dtypes.int32),
- ("8", (None, (None, None)), np.int32([]), dtypes.int32),
- )
- def testWindowDatasetBatchSparse(self, structure, shape, dtype):
- """Tests batching of sparse tensor windows.
-
- Args:
- structure: the input structure
- shape: the input shape
- dtype: the input data type
- """
-
- def fn(*args):
- if len(args) == 1 and not isinstance(args[0], tuple):
- return batching.batch_window(args[0])
-
- return tuple([
- fn(*arg) if isinstance(arg, tuple) else batching.batch_window(arg)
- for arg in args
- ])
-
- dataset = self._structuredSparseDataset(
- structure, shape, dtype).repeat(5).apply(
- grouping.window_dataset(5)).apply(grouping._map_x_dataset(fn))
- get_next = dataset.make_one_shot_iterator().get_next()
- with self.cached_session() as sess:
- expected = sess.run(
- self._structuredSparseElement(structure,
- np.concatenate(([5], shape), axis=0),
- dtype))
- actual = sess.run(get_next)
- self._assertEqual(expected, actual)
-
- @parameterized.named_parameters(
- ("1", np.int32([])),
- ("2", np.int32([1])),
- ("3", np.int32([1, 2, 3])),
- )
- def testWindowDatasetBatchSparseDynamicShape(self, shape):
- """Tests batching of dynamically shaped sparse tensor windows.
-
- Args:
- shape: the input shape
- """
-
- shape_t = array_ops.placeholder(dtypes.int32)
- dataset = dataset_ops.Dataset.from_tensors(array_ops.zeros(shape_t)).map(
- self._make_dense_to_sparse_fn(len(shape) == 0)).repeat(5).apply( # pylint: disable=g-explicit-length-test
- grouping.window_dataset(5)).apply(
- grouping._map_x_dataset(batching.batch_window))
- iterator = dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- sess.run(init_op, {shape_t: shape})
- expected = sess.run(
- self._structuredSparseElement(None,
- np.concatenate(([5], shape), axis=0),
- dtypes.int32))
- actual = sess.run(get_next)
- self._assertEqual(expected, actual)
-
- def _structuredRaggedDataset(self, structure, shapes, dtype):
-
- if structure is None:
- return dataset_ops.Dataset.from_tensor_slices(shapes).map(
- lambda shape: array_ops.zeros(shape, dtype=dtype))
- else:
- return dataset_ops.Dataset.zip(
- tuple([
- self._structuredRaggedDataset(substructure, shapes, dtype)
- for substructure in structure
- ]))
-
- @parameterized.named_parameters(
- ("1", None, np.int32([[1], [2], [3]]), dtypes.bool, [-1]),
- ("2", None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
- ("3", None, np.int32([[1], [2], [3]]), dtypes.float32, [-1]),
- ("4", None, np.int32([[1], [2], [3]]), dtypes.string, [-1]),
- ("5", None, np.int32([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]),
- ("6", None, np.int32([[3, 1, 3], [1, 3, 1]]), dtypes.int32, [-1, -1, -1]),
- ("7", (None, None, None), np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
- ("8", (None,
- (None, None)), np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
- ("9", None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
- ("10", None, np.int32([[1], [2], [3]]), dtypes.int32, np.int32([10])),
- )
- def testWindowDatasetPaddedBatchDense(self, structure, shapes, dtype,
- padded_shape):
- """Tests padded batching of dense tensor windows.
-
- Args:
- structure: the input structure
- shapes: the input shapes
- dtype: the input data type
- padded_shape: the shape to pad the output to
- """
-
- def fn(*args):
- if len(args) == 1 and not isinstance(args[0], tuple):
- return batching.padded_batch_window(args[0], padded_shape)
-
- return tuple([
- fn(*arg) if isinstance(arg, tuple) else batching.padded_batch_window(
- arg, padded_shape) for arg in args
- ])
-
- dataset = self._structuredRaggedDataset(structure, shapes, dtype).apply(
- grouping.window_dataset(len(shapes))).apply(
- grouping._map_x_dataset(fn))
- get_next = dataset.make_one_shot_iterator().get_next()
- with self.cached_session() as sess:
- expected_shape = np.maximum(np.amax(shapes, axis=0), padded_shape)
- expected = sess.run(
- self._structuredElement(
- structure,
- np.concatenate((np.int32([len(shapes)]), expected_shape)), dtype))
- actual = sess.run(get_next)
- self._assertEqual(expected, actual)
-
- @parameterized.named_parameters(
- ("1", np.int32([[1], [2], [3]]), [-1]),
- ("2", np.int32([[1, 3], [2, 2], [3, 1]]), [-1, -1]),
- ("3", np.int32([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]),
- )
- def testWindowDatasetPaddedBatchDenseDynamicShape(self, shapes, padded_shape):
- """Tests padded batching of dynamically shaped dense tensor windows.
-
- Args:
- shapes: the input shapes
- padded_shape: the shape to pad the output to
- """
-
- shapes_t = array_ops.placeholder(dtypes.int32)
- dataset = dataset_ops.Dataset.from_tensor_slices(shapes_t).map(
- lambda shape: array_ops.zeros(shape, dtype=dtypes.int32)).apply(
- grouping.window_dataset(len(shapes))).apply(
- grouping._map_x_dataset(
- lambda x: batching.padded_batch_window(x, padded_shape)))
- iterator = dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- sess.run(init_op, {shapes_t: shapes})
- expected_shape = np.maximum(np.amax(shapes, axis=0), padded_shape)
- expected = sess.run(
- self._structuredElement(
- None, np.concatenate((np.int32([len(shapes)]), expected_shape)),
- dtypes.int32))
- actual = sess.run(get_next)
- self._assertEqual(expected, actual)
-
- @parameterized.named_parameters(
- ("1", np.int32([[1]]), np.int32([0])),
- ("2", np.int32([[10], [20]]), np.int32([15])),
- )
- def testWindowDatasetPaddedBatchDenseInvalid(self, shapes, padded_shape):
- """Tests invalid padded batching of dense tensor windows.
-
- Args:
- shapes: the input shapes
- padded_shape: the shape to pad the output to
- """
-
- dataset = dataset_ops.Dataset.from_tensor_slices(shapes).map(
- lambda shape: array_ops.zeros(shape, dtype=dtypes.int32)).apply(
- grouping.window_dataset(len(shapes))).apply(
- grouping._map_x_dataset(
- lambda x: batching.padded_batch_window(x, padded_shape)))
- get_next = dataset.make_one_shot_iterator().get_next()
- with self.cached_session() as sess:
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(get_next)
-
- def _structuredRaggedSparseDataset(self, structure, shapes, dtype):
-
- def map_fn(shape):
- dense_to_sparse = self._make_dense_to_sparse_fn(False)
- return dense_to_sparse(array_ops.zeros(shape, dtype=dtype))
-
- if structure is None:
- return dataset_ops.Dataset.from_tensor_slices(shapes).map(map_fn)
- else:
- return dataset_ops.Dataset.zip(
- tuple([
- self._structuredRaggedSparseDataset(substructure, shapes, dtype)
- for substructure in structure
- ]))
-
- def _structuredRaggedSparseElement(self, structure, shapes, dtype,
- padded_shape):
- if structure is None:
- dense_shape = np.maximum(np.amax(shapes, axis=0), padded_shape)
- values = []
- for shape in shapes:
- dense_to_sparse = self._make_dense_to_sparse_fn(len(shape) == 0) # pylint: disable=g-explicit-length-test
- sparse = dense_to_sparse(array_ops.zeros(shape, dtype=dtype))
- padded_sparse = sparse_tensor.SparseTensor(sparse.indices,
- sparse.values, dense_shape)
- reshaped_sparse = sparse_ops.sparse_reshape(
- padded_sparse,
- array_ops.concat([np.array([1], dtype=np.int64), dense_shape], 0))
- values.append(reshaped_sparse)
- return sparse_ops.sparse_concat(0, values)
- else:
- return tuple([
- self._structuredRaggedSparseElement(substructure, shapes, dtype,
- padded_shape)
- for substructure in structure
- ])
-
- @parameterized.named_parameters(
- ("1", None, np.int64([[1], [2], [3]]), dtypes.bool, [-1]),
- ("2", None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
- ("3", None, np.int64([[1], [2], [3]]), dtypes.float32, [-1]),
- ("4", None, np.int64([[1], [2], [3]]), dtypes.string, [-1]),
- ("5", None, np.int64([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]),
- ("6", None, np.int64([[1, 3, 1], [3, 1, 3]]), dtypes.int32, [-1, -1, -1]),
- ("7", (None, None, None), np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
- ("8", (None,
- (None, None)), np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
- ("9", None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
- ("10", None, np.int64([[1], [2], [3]]), dtypes.int32, np.int64([10])),
- )
- def testWindowDatasetPaddedBatchSparse(self, structure, shapes, dtype,
- padded_shape):
- """Tests padded batching of sparse tensor windows.
-
- Args:
- structure: the input structure
- shapes: the input shapes
- dtype: the input data type
- padded_shape: the shape to pad the output to
- """
-
- def fn(*args):
- if len(args) == 1 and not isinstance(args[0], tuple):
- return batching.padded_batch_window(args[0], padded_shape)
-
- return tuple([
- fn(*arg) if isinstance(arg, tuple) else batching.padded_batch_window(
- arg, padded_shape) for arg in args
- ])
-
- dataset = self._structuredRaggedSparseDataset(
- structure, shapes, dtype).apply(grouping.window_dataset(
- len(shapes))).apply(grouping._map_x_dataset(fn))
- get_next = dataset.make_one_shot_iterator().get_next()
- with self.cached_session() as sess:
- expected = sess.run(
- self._structuredRaggedSparseElement(structure, shapes, dtype,
- padded_shape))
- actual = sess.run(get_next)
- self._assertEqual(expected, actual)
-
- @parameterized.named_parameters(
- ("1", np.int64([[1], [2], [3]]), [-1]),
- ("2", np.int64([[1, 3], [2, 2], [3, 1]]), [-1, -1]),
- ("3", np.int64([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]),
- )
- def testWindowDatasetPaddedBatchSparseDynamicShape(self, shapes,
- padded_shape):
- """Tests padded batching of dynamically shaped sparse tensor windows.
-
- Args:
- shapes: the input shapes
- padded_shape: the shape to pad the output to
- """
-
- shapes_t = array_ops.placeholder(dtypes.int32)
- dataset = dataset_ops.Dataset.from_tensor_slices(shapes_t).map(
- lambda shape: array_ops.zeros(shape, dtype=dtypes.int32)).map(
- self._make_dense_to_sparse_fn(False)
- ).apply(grouping.window_dataset(len(shapes))).apply(
- grouping._map_x_dataset(
- lambda x: batching.padded_batch_window(x, padded_shape)))
- iterator = dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- sess.run(init_op, {shapes_t: shapes})
- expected = sess.run(
- self._structuredRaggedSparseElement(None, shapes, dtypes.int32,
- padded_shape))
- actual = sess.run(get_next)
- self._assertEqual(expected, actual)
-
- @parameterized.named_parameters(
- ("1", np.int64([[1]]), [0]),
- ("2", np.int64([[10], [20]]), [15]),
- )
- def testWindowDatasetPaddedBatchSparseInvalid(self, shapes, padded_shape):
- """Tests invalid padded batching of sparse tensor windows.
-
- Args:
- shapes: the input shapes
- padded_shape: the shape to pad the output to
- """
-
- dataset = dataset_ops.Dataset.from_tensor_slices(shapes).map(
- lambda shape: array_ops.zeros(shape, dtype=dtypes.int32)).map(
- self._make_dense_to_sparse_fn(False)
- ).apply(grouping.window_dataset(len(shapes))).apply(
- grouping._map_x_dataset(
- lambda x: batching.padded_batch_window(x, padded_shape)))
- get_next = dataset.make_one_shot_iterator().get_next()
- with self.cached_session() as sess:
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(get_next)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD
index 5cd1ed542b..34dc2379d0 100644
--- a/tensorflow/contrib/data/python/ops/BUILD
+++ b/tensorflow/contrib/data/python/ops/BUILD
@@ -16,10 +16,7 @@ py_library(
srcs = ["counter.py"],
srcs_version = "PY2AND3",
deps = [
- ":scan_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/experimental/ops:counter",
],
)
@@ -28,12 +25,7 @@ py_library(
srcs = ["get_single_element.py"],
srcs_version = "PY2AND3",
deps = [
- ":grouping",
- "//tensorflow/python:dataset_ops_gen",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
- "//third_party/py/numpy",
+ "//tensorflow/python/data/experimental/ops:get_single_element",
],
)
@@ -44,10 +36,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:dataset_ops_gen",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:training",
- "//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python/data/experimental/ops:iterator_ops",
],
)
@@ -58,15 +47,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dataset_ops_gen",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:random_seed",
- "//tensorflow/python:tensor_shape",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
+ "//tensorflow/python/data/experimental/ops:random_ops",
],
)
@@ -79,7 +60,6 @@ py_library(
deps = [
":batching",
":interleave_ops",
- ":optimization",
":parsing_ops",
":shuffle_ops",
"//tensorflow/python:constant_op",
@@ -91,6 +71,7 @@ py_library(
"//tensorflow/python:platform",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:util",
+ "//tensorflow/python/data/experimental/ops:readers",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:readers",
"//tensorflow/python/data/util:convert",
@@ -106,7 +87,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/experimental/ops:shuffle_ops",
],
)
@@ -125,6 +106,7 @@ py_library(
"//tensorflow/python:math_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:tensor_util",
+ "//tensorflow/python/data/experimental/ops:batching",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/util:convert",
"//tensorflow/python/data/util:nest",
@@ -138,8 +120,7 @@ py_library(
srcs = ["enumerate_ops.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:dtypes",
- "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/experimental/ops:enumerate_ops",
],
)
@@ -148,10 +129,7 @@ py_library(
srcs = ["error_ops.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:experimental_dataset_ops_gen",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
+ "//tensorflow/python/data/experimental/ops:error_ops",
],
)
@@ -160,16 +138,7 @@ py_library(
srcs = ["grouping.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:array_ops",
- "//tensorflow/python:check_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:function",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:tensor_shape",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
+ "//tensorflow/python/data/experimental/ops:grouping",
],
)
@@ -178,30 +147,7 @@ py_library(
srcs = ["interleave_ops.py"],
srcs_version = "PY2AND3",
deps = [
- ":random_ops",
- "//tensorflow/contrib/stateless",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:experimental_dataset_ops_gen",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:util",
- "//tensorflow/python/data/ops:readers",
- "//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
- ],
-)
-
-py_library(
- name = "optimization",
- srcs = ["optimization.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/python:dtypes",
- "//tensorflow/python:experimental_dataset_ops_gen",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
+ "//tensorflow/python/data/experimental/ops:interleave_ops",
],
)
@@ -210,25 +156,7 @@ py_library(
srcs = ["parsing_ops.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:dataset_ops_gen",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:parsing_ops",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python:tensor_shape",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- ],
-)
-
-py_library(
- name = "map_defun",
- srcs = ["map_defun.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/python:dataset_ops_gen",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/experimental/ops:parsing_ops",
],
)
@@ -237,18 +165,7 @@ py_library(
srcs = ["resampling.py"],
srcs_version = "PY2AND3",
deps = [
- ":batching",
- ":interleave_ops",
- ":scan_ops",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:logging_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:random_ops",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
+ "//tensorflow/python/data/experimental/ops:resampling",
],
)
@@ -257,12 +174,7 @@ py_library(
srcs = ["scan_ops.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:dataset_ops_gen",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:function",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
+ "//tensorflow/python/data/experimental/ops:scan_ops",
],
)
@@ -282,31 +194,11 @@ py_library(
)
py_library(
- name = "stats_ops",
- srcs = ["stats_ops.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/python:dataset_ops_gen",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/ops:iterator_ops",
- "//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
- ],
-)
-
-py_library(
name = "threadpool",
srcs = ["threadpool.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:experimental_dataset_ops_gen",
- "//tensorflow/python:resource_variable_ops",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
- "//tensorflow/python/eager:context",
+ "//tensorflow/python/data/experimental/ops:threadpool",
],
)
@@ -317,11 +209,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:dtypes",
- "//tensorflow/python:experimental_dataset_ops_gen",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
+ "//tensorflow/python/data/experimental/ops:unique",
],
)
@@ -332,20 +220,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:dtypes",
- "//tensorflow/python/data/ops:dataset_ops",
- ],
-)
-
-py_library(
- name = "indexed_dataset_ops",
- srcs = ["indexed_dataset_ops.py"],
- deps = [
- "//tensorflow/python:experimental_dataset_ops_gen",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
+ "//tensorflow/python/data/experimental/ops:writers",
],
)
@@ -353,11 +228,7 @@ py_library(
name = "prefetching_ops",
srcs = ["prefetching_ops.py"],
deps = [
- "//tensorflow/python:experimental_dataset_ops_gen",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
+ "//tensorflow/python/data/experimental/ops:prefetching_ops",
],
)
@@ -370,17 +241,14 @@ py_library(
":error_ops",
":get_single_element",
":grouping",
- ":indexed_dataset_ops",
":interleave_ops",
- ":map_defun",
- ":optimization",
":prefetching_ops",
+ ":random_ops",
":readers",
":resampling",
":scan_ops",
":shuffle_ops",
":sliding",
- ":stats_ops",
":threadpool",
":unique",
":writers",
diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py
index 7a0f221284..8c60459ca8 100644
--- a/tensorflow/contrib/data/python/ops/batching.py
+++ b/tensorflow/contrib/data/python/ops/batching.py
@@ -17,134 +17,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import numpy as np
-
-from tensorflow.contrib.data.python.ops import get_single_element
-from tensorflow.contrib.data.python.ops import grouping
from tensorflow.contrib.framework import with_shape
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.util import convert
+from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.util import nest
-from tensorflow.python.data.util import sparse
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import check_ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import gen_array_ops
-from tensorflow.python.ops import gen_dataset_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import sparse_ops
from tensorflow.python.util import deprecation
-def batch_window(dataset):
- """Batches a window of tensors.
-
- Args:
- dataset: the input dataset.
-
- Returns:
- A `Tensor` representing the batch of the entire input dataset.
- """
- if isinstance(dataset.output_classes, tuple):
- raise TypeError("Input dataset expected to have a single component")
- if dataset.output_classes is ops.Tensor:
- return _batch_dense_window(dataset)
- elif dataset.output_classes is sparse_tensor.SparseTensor:
- return _batch_sparse_window(dataset)
- else:
- raise TypeError("Unsupported dataset type: %s" % dataset.output_classes)
-
-
-def _batch_dense_window(dataset):
- """Batches a window of dense tensors."""
-
- def key_fn(_):
- return np.int64(0)
-
- def shape_init_fn(_):
- return array_ops.shape(first_element)
-
- def shape_reduce_fn(state, value):
- check_ops.assert_equal(state, array_ops.shape(value))
- return state
-
- def finalize_fn(state):
- return state
-
- if dataset.output_shapes.is_fully_defined():
- shape = dataset.output_shapes
- else:
- first_element = get_single_element.get_single_element(dataset.take(1))
- shape_reducer = grouping.Reducer(shape_init_fn, shape_reduce_fn,
- finalize_fn)
- shape = get_single_element.get_single_element(
- dataset.apply(grouping.group_by_reducer(key_fn, shape_reducer)))
-
- def batch_init_fn(_):
- batch_shape = array_ops.concat([[0], shape], 0)
- return gen_array_ops.empty(batch_shape, dtype=dataset.output_types)
-
- def batch_reduce_fn(state, value):
- return array_ops.concat([state, [value]], 0)
-
- batch_reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn)
- return get_single_element.get_single_element(
- dataset.apply(grouping.group_by_reducer(key_fn, batch_reducer)))
-
-
-def _batch_sparse_window(dataset):
- """Batches a window of sparse tensors."""
-
- def key_fn(_):
- return np.int64(0)
-
- def shape_init_fn(_):
- return first_element.dense_shape
-
- def shape_reduce_fn(state, value):
- check_ops.assert_equal(state, value.dense_shape)
- return state
-
- def finalize_fn(state):
- return state
-
- if dataset.output_shapes.is_fully_defined():
- shape = dataset.output_shapes
- else:
- first_element = get_single_element.get_single_element(dataset.take(1))
- shape_reducer = grouping.Reducer(shape_init_fn, shape_reduce_fn,
- finalize_fn)
- shape = get_single_element.get_single_element(
- dataset.apply(grouping.group_by_reducer(key_fn, shape_reducer)))
-
- def batch_init_fn(_):
- indices_shape = array_ops.concat([[0], [array_ops.size(shape) + 1]], 0)
- return sparse_tensor.SparseTensor(
- indices=gen_array_ops.empty(indices_shape, dtype=dtypes.int64),
- values=constant_op.constant([], shape=[0], dtype=dataset.output_types),
- dense_shape=array_ops.concat(
- [np.array([0], dtype=np.int64),
- math_ops.cast(shape, dtypes.int64)], 0))
-
- def batch_reduce_fn(state, value):
- return sparse_ops.sparse_concat(0, [state, value])
-
- def reshape_fn(value):
- return sparse_ops.sparse_reshape(
- value,
- array_ops.concat([np.array([1], dtype=np.int64), value.dense_shape], 0))
-
- batch_reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn)
- return get_single_element.get_single_element(
- dataset.map(reshape_fn).apply(
- grouping.group_by_reducer(key_fn, batch_reducer)))
-
-
+@deprecation.deprecated(
+ None, "Use `tf.data.experimental.dense_to_sparse_batch(...)`.")
def dense_to_sparse_batch(batch_size, row_shape):
"""A transformation that batches ragged elements into `tf.SparseTensor`s.
@@ -187,201 +67,10 @@ def dense_to_sparse_batch(batch_size, row_shape):
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
-
- def _apply_fn(dataset):
- return _DenseToSparseBatchDataset(dataset, batch_size, row_shape)
-
- return _apply_fn
-
-
-def padded_batch_window(dataset, padded_shape, padding_value=None):
- """Batches a window of tensors with padding.
-
- Args:
- dataset: the input dataset.
- padded_shape: (Optional.) `tf.TensorShape` or `tf.int64` vector tensor-like
- object representing the shape to which the input elements should be padded
- prior to batching. Any unknown dimensions (e.g. `tf.Dimension(None)` in a
- `tf.TensorShape` or `-1` in a tensor-like object) will be padded to the
- maximum size of that dimension in each batch.
- padding_value: (Optional.) A scalar-shaped `tf.Tensor`, representing the
- padding value to use. Defaults are `0` for numeric types and the empty
- string for string types. If `dataset` contains `tf.SparseTensor`, this
- value is ignored.
-
- Returns:
- A `Tensor` representing the batch of the entire input dataset.
-
- Raises:
- ValueError: if invalid arguments are provided.
- """
- if not issubclass(dataset.output_classes,
- (ops.Tensor, sparse_tensor.SparseTensor)):
- raise TypeError("Input dataset expected to have a single tensor component")
- if issubclass(dataset.output_classes, (ops.Tensor)):
- return _padded_batch_dense_window(dataset, padded_shape, padding_value)
- elif issubclass(dataset.output_classes, (sparse_tensor.SparseTensor)):
- if padding_value is not None:
- raise ValueError("Padding value not allowed for sparse tensors")
- return _padded_batch_sparse_window(dataset, padded_shape)
- else:
- raise TypeError("Unsupported dataset type: %s" % dataset.output_classes)
-
-
-def _padded_batch_dense_window(dataset, padded_shape, padding_value=None):
- """Batches a window of dense tensors with padding."""
-
- padded_shape = math_ops.cast(
- convert.partial_shape_to_tensor(padded_shape), dtypes.int32)
-
- def key_fn(_):
- return np.int64(0)
-
- def max_init_fn(_):
- return padded_shape
-
- def max_reduce_fn(state, value):
- """Computes the maximum shape to pad to."""
- condition = math_ops.reduce_all(
- math_ops.logical_or(
- math_ops.less_equal(array_ops.shape(value), padded_shape),
- math_ops.equal(padded_shape, -1)))
- assert_op = control_flow_ops.Assert(condition, [
- "Actual shape greater than padded shape: ",
- array_ops.shape(value), padded_shape
- ])
- with ops.control_dependencies([assert_op]):
- return math_ops.maximum(state, array_ops.shape(value))
-
- def finalize_fn(state):
- return state
-
- # Compute the padded shape.
- max_reducer = grouping.Reducer(max_init_fn, max_reduce_fn, finalize_fn)
- padded_shape = get_single_element.get_single_element(
- dataset.apply(grouping.group_by_reducer(key_fn, max_reducer)))
-
- if padding_value is None:
- if dataset.output_types == dtypes.string:
- padding_value = ""
- elif dataset.output_types == dtypes.bool:
- padding_value = False
- elif dataset.output_types == dtypes.variant:
- raise TypeError("Unable to create padding for field of type 'variant'")
- else:
- padding_value = 0
-
- def batch_init_fn(_):
- batch_shape = array_ops.concat(
- [np.array([0], dtype=np.int32), padded_shape], 0)
- return gen_array_ops.empty(batch_shape, dtype=dataset.output_types)
-
- def batch_reduce_fn(state, value):
- return array_ops.concat([state, [value]], 0)
-
- def pad_fn(value):
- shape = array_ops.shape(value)
- left = array_ops.zeros_like(shape)
- right = padded_shape - shape
- return array_ops.pad(
- value, array_ops.stack([left, right], 1), constant_values=padding_value)
-
- batch_reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn)
- return get_single_element.get_single_element(
- dataset.map(pad_fn).apply(
- grouping.group_by_reducer(key_fn, batch_reducer)))
-
-
-def _padded_batch_sparse_window(dataset, padded_shape):
- """Batches a window of sparse tensors with padding."""
-
- def key_fn(_):
- return np.int64(0)
-
- def max_init_fn(_):
- return convert.partial_shape_to_tensor(padded_shape)
-
- def max_reduce_fn(state, value):
- """Computes the maximum shape to pad to."""
- condition = math_ops.reduce_all(
- math_ops.logical_or(
- math_ops.less_equal(value.dense_shape, padded_shape),
- math_ops.equal(padded_shape, -1)))
- assert_op = control_flow_ops.Assert(condition, [
- "Actual shape greater than padded shape: ", value.dense_shape,
- padded_shape
- ])
- with ops.control_dependencies([assert_op]):
- return math_ops.maximum(state, value.dense_shape)
-
- def finalize_fn(state):
- return state
-
- # Compute the padded shape.
- max_reducer = grouping.Reducer(max_init_fn, max_reduce_fn, finalize_fn)
- padded_shape = get_single_element.get_single_element(
- dataset.apply(grouping.group_by_reducer(key_fn, max_reducer)))
-
- def batch_init_fn(_):
- indices_shape = array_ops.concat([[0], [array_ops.size(padded_shape) + 1]],
- 0)
- return sparse_tensor.SparseTensor(
- indices=gen_array_ops.empty(indices_shape, dtype=dtypes.int64),
- values=constant_op.constant([], shape=[0], dtype=dataset.output_types),
- dense_shape=array_ops.concat(
- [np.array([0], dtype=np.int64), padded_shape], 0))
-
- def batch_reduce_fn(state, value):
- padded_value = sparse_tensor.SparseTensor(
- indices=value.indices, values=value.values, dense_shape=padded_shape)
- reshaped_value = sparse_ops.sparse_reshape(
- padded_value,
- array_ops.concat(
- [np.array([1], dtype=np.int64), padded_value.dense_shape], 0))
- return sparse_ops.sparse_concat(0, [state, reshaped_value])
-
- reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn)
- return get_single_element.get_single_element(
- dataset.apply(grouping.group_by_reducer(key_fn, reducer)))
-
-
-class _UnbatchDataset(dataset_ops.UnaryDataset):
- """A dataset that splits the elements of its input into multiple elements."""
-
- def __init__(self, input_dataset):
- """See `unbatch()` for more details."""
- super(_UnbatchDataset, self).__init__(input_dataset)
- flat_shapes = nest.flatten(input_dataset.output_shapes)
- if any(s.ndims == 0 for s in flat_shapes):
- raise ValueError("Cannot unbatch an input with scalar components.")
- known_batch_dim = tensor_shape.Dimension(None)
- for s in flat_shapes:
- try:
- known_batch_dim = known_batch_dim.merge_with(s[0])
- except ValueError:
- raise ValueError("Cannot unbatch an input whose components have "
- "different batch sizes.")
- self._input_dataset = input_dataset
-
- def _as_variant_tensor(self):
- return gen_dataset_ops.unbatch_dataset(
- self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- **dataset_ops.flat_structure(self))
-
- @property
- def output_classes(self):
- return self._input_dataset.output_classes
-
- @property
- def output_shapes(self):
- return nest.map_structure(lambda s: s[1:],
- self._input_dataset.output_shapes)
-
- @property
- def output_types(self):
- return self._input_dataset.output_types
+ return batching.dense_to_sparse_batch(batch_size, row_shape)
+@deprecation.deprecated(None, "Use `tf.data.experimental.unbatch()`.")
def unbatch():
"""Splits elements of a dataset into multiple elements on the batch dimension.
@@ -403,39 +92,7 @@ def unbatch():
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
-
- def _apply_fn(dataset):
- """Function from `Dataset` to `Dataset` that applies the transformation."""
- if not sparse.any_sparse(dataset.output_classes):
- return _UnbatchDataset(dataset)
-
- # NOTE(mrry): We must ensure that any SparseTensors in `dataset`
- # are normalized to the rank-1 dense representation, so that the
- # sparse-oblivious unbatching logic will slice them
- # appropriately. This leads to a somewhat inefficient re-encoding step
- # for all SparseTensor components.
- # TODO(mrry): Consider optimizing this in future
- # if it turns out to be a bottleneck.
- def normalize(arg, *rest):
- if rest:
- return sparse.serialize_many_sparse_tensors((arg,) + rest)
- else:
- return sparse.serialize_many_sparse_tensors(arg)
-
- normalized_dataset = dataset.map(normalize)
-
- # NOTE(mrry): Our `map()` has lost information about the sparseness
- # of any SparseTensor components, so re-apply the structure of the
- # original dataset.
- restructured_dataset = _RestructuredDataset(
- normalized_dataset,
- dataset.output_types,
- dataset.output_shapes,
- dataset.output_classes,
- allow_unsafe_cast=True)
- return _UnbatchDataset(restructured_dataset)
-
- return _apply_fn
+ return batching.unbatch()
@deprecation.deprecated(
@@ -514,135 +171,8 @@ def padded_batch_and_drop_remainder(batch_size,
return _apply_fn
-class _DenseToSparseBatchDataset(dataset_ops.UnaryDataset):
- """A `Dataset` that batches ragged dense elements into `tf.SparseTensor`s."""
-
- def __init__(self, input_dataset, batch_size, row_shape):
- """See `Dataset.dense_to_sparse_batch()` for more details."""
- super(_DenseToSparseBatchDataset, self).__init__(input_dataset)
- if not isinstance(input_dataset.output_types, dtypes.DType):
- raise TypeError("DenseToSparseDataset requires an input whose elements "
- "have a single component, whereas the input has %r." %
- input_dataset.output_types)
- self._input_dataset = input_dataset
- self._batch_size = batch_size
- self._row_shape = row_shape
-
- def _as_variant_tensor(self):
- return gen_dataset_ops.dense_to_sparse_batch_dataset(
- self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- self._batch_size,
- row_shape=convert.partial_shape_to_tensor(self._row_shape),
- **dataset_ops.flat_structure(self))
-
- @property
- def output_classes(self):
- return sparse_tensor.SparseTensor
-
- @property
- def output_shapes(self):
- return tensor_shape.vector(None).concatenate(self._row_shape)
-
- @property
- def output_types(self):
- return self._input_dataset.output_types
-
-
-class _RestructuredDataset(dataset_ops.UnaryDataset):
- """An internal helper for changing the structure and shape of a dataset."""
-
- def __init__(self,
- dataset,
- output_types,
- output_shapes=None,
- output_classes=None,
- allow_unsafe_cast=False):
- """Creates a new dataset with the given output types and shapes.
-
- The given `dataset` must have a structure that is convertible:
- * `dataset.output_types` must be the same as `output_types` module nesting.
- * Each shape in `dataset.output_shapes` must be compatible with each shape
- in `output_shapes` (if given).
-
- Note: This helper permits "unsafe casts" for shapes, equivalent to using
- `tf.Tensor.set_shape()` where domain-specific knowledge is available.
-
- Args:
- dataset: A `Dataset` object.
- output_types: A nested structure of `tf.DType` objects.
- output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects.
- If omitted, the shapes will be inherited from `dataset`.
- output_classes: (Optional.) A nested structure of class types.
- If omitted, the class types will be inherited from `dataset`.
- allow_unsafe_cast: (Optional.) If `True`, the caller may switch the
- reported output types and shapes of the restructured dataset, e.g. to
- switch a sparse tensor represented as `tf.variant` to its user-visible
- type and shape.
-
- Raises:
- ValueError: If either `output_types` or `output_shapes` is not compatible
- with the structure of `dataset`.
- """
- super(_RestructuredDataset, self).__init__(dataset)
- self._input_dataset = dataset
-
- if not allow_unsafe_cast:
- # Validate that the types are compatible.
- output_types = nest.map_structure(dtypes.as_dtype, output_types)
- flat_original_types = nest.flatten(dataset.output_types)
- flat_new_types = nest.flatten(output_types)
- if flat_original_types != flat_new_types:
- raise ValueError(
- "Dataset with output types %r cannot be restructured to have "
- "output types %r" % (dataset.output_types, output_types))
-
- self._output_types = output_types
-
- if output_shapes is None:
- # Inherit shapes from the original `dataset`.
- self._output_shapes = nest.pack_sequence_as(output_types,
- nest.flatten(
- dataset.output_shapes))
- else:
- if not allow_unsafe_cast:
- # Validate that the shapes are compatible.
- nest.assert_same_structure(output_types, output_shapes)
- flat_original_shapes = nest.flatten(dataset.output_shapes)
- flat_new_shapes = nest.flatten_up_to(output_types, output_shapes)
-
- for original_shape, new_shape in zip(flat_original_shapes,
- flat_new_shapes):
- if not original_shape.is_compatible_with(new_shape):
- raise ValueError(
- "Dataset with output shapes %r cannot be restructured to have "
- "incompatible output shapes %r" % (dataset.output_shapes,
- output_shapes))
- self._output_shapes = nest.map_structure_up_to(
- output_types, tensor_shape.as_shape, output_shapes)
- if output_classes is None:
- # Inherit class types from the original `dataset`.
- self._output_classes = nest.pack_sequence_as(output_types,
- nest.flatten(
- dataset.output_classes))
- else:
- self._output_classes = output_classes
-
- def _as_variant_tensor(self):
- return self._input_dataset._as_variant_tensor() # pylint: disable=protected-access
-
- @property
- def output_classes(self):
- return self._output_classes
-
- @property
- def output_types(self):
- return self._output_types
-
- @property
- def output_shapes(self):
- return self._output_shapes
-
-
+# TODO(b/116817045): Move this to `tf.data.experimental` when the `with_shape()`
+# function is available in the core.
def assert_element_shape(expected_shapes):
"""Assert the shape of this `Dataset`.
@@ -687,7 +217,8 @@ def assert_element_shape(expected_shapes):
def _apply_fn(dataset):
output_shapes = _merge_output_shapes(dataset.output_shapes,
expected_shapes)
- return _RestructuredDataset(
+ # pylint: disable=protected-access
+ return batching._RestructuredDataset(
dataset.map(_check_shape),
dataset.output_types,
output_shapes=output_shapes,
@@ -696,49 +227,7 @@ def assert_element_shape(expected_shapes):
return _apply_fn
-class _MapAndBatchDataset(dataset_ops.MapDataset):
- """A `Dataset` that maps a function over a batch of elements."""
-
- def __init__(self, input_dataset, map_func, batch_size, num_parallel_calls,
- drop_remainder):
- """See `Dataset.map()` for details."""
- super(_MapAndBatchDataset, self).__init__(input_dataset, map_func)
- self._batch_size_t = ops.convert_to_tensor(
- batch_size, dtype=dtypes.int64, name="batch_size")
- self._num_parallel_calls_t = ops.convert_to_tensor(
- num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls")
- self._drop_remainder_t = ops.convert_to_tensor(
- drop_remainder, dtype=dtypes.bool, name="drop_remainder")
-
- self._batch_size = batch_size
- self._drop_remainder = drop_remainder
-
- def _as_variant_tensor(self):
- # pylint: disable=protected-access
- input_resource = self._input_dataset._as_variant_tensor()
- return gen_dataset_ops.map_and_batch_dataset_v2(
- input_resource,
- self._map_func.captured_inputs,
- f=self._map_func,
- batch_size=self._batch_size_t,
- num_parallel_calls=self._num_parallel_calls_t,
- drop_remainder=self._drop_remainder_t,
- **dataset_ops.flat_structure(self))
- # pylint: enable=protected-access
-
- @property
- def output_shapes(self):
- dim = self._batch_size if self._drop_remainder else None
- return nest.pack_sequence_as(self._output_shapes, [
- tensor_shape.vector(dim).concatenate(s)
- for s in nest.flatten(self._output_shapes)
- ])
-
- @property
- def output_types(self):
- return self._output_types
-
-
+@deprecation.deprecated(None, "Use `tf.data.experimental.map_and_batch(...)`.")
def map_and_batch(map_func,
batch_size,
num_parallel_batches=None,
@@ -779,17 +268,5 @@ def map_and_batch(map_func,
ValueError: If both `num_parallel_batches` and `num_parallel_calls` are
specified.
"""
-
- if num_parallel_batches is None and num_parallel_calls is None:
- num_parallel_calls = batch_size
- elif num_parallel_batches is not None and num_parallel_calls is None:
- num_parallel_calls = batch_size * num_parallel_batches
- elif num_parallel_batches is not None and num_parallel_calls is not None:
- raise ValueError("The `num_parallel_batches` and `num_parallel_calls` "
- "arguments are mutually exclusive.")
-
- def _apply_fn(dataset):
- return _MapAndBatchDataset(dataset, map_func, batch_size,
- num_parallel_calls, drop_remainder)
-
- return _apply_fn
+ return batching.map_and_batch(map_func, batch_size, num_parallel_batches,
+ drop_remainder, num_parallel_calls)
diff --git a/tensorflow/contrib/data/python/ops/counter.py b/tensorflow/contrib/data/python/ops/counter.py
index 6ef65f9624..4ff5bf3e39 100644
--- a/tensorflow/contrib/data/python/ops/counter.py
+++ b/tensorflow/contrib/data/python/ops/counter.py
@@ -17,13 +17,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.ops import scan_ops
-
-from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.experimental.ops import counter
from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
+from tensorflow.python.util import deprecation
+@deprecation.deprecated(None, "Use `tf.data.experimental.Counter(...)`.")
def Counter(start=0, step=1, dtype=dtypes.int64):
"""Creates a `Dataset` that counts from `start` in steps of size `step`.
@@ -46,8 +45,4 @@ def Counter(start=0, step=1, dtype=dtypes.int64):
Returns:
A `Dataset` of scalar `dtype` elements.
"""
- with ops.name_scope("counter"):
- start = ops.convert_to_tensor(start, dtype=dtype, name="start")
- step = ops.convert_to_tensor(step, dtype=dtype, name="step")
- return dataset_ops.Dataset.from_tensors(0).repeat(None).apply(
- scan_ops.scan(start, lambda state, _: (state + step, state)))
+ return counter.Counter(start, step, dtype)
diff --git a/tensorflow/contrib/data/python/ops/enumerate_ops.py b/tensorflow/contrib/data/python/ops/enumerate_ops.py
index 490281e0d2..a21da4d3ec 100644
--- a/tensorflow/contrib/data/python/ops/enumerate_ops.py
+++ b/tensorflow/contrib/data/python/ops/enumerate_ops.py
@@ -17,12 +17,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import numpy as np
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import dtypes
+from tensorflow.python.data.experimental.ops import enumerate_ops
+from tensorflow.python.util import deprecation
+@deprecation.deprecated(None,
+ "Use `tf.data.experimental.enumerate_dataset(...)`.")
def enumerate_dataset(start=0):
"""A transformation that enumerate the elements of a dataset.
@@ -49,10 +50,4 @@ def enumerate_dataset(start=0):
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
-
- def _apply_fn(dataset):
- max_value = np.iinfo(dtypes.int64.as_numpy_dtype).max
- return dataset_ops.Dataset.zip((dataset_ops.Dataset.range(start, max_value),
- dataset))
-
- return _apply_fn
+ return enumerate_ops.enumerate_dataset(start)
diff --git a/tensorflow/contrib/data/python/ops/error_ops.py b/tensorflow/contrib/data/python/ops/error_ops.py
index f962e623ee..0559a2e09c 100644
--- a/tensorflow/contrib/data/python/ops/error_ops.py
+++ b/tensorflow/contrib/data/python/ops/error_ops.py
@@ -17,10 +17,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.ops import gen_experimental_dataset_ops
+from tensorflow.python.data.experimental.ops import error_ops
+from tensorflow.python.util import deprecation
+@deprecation.deprecated(None, "Use `tf.data.experimental.ignore_errors()`.")
def ignore_errors():
"""Creates a `Dataset` from another `Dataset` and silently ignores any errors.
@@ -43,34 +44,4 @@ def ignore_errors():
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
-
- def _apply_fn(dataset):
- return _IgnoreErrorsDataset(dataset)
-
- return _apply_fn
-
-
-class _IgnoreErrorsDataset(dataset_ops.UnaryDataset):
- """A `Dataset` that silently ignores errors when computing its input."""
-
- def __init__(self, input_dataset):
- """See `Dataset.ignore_errors()` for details."""
- super(_IgnoreErrorsDataset, self).__init__(input_dataset)
- self._input_dataset = input_dataset
-
- def _as_variant_tensor(self):
- return gen_experimental_dataset_ops.experimental_ignore_errors_dataset(
- self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- **dataset_ops.flat_structure(self))
-
- @property
- def output_classes(self):
- return self._input_dataset.output_classes
-
- @property
- def output_shapes(self):
- return self._input_dataset.output_shapes
-
- @property
- def output_types(self):
- return self._input_dataset.output_types
+ return error_ops.ignore_errors()
diff --git a/tensorflow/contrib/data/python/ops/get_single_element.py b/tensorflow/contrib/data/python/ops/get_single_element.py
index a6713b017a..58ad9eea90 100644
--- a/tensorflow/contrib/data/python/ops/get_single_element.py
+++ b/tensorflow/contrib/data/python/ops/get_single_element.py
@@ -19,13 +19,13 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.data.python.ops import grouping
+from tensorflow.python.data.experimental.ops import get_single_element as experimental_get_single_element
from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.util import nest
-from tensorflow.python.data.util import sparse
-from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.util import deprecation
+@deprecation.deprecated(None,
+ "Use `tf.data.experimental.get_single_element(...)`.")
def get_single_element(dataset):
"""Returns the single element in `dataset` as a nested structure of tensors.
@@ -61,18 +61,10 @@ def get_single_element(dataset):
InvalidArgumentError (at runtime): if `dataset` does not contain exactly
one element.
"""
- if not isinstance(dataset, dataset_ops.Dataset):
- raise TypeError("`dataset` must be a `tf.data.Dataset` object.")
-
- nested_ret = nest.pack_sequence_as(
- dataset.output_types, gen_dataset_ops.dataset_to_single_element(
- dataset._as_variant_tensor(), # pylint: disable=protected-access
- **dataset_ops.flat_structure(dataset)))
- return sparse.deserialize_sparse_tensors(
- nested_ret, dataset.output_types, dataset.output_shapes,
- dataset.output_classes)
+ return experimental_get_single_element.get_single_element(dataset)
+@deprecation.deprecated(None, "Use `tf.data.Dataset.reduce(...)`.")
def reduce_dataset(dataset, reducer):
"""Returns the result of reducing the `dataset` using `reducer`.
@@ -90,11 +82,4 @@ def reduce_dataset(dataset, reducer):
if not isinstance(dataset, dataset_ops.Dataset):
raise TypeError("`dataset` must be a `tf.data.Dataset` object.")
- # The sentinel dataset is used in case the reduced dataset is empty.
- sentinel_dataset = dataset_ops.Dataset.from_tensors(
- reducer.finalize_func(reducer.init_func(np.int64(0))))
- reduced_dataset = dataset.apply(
- grouping.group_by_reducer(lambda x: np.int64(0), reducer))
-
- return get_single_element(
- reduced_dataset.concatenate(sentinel_dataset).take(1))
+ return dataset.reduce(reducer.init_func(np.int64(0)), reducer.reduce_func)
diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py
index 7cae33beb3..a99dc2f29a 100644
--- a/tensorflow/contrib/data/python/ops/grouping.py
+++ b/tensorflow/contrib/data/python/ops/grouping.py
@@ -17,20 +17,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import numpy as np
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.util import nest
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import check_ops
-from tensorflow.python.ops import gen_dataset_ops
-from tensorflow.python.ops import math_ops
+from tensorflow.python.data.experimental.ops import grouping
+from tensorflow.python.util import deprecation
+@deprecation.deprecated(None,
+ "Use `tf.data.experimental.group_by_reducer(...)`.")
def group_by_reducer(key_func, reducer):
"""A transformation that groups elements and performs a reduction.
@@ -52,14 +45,11 @@ def group_by_reducer(key_func, reducer):
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
-
- def _apply_fn(dataset):
- """Function from `Dataset` to `Dataset` that applies the transformation."""
- return _GroupByReducerDataset(dataset, key_func, reducer)
-
- return _apply_fn
+ return grouping.group_by_reducer(key_func, reducer)
+@deprecation.deprecated(None,
+ "Use `tf.data.experimental.group_by_window(...)`.")
def group_by_window(key_func,
reduce_func,
window_size=None,
@@ -98,27 +88,12 @@ def group_by_window(key_func,
ValueError: if neither or both of {`window_size`, `window_size_func`} are
passed.
"""
- if (window_size is not None and window_size_func or
- not (window_size is not None or window_size_func)):
- raise ValueError("Must pass either window_size or window_size_func.")
-
- if window_size is not None:
-
- def constant_window_func(unused_key):
- return ops.convert_to_tensor(window_size, dtype=dtypes.int64)
-
- window_size_func = constant_window_func
-
- assert window_size_func is not None
-
- def _apply_fn(dataset):
- """Function from `Dataset` to `Dataset` that applies the transformation."""
- return _GroupByWindowDataset(dataset, key_func, reduce_func,
- window_size_func)
-
- return _apply_fn
+ return grouping.group_by_window(key_func, reduce_func, window_size,
+ window_size_func)
+@deprecation.deprecated(
+ None, "Use `tf.data.experimental.bucket_by_sequence_length(...)`.")
def bucket_by_sequence_length(element_length_func,
bucket_boundaries,
bucket_batch_sizes,
@@ -163,342 +138,12 @@ def bucket_by_sequence_length(element_length_func,
Raises:
ValueError: if `len(bucket_batch_sizes) != len(bucket_boundaries) + 1`.
"""
- with ops.name_scope("bucket_by_seq_length"):
- if len(bucket_batch_sizes) != (len(bucket_boundaries) + 1):
- raise ValueError(
- "len(bucket_batch_sizes) must equal len(bucket_boundaries) + 1")
-
- batch_sizes = constant_op.constant(bucket_batch_sizes, dtype=dtypes.int64)
-
- def element_to_bucket_id(*args):
- """Return int64 id of the length bucket for this element."""
- seq_length = element_length_func(*args)
-
- boundaries = list(bucket_boundaries)
- buckets_min = [np.iinfo(np.int32).min] + boundaries
- buckets_max = boundaries + [np.iinfo(np.int32).max]
- conditions_c = math_ops.logical_and(
- math_ops.less_equal(buckets_min, seq_length),
- math_ops.less(seq_length, buckets_max))
- bucket_id = math_ops.reduce_min(array_ops.where(conditions_c))
-
- return bucket_id
-
- def window_size_fn(bucket_id):
- # The window size is set to the batch size for this bucket
- window_size = batch_sizes[bucket_id]
- return window_size
-
- def make_padded_shapes(shapes, none_filler=None):
- padded = []
- for shape in nest.flatten(shapes):
- shape = tensor_shape.TensorShape(shape)
- shape = [
- none_filler if d.value is None else d
- for d in shape
- ]
- padded.append(shape)
- return nest.pack_sequence_as(shapes, padded)
-
- def batching_fn(bucket_id, grouped_dataset):
- """Batch elements in dataset."""
- batch_size = window_size_fn(bucket_id)
- if no_padding:
- return grouped_dataset.batch(batch_size)
- none_filler = None
- if pad_to_bucket_boundary:
- err_msg = ("When pad_to_bucket_boundary=True, elements must have "
- "length < max(bucket_boundaries).")
- check = check_ops.assert_less(
- bucket_id,
- constant_op.constant(len(bucket_batch_sizes) - 1,
- dtype=dtypes.int64),
- message=err_msg)
- with ops.control_dependencies([check]):
- boundaries = constant_op.constant(bucket_boundaries,
- dtype=dtypes.int64)
- bucket_boundary = boundaries[bucket_id]
- none_filler = bucket_boundary - 1
- shapes = make_padded_shapes(
- padded_shapes or grouped_dataset.output_shapes,
- none_filler=none_filler)
- return grouped_dataset.padded_batch(batch_size, shapes, padding_values)
-
- def _apply_fn(dataset):
- return dataset.apply(
- group_by_window(element_to_bucket_id, batching_fn,
- window_size_func=window_size_fn))
-
- return _apply_fn
-
-
-def _map_x_dataset(map_func):
- """A transformation that maps `map_func` across its input.
-
- This transformation is similar to `tf.data.Dataset.map`, but in addition to
- supporting dense and sparse tensor inputs, it also supports dataset inputs.
-
- Args:
- map_func: A function mapping a nested structure of tensors and/or datasets
- (having shapes and types defined by `self.output_shapes` and
- `self.output_types`) to another nested structure of tensors and/or
- datasets.
-
- Returns:
- Dataset: A `Dataset`.
- """
-
- def _apply_fn(dataset):
- """Function from `Dataset` to `Dataset` that applies the transformation."""
- return _MapXDataset(dataset, map_func)
-
- return _apply_fn
-
-
-# TODO(b/115382007) Remove this once canned reducers move to core.
-def window_dataset(window_size):
- """A transformation that creates window datasets from the input dataset.
-
- The resulting datasets will contain `window_size` elements (or
- `N % window_size` for the last dataset if `window_size` does not divide the
- number of input elements `N` evenly).
-
- Args:
- window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
- consecutive elements of the input dataset to combine into a window.
-
- Returns:
- Dataset: A `Dataset`.
- """
-
- def _apply_fn(dataset):
- return dataset_ops.WindowDataset(
- dataset,
- size=window_size,
- shift=window_size,
- stride=1,
- drop_remainder=False)
-
- return _apply_fn
-
-
-class _GroupByReducerDataset(dataset_ops.UnaryDataset):
- """A `Dataset` that groups its input and performs a reduction."""
-
- def __init__(self, input_dataset, key_func, reducer):
- """See `group_by_reducer()` for details."""
- super(_GroupByReducerDataset, self).__init__(input_dataset)
+ return grouping.bucket_by_sequence_length(
+ element_length_func, bucket_boundaries, bucket_batch_sizes, padded_shapes,
+ padding_values, pad_to_bucket_boundary, no_padding)
- self._input_dataset = input_dataset
- self._make_key_func(key_func, input_dataset)
- self._make_init_func(reducer.init_func)
- self._make_reduce_func(reducer.reduce_func, input_dataset)
- self._make_finalize_func(reducer.finalize_func)
-
- def _make_key_func(self, key_func, input_dataset):
- """Make wrapping Defun for key_func."""
- wrapped_func = dataset_ops.StructuredFunctionWrapper(
- key_func, "tf.contrib.data.group_by_reducer()", input_dataset)
- if not (
- wrapped_func.output_types == dtypes.int64 and
- wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())):
- raise ValueError(
- "`key_func` must return a single tf.int64 tensor. "
- "Got type=%s and shape=%s"
- % (wrapped_func.output_types, wrapped_func.output_shapes))
- self._key_func = wrapped_func.function
-
- def _make_init_func(self, init_func):
- """Make wrapping Defun for init_func."""
- wrapped_func = dataset_ops.StructuredFunctionWrapper(
- init_func, "tf.contrib.data.group_by_reducer()",
- input_classes=ops.Tensor, input_shapes=tensor_shape.scalar(),
- input_types=dtypes.int64)
- self._init_func = wrapped_func.function
- self._state_classes = wrapped_func.output_classes
- self._state_shapes = wrapped_func.output_shapes
- self._state_types = wrapped_func.output_types
-
- def _make_reduce_func(self, reduce_func, input_dataset):
- """Make wrapping Defun for reduce_func."""
-
- # Iteratively rerun the reduce function until reaching a fixed point on
- # `self._state_shapes`.
- need_to_rerun = True
- while need_to_rerun:
-
- wrapped_func = dataset_ops.StructuredFunctionWrapper(
- reduce_func, "tf.contrib.data.group_by_reducer()",
- input_classes=(self._state_classes, input_dataset.output_classes),
- input_shapes=(self._state_shapes, input_dataset.output_shapes),
- input_types=(self._state_types, input_dataset.output_types),
- add_to_graph=False)
-
- # Extract and validate class information from the returned values.
- for new_state_class, state_class in zip(
- nest.flatten(wrapped_func.output_classes),
- nest.flatten(self._state_classes)):
- if not issubclass(new_state_class, state_class):
- raise TypeError(
- "The element classes for the new state must match the initial "
- "state. Expected %s; got %s." %
- (self._state_classes, wrapped_func.output_classes))
-
- # Extract and validate type information from the returned values.
- for new_state_type, state_type in zip(
- nest.flatten(wrapped_func.output_types),
- nest.flatten(self._state_types)):
- if new_state_type != state_type:
- raise TypeError(
- "The element types for the new state must match the initial "
- "state. Expected %s; got %s." %
- (self._state_types, wrapped_func.output_types))
-
- # Extract shape information from the returned values.
- flat_state_shapes = nest.flatten(self._state_shapes)
- flat_new_state_shapes = nest.flatten(wrapped_func.output_shapes)
- weakened_state_shapes = [
- original.most_specific_compatible_shape(new)
- for original, new in zip(flat_state_shapes, flat_new_state_shapes)
- ]
-
- need_to_rerun = False
- for original_shape, weakened_shape in zip(flat_state_shapes,
- weakened_state_shapes):
- if original_shape.ndims is not None and (
- weakened_shape.ndims is None or
- original_shape.as_list() != weakened_shape.as_list()):
- need_to_rerun = True
- break
-
- if need_to_rerun:
- self._state_shapes = nest.pack_sequence_as(self._state_shapes,
- weakened_state_shapes)
-
- self._reduce_func = wrapped_func.function
- self._reduce_func.add_to_graph(ops.get_default_graph())
-
- def _make_finalize_func(self, finalize_func):
- """Make wrapping Defun for finalize_func."""
- wrapped_func = dataset_ops.StructuredFunctionWrapper(
- finalize_func, "tf.contrib.data.group_by_reducer()",
- input_classes=self._state_classes, input_shapes=self._state_shapes,
- input_types=self._state_types)
- self._finalize_func = wrapped_func.function
- self._output_classes = wrapped_func.output_classes
- self._output_shapes = wrapped_func.output_shapes
- self._output_types = wrapped_func.output_types
-
- @property
- def output_classes(self):
- return self._output_classes
-
- @property
- def output_shapes(self):
- return self._output_shapes
-
- @property
- def output_types(self):
- return self._output_types
-
- def _as_variant_tensor(self):
- return gen_dataset_ops.group_by_reducer_dataset(
- self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- self._key_func.captured_inputs,
- self._init_func.captured_inputs,
- self._reduce_func.captured_inputs,
- self._finalize_func.captured_inputs,
- key_func=self._key_func,
- init_func=self._init_func,
- reduce_func=self._reduce_func,
- finalize_func=self._finalize_func,
- **dataset_ops.flat_structure(self))
-
-
-class _GroupByWindowDataset(dataset_ops.UnaryDataset):
- """A `Dataset` that groups its input and performs a windowed reduction."""
-
- def __init__(self, input_dataset, key_func, reduce_func, window_size_func):
- """See `group_by_window()` for details."""
- super(_GroupByWindowDataset, self).__init__(input_dataset)
-
- self._input_dataset = input_dataset
-
- self._make_key_func(key_func, input_dataset)
- self._make_reduce_func(reduce_func, input_dataset)
- self._make_window_size_func(window_size_func)
-
- def _make_window_size_func(self, window_size_func):
- """Make wrapping Defun for window_size_func."""
- def window_size_func_wrapper(key):
- return ops.convert_to_tensor(window_size_func(key), dtype=dtypes.int64)
- wrapped_func = dataset_ops.StructuredFunctionWrapper(
- window_size_func_wrapper, "tf.contrib.data.group_by_window()",
- input_classes=ops.Tensor, input_shapes=tensor_shape.scalar(),
- input_types=dtypes.int64)
- if not (
- wrapped_func.output_types == dtypes.int64 and
- wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())):
- raise ValueError(
- "`window_size_func` must return a single tf.int64 scalar tensor.")
- self._window_size_func = wrapped_func.function
-
- def _make_key_func(self, key_func, input_dataset):
- """Make wrapping Defun for key_func."""
- def key_func_wrapper(*args):
- return ops.convert_to_tensor(key_func(*args), dtype=dtypes.int64)
- wrapped_func = dataset_ops.StructuredFunctionWrapper(
- key_func_wrapper, "tf.contrib.data.group_by_window()", input_dataset)
- if not (
- wrapped_func.output_types == dtypes.int64 and
- wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())):
- raise ValueError(
- "`key_func` must return a single tf.int64 scalar tensor.")
- self._key_func = wrapped_func.function
-
- def _make_reduce_func(self, reduce_func, input_dataset):
- """Make wrapping Defun for reduce_func."""
- nested_dataset = dataset_ops._NestedDatasetComponent(input_dataset) # pylint: disable=protected-access
- wrapped_func = dataset_ops.StructuredFunctionWrapper(
- reduce_func, "tf.contrib.data.reduce_by_window()",
- input_classes=(ops.Tensor, nested_dataset),
- input_shapes=(tensor_shape.scalar(), nested_dataset),
- input_types=(dtypes.int64, nested_dataset),
- experimental_nested_dataset_support=True)
- if not isinstance(
- wrapped_func.output_classes, dataset_ops._NestedDatasetComponent): # pylint: disable=protected-access
- raise TypeError("`reduce_func` must return a `Dataset` object.")
- self._output_classes = wrapped_func.output_classes.output_classes
- self._output_types = wrapped_func.output_types.output_types
- self._output_shapes = wrapped_func.output_shapes.output_shapes
- self._reduce_func = wrapped_func.function
-
- @property
- def output_classes(self):
- return self._output_classes
-
- @property
- def output_shapes(self):
- return self._output_shapes
-
- @property
- def output_types(self):
- return self._output_types
-
- def _as_variant_tensor(self):
- return gen_dataset_ops.group_by_window_dataset(
- self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- self._key_func.captured_inputs,
- self._reduce_func.captured_inputs,
- self._window_size_func.captured_inputs,
- key_func=self._key_func,
- reduce_func=self._reduce_func,
- window_size_func=self._window_size_func,
- **dataset_ops.flat_structure(self))
-
-
-class Reducer(object):
+class Reducer(grouping.Reducer):
"""A reducer is used for reducing a set of elements.
A reducer is represented as a tuple of the three functions:
@@ -507,58 +152,6 @@ class Reducer(object):
3) finalization function: state => result
"""
+ @deprecation.deprecated(None, "Use `tf.data.experimental.Reducer(...)`.")
def __init__(self, init_func, reduce_func, finalize_func):
- self._init_func = init_func
- self._reduce_func = reduce_func
- self._finalize_func = finalize_func
-
- @property
- def init_func(self):
- return self._init_func
-
- @property
- def reduce_func(self):
- return self._reduce_func
-
- @property
- def finalize_func(self):
- return self._finalize_func
-
-
-class _MapXDataset(dataset_ops.UnaryDataset):
- """A `Dataset` that maps a function over elements in its input."""
-
- def __init__(self, input_dataset, map_func):
- """See `map_x_dataset()` for details."""
- super(_MapXDataset, self).__init__(input_dataset)
- self._input_dataset = input_dataset
-
- wrapped_func = dataset_ops.StructuredFunctionWrapper(
- map_func,
- "tf.contrib.data.map_x_dataset()",
- input_dataset,
- experimental_nested_dataset_support=True)
- self._output_classes = wrapped_func.output_classes
- self._output_shapes = wrapped_func.output_shapes
- self._output_types = wrapped_func.output_types
- self._map_func = wrapped_func.function
-
- def _as_variant_tensor(self):
- input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access
- return gen_dataset_ops.map_dataset(
- input_t,
- self._map_func.captured_inputs,
- f=self._map_func,
- **dataset_ops.flat_structure(self))
-
- @property
- def output_classes(self):
- return self._output_classes
-
- @property
- def output_shapes(self):
- return self._output_shapes
-
- @property
- def output_types(self):
- return self._output_types
+ super(Reducer, self).__init__(init_func, reduce_func, finalize_func)
diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py
index 1ee9db1aa8..f50da4d429 100644
--- a/tensorflow/contrib/data/python/ops/interleave_ops.py
+++ b/tensorflow/contrib/data/python/ops/interleave_ops.py
@@ -17,20 +17,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib import stateless
-from tensorflow.contrib.data.python.ops import random_ops
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.ops import readers
-from tensorflow.python.data.util import nest
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import gen_experimental_dataset_ops
-from tensorflow.python.ops import math_ops
+from tensorflow.python.data.experimental.ops import interleave_ops
from tensorflow.python.util import deprecation
+@deprecation.deprecated(None,
+ "Use `tf.data.experimental.parallel_interleave(...)`.")
def parallel_interleave(map_func,
cycle_length,
block_length=1,
@@ -80,12 +72,9 @@ def parallel_interleave(map_func,
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
- def _apply_fn(dataset):
- return readers.ParallelInterleaveDataset(
- dataset, map_func, cycle_length, block_length, sloppy,
- buffer_output_elements, prefetch_input_elements)
-
- return _apply_fn
+ return interleave_ops.parallel_interleave(
+ map_func, cycle_length, block_length, sloppy, buffer_output_elements,
+ prefetch_input_elements)
@deprecation.deprecated(
@@ -139,63 +128,12 @@ def sloppy_interleave(map_func, cycle_length, block_length=1):
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
- def _apply_fn(dataset):
- return readers.ParallelInterleaveDataset(
- dataset,
- map_func,
- cycle_length,
- block_length,
- sloppy=True,
- buffer_output_elements=None,
- prefetch_input_elements=None)
-
- return _apply_fn
-
-
-class _DirectedInterleaveDataset(dataset_ops.Dataset):
- """A substitute for `Dataset.interleave()` on a fixed list of datasets."""
-
- def __init__(self, selector_input, data_inputs):
- self._selector_input = selector_input
- self._data_inputs = list(data_inputs)
-
- for data_input in data_inputs[1:]:
- if (data_input.output_types != data_inputs[0].output_types or
- data_input.output_classes != data_inputs[0].output_classes):
- raise TypeError("All datasets must have the same type and class.")
-
- def _as_variant_tensor(self):
- # pylint: disable=protected-access
- return (
- gen_experimental_dataset_ops.experimental_directed_interleave_dataset(
- self._selector_input._as_variant_tensor(), [
- data_input._as_variant_tensor()
- for data_input in self._data_inputs
- ], **dataset_ops.flat_structure(self)))
- # pylint: enable=protected-access
-
- def _inputs(self):
- return [self._selector_input] + self._data_inputs
-
- @property
- def output_classes(self):
- return self._data_inputs[0].output_classes
-
- @property
- def output_shapes(self):
- ret = self._data_inputs[0].output_shapes
- for data_input in self._data_inputs[1:]:
- ret = nest.pack_sequence_as(ret, [
- ts1.most_specific_compatible_shape(ts2) for (ts1, ts2) in zip(
- nest.flatten(ret), nest.flatten(data_input.output_shapes))
- ])
- return ret
-
- @property
- def output_types(self):
- return self._data_inputs[0].output_types
+ return interleave_ops.parallel_interleave(
+ map_func, cycle_length, block_length, sloppy=True)
+@deprecation.deprecated(None,
+ "Use `tf.data.experimental.sample_from_datasets(...)`.")
def sample_from_datasets(datasets, weights=None, seed=None):
"""Samples elements at random from the datasets in `datasets`.
@@ -219,64 +157,11 @@ def sample_from_datasets(datasets, weights=None, seed=None):
ValueError: If the `weights` argument is specified and does not match the
length of the `datasets` element.
"""
- num_datasets = len(datasets)
- if not isinstance(weights, dataset_ops.Dataset):
- if weights is None:
- # Select inputs with uniform probability.
- logits = [[1.0] * num_datasets]
-
- else:
- # Use the given `weights` as the probability of choosing the respective
- # input.
- weights = ops.convert_to_tensor(weights, name="weights")
- if weights.dtype not in (dtypes.float32, dtypes.float64):
- raise TypeError("`weights` must be convertible to a tensor of "
- "`tf.float32` or `tf.float64` elements.")
- if not weights.shape.is_compatible_with([num_datasets]):
- raise ValueError(
- "`weights` must be a vector of length `len(datasets)`.")
-
- # The `stateless_multinomial()` op expects log-probabilities, as opposed
- # to weights.
- logits = array_ops.expand_dims(math_ops.log(weights, name="logits"), 0)
-
- # NOTE(mrry): We only specialize when `weights` is not a `Dataset`. When it
- # is a `Dataset`, it is possible that evaluating it has a side effect the
- # user depends on.
- if len(datasets) == 1:
- return datasets[0]
-
- def select_dataset_constant_logits(seed):
- return array_ops.squeeze(
- stateless.stateless_multinomial(logits, 1, seed=seed), axis=[0, 1])
-
- selector_input = dataset_ops.MapDataset(
- random_ops.RandomDataset(seed).batch(2),
- select_dataset_constant_logits,
- use_inter_op_parallelism=False)
-
- else:
- # Use each element of the given `weights` dataset as the probability of
- # choosing the respective input.
-
- # The `stateless_multinomial()` op expects log-probabilities, as opposed to
- # weights.
- logits_ds = weights.map(lambda *p: math_ops.log(p, name="logits"))
-
- def select_dataset_varying_logits(logits, seed):
- return array_ops.squeeze(
- stateless.stateless_multinomial(logits, 1, seed=seed), axis=[0, 1])
-
- logits_and_seeds = dataset_ops.Dataset.zip(
- (logits_ds, random_ops.RandomDataset(seed).batch(2)))
- selector_input = dataset_ops.MapDataset(
- logits_and_seeds,
- select_dataset_varying_logits,
- use_inter_op_parallelism=False)
-
- return _DirectedInterleaveDataset(selector_input, datasets)
+ return interleave_ops.sample_from_datasets(datasets, weights, seed)
+@deprecation.deprecated(None,
+ "Use `tf.data.experimental.choose_from_datasets(...)`.")
def choose_from_datasets(datasets, choice_dataset):
"""Creates a dataset that deterministically chooses elements from `datasets`.
@@ -312,10 +197,4 @@ def choose_from_datasets(datasets, choice_dataset):
TypeError: If the `datasets` or `choice_dataset` arguments have the wrong
type.
"""
- if not (choice_dataset.output_types == dtypes.int64
- and choice_dataset.output_shapes.is_compatible_with(
- tensor_shape.scalar())
- and choice_dataset.output_classes == ops.Tensor):
- raise TypeError("`choice_dataset` must be a dataset of scalar "
- "`tf.int64` tensors.")
- return _DirectedInterleaveDataset(choice_dataset, datasets)
+ return interleave_ops.choose_from_datasets(datasets, choice_dataset)
diff --git a/tensorflow/contrib/data/python/ops/iterator_ops.py b/tensorflow/contrib/data/python/ops/iterator_ops.py
index 18515e21ed..48c325c86f 100644
--- a/tensorflow/contrib/data/python/ops/iterator_ops.py
+++ b/tensorflow/contrib/data/python/ops/iterator_ops.py
@@ -16,15 +16,13 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.data.ops import iterator_ops
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import gen_dataset_ops
-from tensorflow.python.training import basic_session_run_hooks
-from tensorflow.python.training import checkpoint_management
-from tensorflow.python.training import saver as saver_lib
-from tensorflow.python.training import session_run_hook
+from tensorflow.python.data.experimental.ops import iterator_ops
+from tensorflow.python.util import deprecation
+
+@deprecation.deprecated(
+ None, "Use `tf.data.experimental.make_saveable_from_iterator(...)`.")
def make_saveable_from_iterator(iterator):
"""Returns a SaveableObject for saving/restore iterator state using Saver.
@@ -60,27 +58,10 @@ def make_saveable_from_iterator(iterator):
Note: Not all iterators support checkpointing yet. Attempting to save the
state of an unsupported iterator will throw an error.
"""
- return _Saveable(iterator._iterator_resource) # pylint: disable=protected-access
-
-
-class _Saveable(saver_lib.BaseSaverBuilder.SaveableObject):
- """SaveableObject for saving/restoring iterator state."""
+ return iterator_ops.make_saveable_from_iterator(iterator)
- def __init__(self, iterator_resource):
- serialized_iterator = gen_dataset_ops.serialize_iterator(iterator_resource)
- specs = [
- saver_lib.BaseSaverBuilder.SaveSpec(serialized_iterator, "",
- iterator_resource.name + "-state")
- ]
- super(_Saveable, self).__init__(iterator_resource, specs,
- iterator_resource.name)
- def restore(self, restored_tensors, unused_restored_shapes):
- with ops.colocate_with(self.op):
- return gen_dataset_ops.deserialize_iterator(self.op, restored_tensors[0])
-
-
-class CheckpointInputPipelineHook(session_run_hook.SessionRunHook):
+class CheckpointInputPipelineHook(iterator_ops.CheckpointInputPipelineHook):
"""Checkpoints input pipeline state every N steps or seconds.
This hook saves the state of the iterators in the `Graph` so that when
@@ -125,135 +106,7 @@ class CheckpointInputPipelineHook(session_run_hook.SessionRunHook):
collector when building the eval graph.
"""
+ @deprecation.deprecated(
+ None, "Use `tf.data.experimental.CheckpointInputPipelineHook(...)`.")
def __init__(self, estimator):
- """Initializes a `CheckpointInputPipelineHook`.
-
- Args:
- estimator: Estimator.
-
- Raises:
- ValueError: One of `save_steps` or `save_secs` should be set.
- ValueError: At most one of saver or scaffold should be set.
- """
- # `checkpoint_basename` is "input.ckpt" for non-distributed pipelines or
- # of the form "input_<task_type>_<task_id>.ckpt" for distributed pipelines.
- # Note: The default `checkpoint_basename` used by `CheckpointSaverHook` is
- # "model.ckpt". We intentionally choose the input pipeline checkpoint prefix
- # to be different to avoid conflicts with the model checkpoint.
-
- # pylint: disable=protected-access
- checkpoint_prefix = "input"
- if estimator._config.num_worker_replicas > 1:
- # Distributed setting.
- suffix = "_{}_{}".format(estimator._config.task_type,
- estimator._config.task_id)
- checkpoint_prefix += suffix
- # pylint: enable=protected-access
-
- # We use a composition paradigm instead of inheriting from
- # `CheckpointSaverHook` because `Estimator` does an `isinstance` check
- # to check whether a `CheckpointSaverHook` is already present in the list
- # of hooks and if not, adds one. Inheriting from `CheckpointSaverHook`
- # would thwart this behavior. This hook checkpoints *only the iterators*
- # and not the graph variables.
- self._checkpoint_saver_hook = basic_session_run_hooks.CheckpointSaverHook(
- estimator.model_dir,
- save_secs=estimator._config.save_checkpoints_secs, # pylint: disable=protected-access
- save_steps=estimator._config.save_checkpoints_steps, # pylint: disable=protected-access
- checkpoint_basename=checkpoint_prefix + ".ckpt")
-
- # Name for the protocol buffer file that will contain the list of most
- # recent checkpoints stored as a `CheckpointState` protocol buffer.
- # This file, kept in the same directory as the checkpoint files, is
- # automatically managed by the `Saver` to keep track of recent checkpoints.
- # The default name used by the `Saver` for this file is "checkpoint". Here
- # we use the name "checkpoint_<checkpoint_prefix>" so that in case the
- # `checkpoint_dir` is the same as the model checkpoint directory, there are
- # no conflicts during restore.
- self._latest_filename = "checkpoint_" + checkpoint_prefix
- self._first_run = True
-
- def begin(self):
- # Build a Saver that saves all iterators in the `GLOBAL_ITERATORS`
- # collection if no `Saver` or `Scaffold` is provided.
- # pylint: disable=protected-access
- if (self._checkpoint_saver_hook._saver is None and
- self._checkpoint_saver_hook._scaffold is None):
- iterators = ops.get_collection(iterator_ops.GLOBAL_ITERATORS)
- saveables = [_Saveable(i) for i in iterators]
- self._checkpoint_saver_hook._saver = _CustomSaver(saveables,
- self._latest_filename)
- # pylint: enable=protected-access
- self._checkpoint_saver_hook.begin()
-
- def _restore_or_save_initial_ckpt(self, session):
- # Ideally this should be run in after_create_session but is not for the
- # following reason:
- # Currently there is no way of enforcing an order of running the
- # `SessionRunHooks`. Hence it is possible that the `_DatasetInitializerHook`
- # is run *after* this hook. That is troublesome because
- # 1. If a checkpoint exists and this hook restores it, the initializer hook
- # will override it.
- # 2. If no checkpoint exists, this hook will try to save an initialized
- # iterator which will result in an exception.
- #
- # As a temporary fix we enter the following implicit contract between this
- # hook and the _DatasetInitializerHook.
- # 1. The _DatasetInitializerHook initializes the iterator in the call to
- # after_create_session.
- # 2. This hook saves the iterator on the first call to `before_run()`, which
- # is guaranteed to happen after `after_create_session()` of all hooks
- # have been run.
-
- # Check if there is an existing checkpoint. If so, restore from it.
- # pylint: disable=protected-access
- latest_checkpoint_path = checkpoint_management.latest_checkpoint(
- self._checkpoint_saver_hook._checkpoint_dir,
- latest_filename=self._latest_filename)
- if latest_checkpoint_path:
- self._checkpoint_saver_hook._get_saver().restore(session,
- latest_checkpoint_path)
- else:
- # The checkpoint saved here is the state at step "global_step".
- # Note: We do not save the GraphDef or MetaGraphDef here.
- global_step = session.run(self._checkpoint_saver_hook._global_step_tensor)
- self._checkpoint_saver_hook._save(session, global_step)
- self._checkpoint_saver_hook._timer.update_last_triggered_step(global_step)
- # pylint: enable=protected-access
-
- def before_run(self, run_context):
- if self._first_run:
- self._restore_or_save_initial_ckpt(run_context.session)
- self._first_run = False
- return self._checkpoint_saver_hook.before_run(run_context)
-
- def after_run(self, run_context, run_values):
- self._checkpoint_saver_hook.after_run(run_context, run_values)
-
- def end(self, session):
- self._checkpoint_saver_hook.end(session)
-
-
-class _CustomSaver(saver_lib.Saver):
- """`Saver` with a different default `latest_filename`.
-
- This is used in the `CheckpointInputPipelineHook` to avoid conflicts with
- the model ckpt saved by the `CheckpointSaverHook`.
- """
-
- def __init__(self, var_list, latest_filename):
- super(_CustomSaver, self).__init__(var_list)
- self._latest_filename = latest_filename
-
- def save(self,
- sess,
- save_path,
- global_step=None,
- latest_filename=None,
- meta_graph_suffix="meta",
- write_meta_graph=True,
- write_state=True,
- strip_default_attrs=False):
- return super(_CustomSaver, self).save(
- sess, save_path, global_step, latest_filename or self._latest_filename,
- meta_graph_suffix, write_meta_graph, write_state, strip_default_attrs)
+ super(CheckpointInputPipelineHook, self).__init__(estimator)
diff --git a/tensorflow/contrib/data/python/ops/parsing_ops.py b/tensorflow/contrib/data/python/ops/parsing_ops.py
index cfbba701b0..3aeee9d8e4 100644
--- a/tensorflow/contrib/data/python/ops/parsing_ops.py
+++ b/tensorflow/contrib/data/python/ops/parsing_ops.py
@@ -17,92 +17,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.util import nest
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.ops import gen_dataset_ops
-from tensorflow.python.ops import parsing_ops
+from tensorflow.python.data.experimental.ops import parsing_ops
+from tensorflow.python.util import deprecation
-class _ParseExampleDataset(dataset_ops.UnaryDataset):
- """A `Dataset` that parses `example` dataset into a `dict` dataset."""
-
- def __init__(self, input_dataset, features, num_parallel_calls):
- super(_ParseExampleDataset, self).__init__(input_dataset)
- self._input_dataset = input_dataset
- if not all(types == dtypes.string
- for types in nest.flatten(input_dataset.output_types)):
- raise TypeError("Input dataset should be a dataset of vectors of strings")
- self._num_parallel_calls = num_parallel_calls
- # pylint: disable=protected-access
- self._features = parsing_ops._prepend_none_dimension(features)
- # sparse_keys and dense_keys come back sorted here.
- (sparse_keys, sparse_types, dense_keys, dense_types, dense_defaults,
- dense_shapes) = parsing_ops._features_to_raw_params(
- self._features, [
- parsing_ops.VarLenFeature, parsing_ops.SparseFeature,
- parsing_ops.FixedLenFeature, parsing_ops.FixedLenSequenceFeature
- ])
- # TODO(b/112859642): Pass sparse_index and sparse_values for SparseFeature.
- (_, dense_defaults_vec, sparse_keys, sparse_types, dense_keys, dense_shapes,
- dense_shape_as_shape) = parsing_ops._process_raw_parameters(
- None, dense_defaults, sparse_keys, sparse_types, dense_keys,
- dense_types, dense_shapes)
- # pylint: enable=protected-access
- self._sparse_keys = sparse_keys
- self._sparse_types = sparse_types
- self._dense_keys = dense_keys
- self._dense_defaults = dense_defaults_vec
- self._dense_shapes = dense_shapes
- self._dense_types = dense_types
- dense_output_shapes = [
- self._input_dataset.output_shapes.concatenate(shape)
- for shape in dense_shape_as_shape
- ]
- sparse_output_shapes = [
- self._input_dataset.output_shapes.concatenate([None])
- for _ in range(len(sparse_keys))
- ]
-
- self._output_shapes = dict(
- zip(self._dense_keys + self._sparse_keys,
- dense_output_shapes + sparse_output_shapes))
- self._output_types = dict(
- zip(self._dense_keys + self._sparse_keys,
- self._dense_types + self._sparse_types))
- self._output_classes = dict(
- zip(self._dense_keys + self._sparse_keys,
- [ops.Tensor for _ in range(len(self._dense_defaults))] +
- [sparse_tensor.SparseTensor for _ in range(len(self._sparse_keys))
- ]))
-
- def _as_variant_tensor(self):
- return gen_dataset_ops.parse_example_dataset(
- self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- self._num_parallel_calls,
- self._dense_defaults,
- self._sparse_keys,
- self._dense_keys,
- self._sparse_types,
- self._dense_shapes,
- **dataset_ops.flat_structure(self))
-
- @property
- def output_shapes(self):
- return self._output_shapes
-
- @property
- def output_types(self):
- return self._output_types
-
- @property
- def output_classes(self):
- return self._output_classes
-
-
-# TODO(b/111553342): add arguments names and example names as well.
+@deprecation.deprecated(
+ None, "Use `tf.data.experimental.parse_example_dataset(...)`.")
def parse_example_dataset(features, num_parallel_calls=1):
"""A transformation that parses `Example` protos into a `dict` of tensors.
@@ -130,21 +50,4 @@ def parse_example_dataset(features, num_parallel_calls=1):
Raises:
ValueError: if features argument is None.
"""
- if features is None:
- raise ValueError("Missing: features was %s." % features)
-
- def _apply_fn(dataset):
- """Function from `Dataset` to `Dataset` that applies the transformation."""
- out_dataset = _ParseExampleDataset(dataset, features, num_parallel_calls)
- if any([
- isinstance(feature, parsing_ops.SparseFeature)
- for _, feature in features.items()
- ]):
- # pylint: disable=protected-access
- # pylint: disable=g-long-lambda
- out_dataset = out_dataset.map(
- lambda x: parsing_ops._construct_sparse_tensors_for_sparse_features(
- features, x), num_parallel_calls=num_parallel_calls)
- return out_dataset
-
- return _apply_fn
+ return parsing_ops.parse_example_dataset(features, num_parallel_calls)
diff --git a/tensorflow/contrib/data/python/ops/prefetching_ops.py b/tensorflow/contrib/data/python/ops/prefetching_ops.py
index 46f82e453a..adfb390cd9 100644
--- a/tensorflow/contrib/data/python/ops/prefetching_ops.py
+++ b/tensorflow/contrib/data/python/ops/prefetching_ops.py
@@ -17,321 +17,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import warnings
-
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.ops import iterator_ops
-from tensorflow.python.data.util import nest
-from tensorflow.python.data.util import sparse
-from tensorflow.python.eager import context
-from tensorflow.python.framework import device as framework_device
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import function
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import functional_ops
-from tensorflow.python.ops import gen_dataset_ops
-from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
-from tensorflow.python.ops import resource_variable_ops
-
-
-def function_buffering_resource(string_arg,
- target_device,
- f,
- buffer_size,
- output_types,
- container="",
- shared_name=None,
- name=None):
- """Creates a FunctionBufferingResource.
-
- A FunctionBufferingResource fills up a buffer by calling a function `f` on
- `target_device`. `f` should take in only a single string argument as input.
-
- Args:
- string_arg: The single string argument to the function.
- target_device: The device to run `f` on.
- f: The function to be executed.
- buffer_size: Size of the buffer to be populated.
- output_types: The output types generated by the function.
- container: (Optional) string. Defaults to "".
- shared_name: (Optional) string.
- name: (Optional) string to name the op.
-
- Returns:
- Handle to a FunctionBufferingResource.
- """
- if shared_name is None:
- shared_name = ""
- return ged_ops.experimental_function_buffering_resource(
- string_arg=string_arg,
- target_device=target_device,
- shared_name=shared_name,
- f=f,
- buffer_size=buffer_size,
- container=container,
- name=name,
- output_types=output_types)
-
-
-def function_buffering_resource_get_next(function_buffer_resource,
- output_types,
- name=None):
- return ged_ops.experimental_function_buffering_resource_get_next(
- function_buffer_resource=function_buffer_resource,
- output_types=output_types,
- name=name)
-
-
-def function_buffering_resource_reset(function_buffer_resource, name=None):
- return ged_ops.experimental_function_buffering_resource_reset(
- function_buffer_resource=function_buffer_resource, name=name)
-
-
-# pylint: disable=protected-access
-class _PrefetchToDeviceIterator(object):
- """A replacement for `tf.data.Iterator` that prefetches to another device.
-
- Args:
- input_dataset: The input dataset
- one_shot: If true, we make a one shot iterator that's already initialized.
- device: A fully specified device string where we want to prefetch to
- buffer_size: Size of the prefetching buffer.
- shared_name: (Optional.) If non-empty, the returned iterator will be
- shared under the given name across multiple sessions that share the
- same devices (e.g. when using a remote server).
-
- Returns:
- An Iterator type object.
- """
-
- def __init__(self,
- input_dataset,
- one_shot,
- device,
- buffer_size,
- shared_name=None):
- self._input_dataset = input_dataset
- self._get_next_call_count = 0
- self._one_shot = one_shot
- if shared_name is None:
- shared_name = ""
-
- if self._one_shot:
- self._input_iterator = input_dataset.make_one_shot_iterator()
- else:
- self._input_iterator = iterator_ops.Iterator.from_structure(
- self._input_dataset.output_types, self._input_dataset.output_shapes,
- shared_name, self._input_dataset.output_classes)
- input_iterator_handle = self._input_iterator.string_handle()
-
- @function.Defun(dtypes.string)
- def _prefetch_fn(handle):
- """Prefetches one element from `input_iterator`."""
- remote_iterator = iterator_ops.Iterator.from_string_handle(
- handle, self._input_iterator.output_types,
- self._input_iterator.output_shapes,
- self._input_iterator.output_classes)
- ret = remote_iterator.get_next()
- return nest.flatten(sparse.serialize_sparse_tensors(ret))
-
- iterator_device = ged_ops.experimental_iterator_get_device(
- self._input_iterator._iterator_resource)
-
- with ops.device(device):
- self._buffering_resource = function_buffering_resource(
- f=_prefetch_fn,
- target_device=iterator_device,
- string_arg=input_iterator_handle,
- buffer_size=buffer_size,
- shared_name=shared_name,
- output_types=nest.flatten(
- sparse.as_dense_types(self._input_dataset.output_types,
- self._input_dataset.output_classes)))
-
- if not self._one_shot:
- reset_op = function_buffering_resource_reset(self._buffering_resource)
- with ops.control_dependencies([reset_op]):
- self._initializer = self._input_iterator.make_initializer(
- self._input_dataset)
-
- def get_next(self, name=None):
- """See `tf.data.Iterator.get_next`."""
- self._get_next_call_count += 1
- if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD:
- warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE)
-
- flat_ret = ged_ops.experimental_function_buffering_resource_get_next(
- self._buffering_resource,
- output_types=nest.flatten(
- sparse.as_dense_types(self.output_types, self.output_classes)),
- name=name)
-
- ret = sparse.deserialize_sparse_tensors(
- nest.pack_sequence_as(self.output_types, flat_ret),
- self.output_types, self.output_shapes, self.output_classes)
-
- for tensor, shape in zip(
- nest.flatten(ret), nest.flatten(self.output_shapes)):
- if isinstance(tensor, ops.Tensor):
- tensor.set_shape(shape)
-
- return ret
-
- @property
- def initializer(self):
- if self._one_shot:
- raise NotImplementedError("Can't initialize a one_shot_iterator")
- return self._initializer
-
- @property
- def output_classes(self):
- return self._input_dataset.output_classes
-
- @property
- def output_shapes(self):
- return self._input_dataset.output_shapes
-
- @property
- def output_types(self):
- return self._input_dataset.output_types
-
-
-class _PrefetchToDeviceEagerIterator(iterator_ops.EagerIterator):
- """A replacement for `tf.data.Iterator` that prefetches to another device.
-
- Args:
- input_dataset: The input dataset
- one_shot: If true, we make a one shot iterator that's already initialized.
- device: A fully specified device string where we want to prefetch to
- buffer_size: Size of the prefetching buffer.
- shared_name: (Optional.) If non-empty, the returned iterator will be
- shared under the given name across multiple sessions that share the
- same devices (e.g. when using a remote server).
-
- Returns:
- An Iterator type object.
- """
-
- def __init__(self,
- input_dataset,
- device,
- buffer_size):
- with ops.device("/device:CPU:0"):
- super(_PrefetchToDeviceEagerIterator, self).__init__(input_dataset)
- input_iterator_handle = gen_dataset_ops.iterator_to_string_handle(
- self._resource)
-
- self._device = device
-
- @function.Defun(dtypes.string)
- def _prefetch_fn(handle):
- """Prefetches one element from `input_iterator`."""
- remote_iterator = iterator_ops.Iterator.from_string_handle(
- handle, self.output_types, self.output_shapes, self.output_classes)
- ret = remote_iterator.get_next()
- return nest.flatten(sparse.serialize_sparse_tensors(ret))
-
- _prefetch_fn.add_to_graph(None)
-
- with ops.device(device):
- self._buffering_resource = function_buffering_resource(
- f=_prefetch_fn,
- output_types=self._flat_output_types,
- target_device=ged_ops.experimental_iterator_get_device(
- self._resource),
- string_arg=input_iterator_handle,
- buffer_size=buffer_size,
- shared_name=iterator_ops._generate_shared_name(
- "function_buffer_resource"))
-
- def _next_internal(self):
- """Returns a nested structure of `tf.Tensor`s containing the next element.
- """
- # This runs in sync mode as iterators use an error status to communicate
- # that there is no more data to iterate over.
- # TODO(b/77291417): Fix
- with context.execution_mode(context.SYNC):
- with ops.device(self._device):
- ret = ged_ops.experimental_function_buffering_resource_get_next(
- function_buffer_resource=self._buffering_resource,
- output_types=self._flat_output_types)
- return sparse.deserialize_sparse_tensors(
- nest.pack_sequence_as(self._output_types, ret), self._output_types,
- self._output_shapes, self._output_classes)
-# pylint: enable=protected-access
-
-
-class _PrefetchToDeviceDataset(dataset_ops.UnaryDataset):
- """A `Dataset` whose iterator prefetches elements to another device."""
-
- def __init__(self, input_dataset, device, buffer_size):
- super(_PrefetchToDeviceDataset, self).__init__(input_dataset)
- self._input_dataset = input_dataset
- self._device = device
- self._buffer_size = buffer_size if buffer_size is not None else 1
-
- # The static analysis cannot tell that the eager iterator's superclass has
- # a `next()` method.
- # pylint: disable=non-iterator-returned
- def __iter__(self):
- """Creates an `Iterator` for enumerating the elements of this dataset.
-
- The returned iterator implements the Python iterator protocol and therefore
- can only be used in eager mode.
-
- Returns:
- An `Iterator` over the elements of this dataset.
-
- Raises:
- RuntimeError: If eager execution is enabled.
- """
- if context.executing_eagerly():
- return _PrefetchToDeviceEagerIterator(self._input_dataset, self._device,
- self._buffer_size)
- else:
- raise RuntimeError("dataset.__iter__() is only supported when eager "
- "execution is enabled.")
- # pylint: enable=non-iterator-returned
-
- def make_one_shot_iterator(self):
- if context.executing_eagerly():
- return _PrefetchToDeviceEagerIterator(self._input_dataset, self._device,
- self._buffer_size)
- else:
- return _PrefetchToDeviceIterator(self._input_dataset, one_shot=True,
- device=self._device,
- buffer_size=self._buffer_size)
-
- def make_initializable_iterator(self, shared_name=None):
- return _PrefetchToDeviceIterator(
- self._input_dataset,
- one_shot=False,
- device=self._device,
- buffer_size=self._buffer_size,
- shared_name=shared_name)
-
- def _as_variant_tensor(self):
- # TODO(mrry): Raise this error earlier (e.g. when one of the Dataset
- # transformation methods is called.
- # TODO(mrry): Investigate support for chaining further transformations after
- # the prefetch, including GPU support.
- raise NotImplementedError("`prefetch_to_device()` must be the last "
- "transformation in a dataset pipeline.")
-
- @property
- def output_types(self):
- return self._input_dataset.output_types
-
- @property
- def output_shapes(self):
- return self._input_dataset.output_shapes
-
- @property
- def output_classes(self):
- return self._input_dataset.output_classes
+from tensorflow.python.data.experimental.ops import prefetching_ops
+from tensorflow.python.util import deprecation
+@deprecation.deprecated(None,
+ "Use `tf.data.experimental.prefetch_to_device(...)`.")
def prefetch_to_device(device, buffer_size=None):
"""A transformation that prefetches dataset values to the given `device`.
@@ -347,12 +38,10 @@ def prefetch_to_device(device, buffer_size=None):
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
- def _apply_fn(dataset):
- return _PrefetchToDeviceDataset(dataset, device, buffer_size)
-
- return _apply_fn
+ return prefetching_ops.prefetch_to_device(device, buffer_size)
+@deprecation.deprecated(None, "Use `tf.data.experimental.copy_to_device(...)`.")
def copy_to_device(target_device, source_device="/cpu:0"):
"""A transformation that copies dataset elements to the given `target_device`.
@@ -364,165 +53,4 @@ def copy_to_device(target_device, source_device="/cpu:0"):
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
-
- def _apply_fn(dataset):
- return _CopyToDeviceDataset(
- dataset, target_device=target_device, source_device=source_device)
-
- return _apply_fn
-
-
-# TODO(rohanj): Use the _input_hostmem attr on the RemoteCall ops to indicate
-# all inputs to the Op are in host memory, thereby avoiding some unnecessary
-# Sends and Recvs.
-class _CopyToDeviceDataset(dataset_ops.UnaryDataset):
- """A `Dataset` that copies elements to another device."""
-
- def __init__(self, input_dataset, target_device, source_device="/cpu:0"):
- """Constructs a _CopyToDeviceDataset.
-
- Args:
- input_dataset: `Dataset` to be copied
- target_device: The name of the device to which elements would be copied.
- source_device: Device where input_dataset would be placed.
- """
- super(_CopyToDeviceDataset, self).__init__(input_dataset)
- self._input_dataset = input_dataset
- self._target_device = target_device
- spec = framework_device.DeviceSpec().from_string(self._target_device)
- self._is_gpu_target = (spec.device_type == "GPU")
- self._source_device_string = source_device
- self._source_device = ops.convert_to_tensor(source_device)
-
- self._flat_output_shapes = nest.flatten(
- sparse.as_dense_shapes(self._input_dataset.output_shapes,
- self._input_dataset.output_classes))
- self._flat_output_types = nest.flatten(
- sparse.as_dense_types(self._input_dataset.output_types,
- self._input_dataset.output_classes))
-
- @function.Defun()
- def _init_func():
- """Creates an iterator for the input dataset.
-
- Returns:
- A `string` tensor that encapsulates the iterator created.
- """
- # pylint: disable=protected-access
- ds_variant = self._input_dataset._as_variant_tensor()
- resource = gen_dataset_ops.anonymous_iterator(
- output_types=self._flat_output_types,
- output_shapes=self._flat_output_shapes)
- with ops.control_dependencies(
- [gen_dataset_ops.make_iterator(ds_variant, resource)]):
- return gen_dataset_ops.iterator_to_string_handle(resource)
-
- @function.Defun()
- def _remote_init_func():
- return functional_ops.remote_call(
- target=self._source_device,
- args=_init_func.captured_inputs,
- Tout=[dtypes.string],
- f=_init_func)
-
- self._init_func = _remote_init_func
- self._init_captured_args = _remote_init_func.captured_inputs
-
- @function.Defun(dtypes.string)
- def _next_func(string_handle):
- """Calls get_next for created iterator.
-
- Args:
- string_handle: An iterator string handle created by _init_func
- Returns:
- The elements generated from `input_dataset`
- """
- with ops.device(self._source_device_string):
- iterator = iterator_ops.Iterator.from_string_handle(
- string_handle, self.output_types, self.output_shapes,
- self.output_classes)
- ret = iterator.get_next()
- return nest.flatten(sparse.serialize_sparse_tensors(ret))
-
- @function.Defun(dtypes.string)
- def _remote_next_func(string_handle):
- return functional_ops.remote_call(
- target=self._source_device,
- args=[string_handle] + _next_func.captured_inputs,
- Tout=self._flat_output_types,
- f=_next_func)
-
- self._next_func = _remote_next_func
- self._next_captured_args = _remote_next_func.captured_inputs
-
- @function.Defun(dtypes.string)
- def _finalize_func(string_handle):
- """Destroys the iterator resource created.
-
- Args:
- string_handle: An iterator string handle created by _init_func
- Returns:
- Tensor constant 0
- """
- iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2(
- string_handle,
- output_types=self._flat_output_types,
- output_shapes=self._flat_output_shapes)
- with ops.control_dependencies([
- resource_variable_ops.destroy_resource_op(
- iterator_resource, ignore_lookup_error=True)]):
- return array_ops.constant(0, dtypes.int64)
-
- @function.Defun(dtypes.string)
- def _remote_finalize_func(string_handle):
- return functional_ops.remote_call(
- target=self._source_device,
- args=[string_handle] + _finalize_func.captured_inputs,
- Tout=[dtypes.int64],
- f=_finalize_func)
-
- self._finalize_func = _remote_finalize_func
- self._finalize_captured_args = _remote_finalize_func.captured_inputs
-
- g = ops.get_default_graph()
- _remote_init_func.add_to_graph(g)
- _remote_next_func.add_to_graph(g)
- _remote_finalize_func.add_to_graph(g)
- # pylint: enable=protected-scope
-
- # The one_shot_iterator implementation needs a 0 arg _make_dataset function
- # that thereby captures all the inputs required to create the dataset. Since
- # there are strings that are inputs to the GeneratorDataset which can't be
- # placed on a GPU, this fails for the GPU case. Therefore, disabling it for
- # GPU
- def make_one_shot_iterator(self):
- if self._is_gpu_target:
- raise ValueError("Cannot create a one shot iterator when using "
- "`tf.contrib.data.copy_to_device()` on GPU. Please use "
- "`Dataset.make_initializable_iterator()` instead.")
- else:
- return super(_CopyToDeviceDataset, self).make_one_shot_iterator()
-
- def _as_variant_tensor(self):
- with ops.device(self._target_device):
- return gen_dataset_ops.generator_dataset(
- self._init_captured_args,
- self._next_captured_args,
- self._finalize_captured_args,
- init_func=self._init_func,
- next_func=self._next_func,
- finalize_func=self._finalize_func,
- output_types=self._flat_output_types,
- output_shapes=self._flat_output_shapes)
-
- @property
- def output_types(self):
- return self._input_dataset.output_types
-
- @property
- def output_shapes(self):
- return self._input_dataset.output_shapes
-
- @property
- def output_classes(self):
- return self._input_dataset.output_classes
+ return prefetching_ops.copy_to_device(target_device, source_device)
diff --git a/tensorflow/contrib/data/python/ops/random_ops.py b/tensorflow/contrib/data/python/ops/random_ops.py
index 344a0763c8..2c95125636 100644
--- a/tensorflow/contrib/data/python/ops/random_ops.py
+++ b/tensorflow/contrib/data/python/ops/random_ops.py
@@ -17,36 +17,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.util import random_seed
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.data.experimental.ops import random_ops
+from tensorflow.python.util import deprecation
-class RandomDataset(dataset_ops.DatasetSource):
+class RandomDataset(random_ops.RandomDataset):
"""A `Dataset` of pseudorandom values."""
+ @deprecation.deprecated(
+ None, "Use `tf.data.experimental.RandomDataset(...)`.")
def __init__(self, seed=None):
- """A `Dataset` of pseudorandom values."""
- super(RandomDataset, self).__init__()
- self._seed, self._seed2 = random_seed.get_seed(seed)
-
- def _as_variant_tensor(self):
- return gen_dataset_ops.random_dataset(
- seed=self._seed,
- seed2=self._seed2,
- **dataset_ops.flat_structure(self))
-
- @property
- def output_classes(self):
- return ops.Tensor
-
- @property
- def output_shapes(self):
- return tensor_shape.scalar()
-
- @property
- def output_types(self):
- return dtypes.int64
+ super(RandomDataset, self).__init__(seed)
diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py
index 360971e200..4601376dff 100644
--- a/tensorflow/contrib/data/python/ops/readers.py
+++ b/tensorflow/contrib/data/python/ops/readers.py
@@ -17,295 +17,20 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import collections
-import csv
-
-import numpy as np
-
-from tensorflow.contrib.data.python.ops import batching
-from tensorflow.contrib.data.python.ops import interleave_ops
-from tensorflow.contrib.data.python.ops import optimization
-from tensorflow.contrib.data.python.ops import parsing_ops
-from tensorflow.contrib.data.python.ops import shuffle_ops
+from tensorflow.python.data.experimental.ops import optimization
+from tensorflow.python.data.experimental.ops import readers
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers as core_readers
-from tensorflow.python.data.util import convert
from tensorflow.python.data.util import nest
-from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
-from tensorflow.python.lib.io import file_io
-from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import gen_experimental_dataset_ops
-from tensorflow.python.platform import gfile
from tensorflow.python.util import deprecation
-_ACCEPTABLE_CSV_TYPES = (dtypes.float32, dtypes.float64, dtypes.int32,
- dtypes.int64, dtypes.string)
-
-
-def _is_valid_int32(str_val):
- try:
- # Checks equality to prevent int32 overflow
- return dtypes.int32.as_numpy_dtype(str_val) == dtypes.int64.as_numpy_dtype(
- str_val)
- except (ValueError, OverflowError):
- return False
-
-
-def _is_valid_int64(str_val):
- try:
- dtypes.int64.as_numpy_dtype(str_val)
- return True
- except (ValueError, OverflowError):
- return False
-
-
-def _is_valid_float(str_val, float_dtype):
- try:
- return float_dtype.as_numpy_dtype(str_val) < np.inf
- except ValueError:
- return False
-
-
-def _infer_type(str_val, na_value, prev_type):
- """Given a string, infers its tensor type.
-
- Infers the type of a value by picking the least 'permissive' type possible,
- while still allowing the previous type inference for this column to be valid.
-
- Args:
- str_val: String value to infer the type of.
- na_value: Additional string to recognize as a NA/NaN CSV value.
- prev_type: Type previously inferred based on values of this column that
- we've seen up till now.
- Returns:
- Inferred dtype.
- """
- if str_val in ("", na_value):
- # If the field is null, it gives no extra information about its type
- return prev_type
-
- type_list = [
- dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64, dtypes.string
- ] # list of types to try, ordered from least permissive to most
-
- type_functions = [
- _is_valid_int32,
- _is_valid_int64,
- lambda str_val: _is_valid_float(str_val, dtypes.float32),
- lambda str_val: _is_valid_float(str_val, dtypes.float64),
- lambda str_val: True,
- ] # Corresponding list of validation functions
-
- for i in range(len(type_list)):
- validation_fn = type_functions[i]
- if validation_fn(str_val) and (prev_type is None or
- prev_type in type_list[:i + 1]):
- return type_list[i]
-
-
-def _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header):
- """Generator that yields rows of CSV file(s) in order."""
- for fn in filenames:
- with file_io.FileIO(fn, "r") as f:
- rdr = csv.reader(
- f,
- delimiter=field_delim,
- quoting=csv.QUOTE_MINIMAL if use_quote_delim else csv.QUOTE_NONE)
- if header:
- next(rdr) # Skip header lines
-
- for csv_row in rdr:
- if len(csv_row) != num_cols:
- raise ValueError(
- "Problem inferring types: CSV row has different number of fields "
- "than expected.")
- yield csv_row
-
-
-def _infer_column_defaults(filenames, num_cols, field_delim, use_quote_delim,
- na_value, header, num_rows_for_inference,
- select_columns):
- """Infers column types from the first N valid CSV records of files."""
- if select_columns is None:
- select_columns = range(num_cols)
- inferred_types = [None] * len(select_columns)
-
- for i, csv_row in enumerate(
- _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header)):
- if num_rows_for_inference is not None and i >= num_rows_for_inference:
- break
-
- for j, col_index in enumerate(select_columns):
- inferred_types[j] = _infer_type(csv_row[col_index], na_value,
- inferred_types[j])
-
- # Replace None's with a default type
- inferred_types = [t or dtypes.string for t in inferred_types]
- # Default to 0 or '' for null values
- return [
- constant_op.constant([0 if t is not dtypes.string else ""], dtype=t)
- for t in inferred_types
- ]
-
-
-def _infer_column_names(filenames, field_delim, use_quote_delim):
- """Infers column names from first rows of files."""
- csv_kwargs = {
- "delimiter": field_delim,
- "quoting": csv.QUOTE_MINIMAL if use_quote_delim else csv.QUOTE_NONE
- }
- with file_io.FileIO(filenames[0], "r") as f:
- try:
- column_names = next(csv.reader(f, **csv_kwargs))
- except StopIteration:
- raise ValueError(("Received StopIteration when reading the header line "
- "of %s. Empty file?") % filenames[0])
-
- for name in filenames[1:]:
- with file_io.FileIO(name, "r") as f:
- try:
- if next(csv.reader(f, **csv_kwargs)) != column_names:
- raise ValueError(
- "Files have different column names in the header row.")
- except StopIteration:
- raise ValueError(("Received StopIteration when reading the header line "
- "of %s. Empty file?") % filenames[0])
- return column_names
-
-
-def _get_sorted_col_indices(select_columns, column_names):
- """Transforms select_columns argument into sorted column indices."""
- names_to_indices = {n: i for i, n in enumerate(column_names)}
- num_cols = len(column_names)
- for i, v in enumerate(select_columns):
- if isinstance(v, int):
- if v < 0 or v >= num_cols:
- raise ValueError(
- "Column index %d specified in select_columns out of valid range." %
- v)
- continue
- if v not in names_to_indices:
- raise ValueError(
- "Value '%s' specified in select_columns not a valid column index or "
- "name." % v)
- select_columns[i] = names_to_indices[v]
-
- # Sort and ensure there are no duplicates
- result = sorted(set(select_columns))
- if len(result) != len(select_columns):
- raise ValueError("select_columns contains duplicate columns")
- return result
-
-
-def _maybe_shuffle_and_repeat(
- dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed):
- """Optionally shuffle and repeat dataset, as requested."""
- if num_epochs != 1 and shuffle:
- # Use shuffle_and_repeat for perf
- return dataset.apply(
- shuffle_ops.shuffle_and_repeat(shuffle_buffer_size, num_epochs,
- shuffle_seed))
- elif shuffle:
- return dataset.shuffle(shuffle_buffer_size, shuffle_seed)
- elif num_epochs != 1:
- return dataset.repeat(num_epochs)
- return dataset
-
-
-def make_tf_record_dataset(file_pattern,
- batch_size,
- parser_fn=None,
- num_epochs=None,
- shuffle=True,
- shuffle_buffer_size=None,
- shuffle_seed=None,
- prefetch_buffer_size=optimization.AUTOTUNE,
- num_parallel_reads=None,
- num_parallel_parser_calls=None,
- drop_final_batch=False):
- """Reads and optionally parses TFRecord files into a dataset.
-
- Provides common functionality such as batching, optional parsing, shuffling,
- and performant defaults.
-
- Args:
- file_pattern: List of files or patterns of TFRecord file paths.
- See `tf.gfile.Glob` for pattern rules.
- batch_size: An int representing the number of records to combine
- in a single batch.
- parser_fn: (Optional.) A function accepting string input to parse
- and process the record contents. This function must map records
- to components of a fixed shape, so they may be batched. By
- default, uses the record contents unmodified.
- num_epochs: (Optional.) An int specifying the number of times this
- dataset is repeated. If None (the default), cycles through the
- dataset forever.
- shuffle: (Optional.) A bool that indicates whether the input
- should be shuffled. Defaults to `True`.
- shuffle_buffer_size: (Optional.) Buffer size to use for
- shuffling. A large buffer size ensures better shuffling, but
- increases memory usage and startup time.
- shuffle_seed: (Optional.) Randomization seed to use for shuffling.
- prefetch_buffer_size: (Optional.) An int specifying the number of
- feature batches to prefetch for performance improvement.
- Defaults to auto-tune. Set to 0 to disable prefetching.
- num_parallel_reads: (Optional.) Number of threads used to read
- records from files. By default or if set to a value >1, the
- results will be interleaved.
- num_parallel_parser_calls: (Optional.) Number of parallel
- records to parse in parallel. Defaults to an automatic selection.
- drop_final_batch: (Optional.) Whether the last batch should be
- dropped in case its size is smaller than `batch_size`; the
- default behavior is not to drop the smaller batch.
-
- Returns:
- A dataset, where each element matches the output of `parser_fn`
- except it will have an additional leading `batch-size` dimension,
- or a `batch_size`-length 1-D tensor of strings if `parser_fn` is
- unspecified.
- """
- files = dataset_ops.Dataset.list_files(
- file_pattern, shuffle=shuffle, seed=shuffle_seed)
-
- if num_parallel_reads is None:
- # Note: We considered auto-tuning this value, but there is a concern
- # that this affects the mixing of records from different files, which
- # could affect training convergence/accuracy, so we are defaulting to
- # a constant for now.
- num_parallel_reads = 24
- dataset = core_readers.TFRecordDataset(
- files, num_parallel_reads=num_parallel_reads)
-
- if shuffle_buffer_size is None:
- # TODO(josh11b): Auto-tune this value when not specified
- shuffle_buffer_size = 10000
- dataset = _maybe_shuffle_and_repeat(
- dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed)
-
- # NOTE(mrry): We set `drop_final_batch=True` when `num_epochs is None` to
- # improve the shape inference, because it makes the batch dimension static.
- # It is safe to do this because in that case we are repeating the input
- # indefinitely, and all batches will be full-sized.
- drop_final_batch = drop_final_batch or num_epochs is None
-
- if parser_fn is None:
- dataset = dataset.batch(batch_size, drop_remainder=drop_final_batch)
- else:
- # TODO(josh11b): if num_parallel_parser_calls is None, use some function
- # of num cores instead of map_and_batch's default behavior of one batch.
- dataset = dataset.apply(batching.map_and_batch(
- parser_fn, batch_size, num_parallel_calls=num_parallel_parser_calls,
- drop_remainder=drop_final_batch))
-
- if prefetch_buffer_size == 0:
- return dataset
- else:
- return dataset.prefetch(buffer_size=prefetch_buffer_size)
-
+@deprecation.deprecated(None,
+ "Use `tf.data.experimental.make_csv_dataset(...)`.")
def make_csv_dataset(
file_pattern,
batch_size,
@@ -387,7 +112,6 @@ def make_csv_dataset(
prefetch_buffer_size: An int specifying the number of feature
batches to prefetch for performance improvement. Recommended value is the
number of batches consumed per training step. Defaults to auto-tune.
-
num_parallel_reads: Number of threads used to read CSV records from files.
If >1, the results will be interleaved.
sloppy: If `True`, reading performance will be improved at
@@ -411,106 +135,18 @@ def make_csv_dataset(
Raises:
ValueError: If any of the arguments is malformed.
"""
- # Create dataset of all matching filenames
- filenames = _get_file_names(file_pattern, False)
- dataset = dataset_ops.Dataset.from_tensor_slices(filenames)
- if shuffle:
- dataset = dataset.shuffle(len(filenames), shuffle_seed)
-
- # Clean arguments; figure out column names and defaults
+ return readers.make_csv_dataset(
+ file_pattern, batch_size, column_names, column_defaults, label_name,
+ select_columns, field_delim, use_quote_delim, na_value, header,
+ num_epochs, shuffle, shuffle_buffer_size, shuffle_seed,
+ prefetch_buffer_size, num_parallel_reads, sloppy, num_rows_for_inference,
+ compression_type)
- if column_names is None:
- if not header:
- raise ValueError("Cannot infer column names without a header line.")
- # If column names are not provided, infer from the header lines
- column_names = _infer_column_names(filenames, field_delim, use_quote_delim)
- if len(column_names) != len(set(column_names)):
- raise ValueError("Cannot have duplicate column names.")
- if select_columns is not None:
- select_columns = _get_sorted_col_indices(select_columns, column_names)
-
- if column_defaults is not None:
- column_defaults = [
- constant_op.constant([], dtype=x) if x in _ACCEPTABLE_CSV_TYPES else x
- for x in column_defaults
- ]
- else:
- # If column defaults are not provided, infer from records at graph
- # construction time
- column_defaults = _infer_column_defaults(
- filenames, len(column_names), field_delim, use_quote_delim, na_value,
- header, num_rows_for_inference, select_columns)
-
- if select_columns is not None and len(column_defaults) != len(select_columns):
- raise ValueError(
- "If specified, column_defaults and select_columns must have same "
- "length."
- )
- if select_columns is not None and len(column_names) > len(select_columns):
- # Pick the relevant subset of column names
- column_names = [column_names[i] for i in select_columns]
-
- if label_name is not None and label_name not in column_names:
- raise ValueError("`label_name` provided must be one of the columns.")
-
- def filename_to_dataset(filename):
- return CsvDataset(
- filename,
- record_defaults=column_defaults,
- field_delim=field_delim,
- use_quote_delim=use_quote_delim,
- na_value=na_value,
- select_cols=select_columns,
- header=header,
- compression_type=compression_type,
- )
-
- def map_fn(*columns):
- """Organizes columns into a features dictionary.
-
- Args:
- *columns: list of `Tensor`s corresponding to one csv record.
- Returns:
- An OrderedDict of feature names to values for that particular record. If
- label_name is provided, extracts the label feature to be returned as the
- second element of the tuple.
- """
- features = collections.OrderedDict(zip(column_names, columns))
- if label_name is not None:
- label = features.pop(label_name)
- return features, label
- return features
-
- # Read files sequentially (if num_parallel_reads=1) or in parallel
- dataset = dataset.apply(
- interleave_ops.parallel_interleave(
- filename_to_dataset, cycle_length=num_parallel_reads, sloppy=sloppy))
-
- dataset = _maybe_shuffle_and_repeat(
- dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed)
-
- # Apply batch before map for perf, because map has high overhead relative
- # to the size of the computation in each map.
- # NOTE(mrry): We set `drop_remainder=True` when `num_epochs is None` to
- # improve the shape inference, because it makes the batch dimension static.
- # It is safe to do this because in that case we are repeating the input
- # indefinitely, and all batches will be full-sized.
- dataset = dataset.batch(batch_size=batch_size,
- drop_remainder=num_epochs is None)
- dataset = dataset_ops.MapDataset(
- dataset, map_fn, use_inter_op_parallelism=False)
- dataset = dataset.prefetch(prefetch_buffer_size)
-
- return dataset
-
-
-_DEFAULT_READER_BUFFER_SIZE_BYTES = 4 * 1024 * 1024 # 4 MB
-
-
-class CsvDataset(dataset_ops.DatasetSource):
+class CsvDataset(readers.CsvDataset):
"""A Dataset comprising lines from one or more CSV files."""
+ @deprecation.deprecated(None, "Use `tf.data.experimental.CsvDataset(...)`.")
def __init__(self,
filenames,
record_defaults,
@@ -521,140 +157,13 @@ class CsvDataset(dataset_ops.DatasetSource):
use_quote_delim=True,
na_value="",
select_cols=None):
- """Creates a `CsvDataset` by reading and decoding CSV files.
-
- The elements of this dataset correspond to records from the file(s).
- RFC 4180 format is expected for CSV files
- (https://tools.ietf.org/html/rfc4180)
- Note that we allow leading and trailing spaces with int or float field.
-
-
- For example, suppose we have a file 'my_file0.csv' with four CSV columns of
- different data types:
- ```
- abcdefg,4.28E10,5.55E6,12
- hijklmn,-5.3E14,,2
- ```
-
- We can construct a CsvDataset from it as follows:
- ```python
- dataset = tf.contrib.data.CsvDataset(
- "my_file*.csv",
- [tf.float32, # Required field, use dtype or empty tensor
- tf.constant([0.0], dtype=tf.float32), # Optional field, default to 0.0
- tf.int32, # Required field, use dtype or empty tensor
- ],
- select_cols=[1,2,3] # Only parse last three columns
- )
- ```
-
- The expected output of its iterations is:
- ```python
- next_element = dataset.make_one_shot_iterator().get_next()
- with tf.Session() as sess:
- while True:
- try:
- print(sess.run(next_element))
- except tf.errors.OutOfRangeError:
- break
-
- >> (4.28e10, 5.55e6, 12)
- >> (-5.3e14, 0.0, 2)
- ```
-
- Args:
- filenames: A `tf.string` tensor containing one or more filenames.
- record_defaults: A list of default values for the CSV fields. Each item in
- the list is either a valid CSV `DType` (float32, float64, int32, int64,
- string), or a `Tensor` object with one of the above types. One per
- column of CSV data, with either a scalar `Tensor` default value for the
- column if it is optional, or `DType` or empty `Tensor` if required. If
- both this and `select_columns` are specified, these must have the same
- lengths, and `column_defaults` is assumed to be sorted in order of
- increasing column index.
- compression_type: (Optional.) A `tf.string` scalar evaluating to one of
- `""` (no compression), `"ZLIB"`, or `"GZIP"`. Defaults to no
- compression.
- buffer_size: (Optional.) A `tf.int64` scalar denoting the number of bytes
- to buffer while reading files. Defaults to 4MB.
- header: (Optional.) A `tf.bool` scalar indicating whether the CSV file(s)
- have header line(s) that should be skipped when parsing. Defaults to
- `False`.
- field_delim: (Optional.) A `tf.string` scalar containing the delimiter
- character that separates fields in a record. Defaults to `","`.
- use_quote_delim: (Optional.) A `tf.bool` scalar. If `False`, treats
- double quotation marks as regular characters inside of string fields
- (ignoring RFC 4180, Section 2, Bullet 5). Defaults to `True`.
- na_value: (Optional.) A `tf.string` scalar indicating a value that will
- be treated as NA/NaN.
- select_cols: (Optional.) A sorted list of column indices to select from
- the input data. If specified, only this subset of columns will be
- parsed. Defaults to parsing all columns.
- """
- super(CsvDataset, self).__init__()
- self._filenames = ops.convert_to_tensor(
- filenames, dtype=dtypes.string, name="filenames")
- self._compression_type = convert.optional_param_to_tensor(
- "compression_type",
- compression_type,
- argument_default="",
- argument_dtype=dtypes.string)
- record_defaults = [
- constant_op.constant([], dtype=x) if x in _ACCEPTABLE_CSV_TYPES else x
- for x in record_defaults
- ]
- self._record_defaults = ops.convert_n_to_tensor(
- record_defaults, name="record_defaults")
- self._buffer_size = convert.optional_param_to_tensor(
- "buffer_size", buffer_size, _DEFAULT_READER_BUFFER_SIZE_BYTES)
- self._header = ops.convert_to_tensor(
- header, dtype=dtypes.bool, name="header")
- self._field_delim = ops.convert_to_tensor(
- field_delim, dtype=dtypes.string, name="field_delim")
- self._use_quote_delim = ops.convert_to_tensor(
- use_quote_delim, dtype=dtypes.bool, name="use_quote_delim")
- self._na_value = ops.convert_to_tensor(
- na_value, dtype=dtypes.string, name="na_value")
- self._select_cols = convert.optional_param_to_tensor(
- "select_cols",
- select_cols,
- argument_default=[],
- argument_dtype=dtypes.int64,
- )
- self._output_shapes = tuple(
- tensor_shape.scalar() for _ in range(len(record_defaults)))
- self._output_types = tuple(d.dtype for d in self._record_defaults)
- self._output_classes = tuple(
- ops.Tensor for _ in range(len(record_defaults)))
-
- def _as_variant_tensor(self):
- # Constructs graph node for the dataset op.
- return gen_experimental_dataset_ops.experimental_csv_dataset(
- filenames=self._filenames,
- record_defaults=self._record_defaults,
- buffer_size=self._buffer_size,
- header=self._header,
- output_shapes=self._output_shapes,
- field_delim=self._field_delim,
- use_quote_delim=self._use_quote_delim,
- na_value=self._na_value,
- select_cols=self._select_cols,
- compression_type=self._compression_type,
- )
-
- @property
- def output_types(self):
- return self._output_types
-
- @property
- def output_shapes(self):
- return self._output_shapes
-
- @property
- def output_classes(self):
- return self._output_classes
+ super(CsvDataset, self).__init__(
+ filenames, record_defaults, compression_type, buffer_size, header,
+ field_delim, use_quote_delim, na_value, select_cols)
+@deprecation.deprecated(
+ None, "Use `tf.data.experimental.make_batched_features_dataset(...)`.")
def make_batched_features_dataset(file_pattern,
batch_size,
features,
@@ -759,57 +268,15 @@ def make_batched_features_dataset(file_pattern,
Raises:
ValueError: If `label_key` is not one of the `features` keys.
"""
- # Create dataset of all matching filenames
- filenames = _get_file_names(file_pattern, False)
- dataset = dataset_ops.Dataset.from_tensor_slices(filenames)
- if shuffle:
- dataset = dataset.shuffle(len(filenames), shuffle_seed)
-
- # Read `Example` records from files as tensor objects.
- if reader_args is None:
- reader_args = []
+ return readers.make_batched_features_dataset(
+ file_pattern, batch_size, features, reader, label_key, reader_args,
+ num_epochs, shuffle, shuffle_buffer_size, shuffle_seed,
+ prefetch_buffer_size, reader_num_threads, parser_num_threads,
+ sloppy_ordering, drop_final_batch)
- # Read files sequentially (if reader_num_threads=1) or in parallel
- dataset = dataset.apply(
- interleave_ops.parallel_interleave(
- lambda filename: reader(filename, *reader_args),
- cycle_length=reader_num_threads,
- sloppy=sloppy_ordering))
- # Extract values if the `Example` tensors are stored as key-value tuples.
- if dataset.output_types == (dtypes.string, dtypes.string):
- dataset = dataset_ops.MapDataset(
- dataset, lambda _, v: v, use_inter_op_parallelism=False)
-
- # Apply dataset repeat and shuffle transformations.
- dataset = _maybe_shuffle_and_repeat(
- dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed)
-
- # NOTE(mrry): We set `drop_remainder=True` when `num_epochs is None` to
- # improve the shape inference, because it makes the batch dimension static.
- # It is safe to do this because in that case we are repeating the input
- # indefinitely, and all batches will be full-sized.
- dataset = dataset.batch(
- batch_size, drop_remainder=drop_final_batch or num_epochs is None)
-
- # Parse `Example` tensors to a dictionary of `Feature` tensors.
- dataset = dataset.apply(
- parsing_ops.parse_example_dataset(
- features, num_parallel_calls=parser_num_threads))
-
- if label_key:
- if label_key not in features:
- raise ValueError(
- "The `label_key` provided (%r) must be one of the `features` keys." %
- label_key)
- dataset = dataset.map(lambda x: (x, x.pop(label_key)))
-
- dataset = dataset.prefetch(prefetch_buffer_size)
- return dataset
-
-
-@deprecation.deprecated(None,
- "Use `tf.contrib.data.make_batched_features_dataset`")
+@deprecation.deprecated(
+ None, "Use `tf.data.experimental.make_batched_features_dataset(...)`")
def read_batch_features(file_pattern,
batch_size,
features,
@@ -879,7 +346,7 @@ def read_batch_features(file_pattern,
Returns:
A dict from keys in features to `Tensor` or `SparseTensor` objects.
"""
- dataset = make_batched_features_dataset(
+ dataset = readers.make_batched_features_dataset(
file_pattern,
batch_size,
features,
@@ -893,96 +360,13 @@ def read_batch_features(file_pattern,
return outputs
-def _get_file_names(file_pattern, shuffle):
- """Parse list of file names from pattern, optionally shuffled.
-
- Args:
- file_pattern: File glob pattern, or list of glob patterns.
- shuffle: Whether to shuffle the order of file names.
-
- Returns:
- List of file names matching `file_pattern`.
-
- Raises:
- ValueError: If `file_pattern` is empty, or pattern matches no files.
- """
- if isinstance(file_pattern, list):
- if not file_pattern:
- raise ValueError("File pattern is empty.")
- file_names = []
- for entry in file_pattern:
- file_names.extend(gfile.Glob(entry))
- else:
- file_names = list(gfile.Glob(file_pattern))
-
- if not file_names:
- raise ValueError("No files match %s." % file_pattern)
-
- # Sort files so it will be deterministic for unit tests.
- if not shuffle:
- file_names = sorted(file_names)
- return file_names
-
-
-class SqlDataset(dataset_ops.DatasetSource):
+class SqlDataset(readers.SqlDataset):
"""A `Dataset` consisting of the results from a SQL query."""
+ @deprecation.deprecated(None, "Use `tf.data.experimental.SqlDataset(...)`.")
def __init__(self, driver_name, data_source_name, query, output_types):
- """Creates a `SqlDataset`.
-
- `SqlDataset` allows a user to read data from the result set of a SQL query.
- For example:
-
- ```python
- dataset = tf.contrib.data.SqlDataset("sqlite", "/foo/bar.sqlite3",
- "SELECT name, age FROM people",
- (tf.string, tf.int32))
- iterator = dataset.make_one_shot_iterator()
- next_element = iterator.get_next()
- # Prints the rows of the result set of the above query.
- while True:
- try:
- print(sess.run(next_element))
- except tf.errors.OutOfRangeError:
- break
- ```
-
- Args:
- driver_name: A 0-D `tf.string` tensor containing the database type.
- Currently, the only supported value is 'sqlite'.
- data_source_name: A 0-D `tf.string` tensor containing a connection string
- to connect to the database.
- query: A 0-D `tf.string` tensor containing the SQL query to execute.
- output_types: A tuple of `tf.DType` objects representing the types of the
- columns returned by `query`.
- """
- super(SqlDataset, self).__init__()
- self._driver_name = ops.convert_to_tensor(
- driver_name, dtype=dtypes.string, name="driver_name")
- self._data_source_name = ops.convert_to_tensor(
- data_source_name, dtype=dtypes.string, name="data_source_name")
- self._query = ops.convert_to_tensor(
- query, dtype=dtypes.string, name="query")
- self._output_types = output_types
-
- def _as_variant_tensor(self):
- return gen_dataset_ops.sql_dataset(self._driver_name,
- self._data_source_name, self._query,
- nest.flatten(self.output_types),
- nest.flatten(self.output_shapes))
-
- @property
- def output_classes(self):
- return nest.map_structure(lambda _: ops.Tensor, self._output_types)
-
- @property
- def output_shapes(self):
- return nest.map_structure(lambda _: tensor_shape.TensorShape([]),
- self._output_types)
-
- @property
- def output_types(self):
- return self._output_types
+ super(SqlDataset, self).__init__(
+ driver_name, data_source_name, query, output_types)
class LMDBDataset(dataset_ops.DatasetSource):
diff --git a/tensorflow/contrib/data/python/ops/resampling.py b/tensorflow/contrib/data/python/ops/resampling.py
index 75642f143e..29d77528d9 100644
--- a/tensorflow/contrib/data/python/ops/resampling.py
+++ b/tensorflow/contrib/data/python/ops/resampling.py
@@ -17,22 +17,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import numpy as np
-
-from tensorflow.contrib.data.python.ops import batching
-from tensorflow.contrib.data.python.ops import interleave_ops
-from tensorflow.contrib.data.python.ops import scan_ops
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_util
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import logging_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import random_ops
+from tensorflow.python.data.experimental.ops import resampling
+from tensorflow.python.util import deprecation
+@deprecation.deprecated(None,
+ "Use `tf.data.experimental.rejection_resample(...)`.")
def rejection_resample(class_func, target_dist, initial_dist=None, seed=None):
"""A transformation that resamples a dataset to achieve a target distribution.
@@ -52,243 +42,5 @@ def rejection_resample(class_func, target_dist, initial_dist=None, seed=None):
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
- def _apply_fn(dataset):
- """Function from `Dataset` to `Dataset` that applies the transformation."""
- target_dist_t = ops.convert_to_tensor(target_dist, name="target_dist")
- class_values_ds = dataset.map(class_func)
-
- # Get initial distribution.
- if initial_dist is not None:
- initial_dist_t = ops.convert_to_tensor(initial_dist, name="initial_dist")
- acceptance_dist, prob_of_original = (
- _calculate_acceptance_probs_with_mixing(initial_dist_t,
- target_dist_t))
- initial_dist_ds = dataset_ops.Dataset.from_tensors(
- initial_dist_t).repeat()
- acceptance_dist_ds = dataset_ops.Dataset.from_tensors(
- acceptance_dist).repeat()
- prob_of_original_ds = dataset_ops.Dataset.from_tensors(
- prob_of_original).repeat()
- else:
- initial_dist_ds = _estimate_initial_dist_ds(
- target_dist_t, class_values_ds)
- acceptance_and_original_prob_ds = initial_dist_ds.map(
- lambda initial: _calculate_acceptance_probs_with_mixing(
- initial, target_dist_t))
- acceptance_dist_ds = acceptance_and_original_prob_ds.map(
- lambda accept_prob, _: accept_prob)
- prob_of_original_ds = acceptance_and_original_prob_ds.map(
- lambda _, prob_original: prob_original)
- filtered_ds = _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds,
- class_values_ds, seed)
- # Prefetch filtered dataset for speed.
- filtered_ds = filtered_ds.prefetch(3)
-
- prob_original_static = _get_prob_original_static(
- initial_dist_t, target_dist_t) if initial_dist is not None else None
- if prob_original_static == 1:
- return dataset_ops.Dataset.zip((class_values_ds, dataset))
- elif prob_original_static == 0:
- return filtered_ds
- else:
- return interleave_ops.sample_from_datasets(
- [dataset_ops.Dataset.zip((class_values_ds, dataset)), filtered_ds],
- weights=prob_of_original_ds.map(lambda prob: [(prob, 1.0 - prob)]),
- seed=seed)
-
- return _apply_fn
-
-
-def _get_prob_original_static(initial_dist_t, target_dist_t):
- """Returns the static probability of sampling from the original.
-
- `tensor_util.constant_value(prob_of_original)` returns `None` if it encounters
- an Op that it isn't defined for. We have some custom logic to avoid this.
-
- Args:
- initial_dist_t: A tensor of the initial distribution.
- target_dist_t: A tensor of the target distribution.
-
- Returns:
- The probability of sampling from the original distribution as a constant,
- if it is a constant, or `None`.
- """
- init_static = tensor_util.constant_value(initial_dist_t)
- target_static = tensor_util.constant_value(target_dist_t)
-
- if init_static is None or target_static is None:
- return None
- else:
- return np.min(target_static / init_static)
-
-
-def _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds, class_values_ds,
- seed):
- """Filters a dataset based on per-class acceptance probabilities.
-
- Args:
- dataset: The dataset to be filtered.
- acceptance_dist_ds: A dataset of acceptance probabilities.
- initial_dist_ds: A dataset of the initial probability distribution, given or
- estimated.
- class_values_ds: A dataset of the corresponding classes.
- seed: (Optional.) Python integer seed for the resampler.
-
- Returns:
- A dataset of (class value, data) after filtering.
- """
- def maybe_warn_on_large_rejection(accept_dist, initial_dist):
- proportion_rejected = math_ops.reduce_sum((1 - accept_dist) * initial_dist)
- return control_flow_ops.cond(
- math_ops.less(proportion_rejected, .5),
- lambda: accept_dist,
- lambda: logging_ops.Print( # pylint: disable=g-long-lambda
- accept_dist, [proportion_rejected, initial_dist, accept_dist],
- message="Proportion of examples rejected by sampler is high: ",
- summarize=100,
- first_n=10))
-
- acceptance_dist_ds = (dataset_ops.Dataset.zip((acceptance_dist_ds,
- initial_dist_ds))
- .map(maybe_warn_on_large_rejection))
-
- def _gather_and_copy(class_val, acceptance_prob, data):
- return class_val, array_ops.gather(acceptance_prob, class_val), data
-
- current_probabilities_and_class_and_data_ds = dataset_ops.Dataset.zip(
- (class_values_ds, acceptance_dist_ds, dataset)).map(_gather_and_copy)
- filtered_ds = (
- current_probabilities_and_class_and_data_ds
- .filter(lambda _1, p, _2: random_ops.random_uniform([], seed=seed) < p))
- return filtered_ds.map(lambda class_value, _, data: (class_value, data))
-
-
-def _estimate_initial_dist_ds(
- target_dist_t, class_values_ds, dist_estimation_batch_size=32,
- smoothing_constant=10):
- num_classes = (target_dist_t.shape[0].value or
- array_ops.shape(target_dist_t)[0])
- initial_examples_per_class_seen = array_ops.fill(
- [num_classes], np.int64(smoothing_constant))
-
- def update_estimate_and_tile(num_examples_per_class_seen, c):
- updated_examples_per_class_seen, dist = _estimate_data_distribution(
- c, num_examples_per_class_seen)
- tiled_dist = array_ops.tile(
- array_ops.expand_dims(dist, 0), [dist_estimation_batch_size, 1])
- return updated_examples_per_class_seen, tiled_dist
-
- initial_dist_ds = (class_values_ds.batch(dist_estimation_batch_size)
- .apply(scan_ops.scan(initial_examples_per_class_seen,
- update_estimate_and_tile))
- .apply(batching.unbatch()))
-
- return initial_dist_ds
-
-
-def _get_target_to_initial_ratio(initial_probs, target_probs):
- # Add tiny to initial_probs to avoid divide by zero.
- denom = (initial_probs + np.finfo(initial_probs.dtype.as_numpy_dtype).tiny)
- return target_probs / denom
-
-
-def _estimate_data_distribution(c, num_examples_per_class_seen):
- """Estimate data distribution as labels are seen.
-
- Args:
- c: The class labels. Type `int32`, shape `[batch_size]`.
- num_examples_per_class_seen: Type `int64`, shape `[num_classes]`,
- containing counts.
-
- Returns:
- num_examples_per_lass_seen: Updated counts. Type `int64`, shape
- `[num_classes]`.
- dist: The updated distribution. Type `float32`, shape `[num_classes]`.
- """
- num_classes = num_examples_per_class_seen.get_shape()[0].value
- # Update the class-count based on what labels are seen in batch.
- num_examples_per_class_seen = math_ops.add(
- num_examples_per_class_seen, math_ops.reduce_sum(
- array_ops.one_hot(c, num_classes, dtype=dtypes.int64), 0))
- init_prob_estimate = math_ops.truediv(
- num_examples_per_class_seen,
- math_ops.reduce_sum(num_examples_per_class_seen))
- dist = math_ops.cast(init_prob_estimate, dtypes.float32)
- return num_examples_per_class_seen, dist
-
-
-def _calculate_acceptance_probs_with_mixing(initial_probs, target_probs):
- """Calculates the acceptance probabilities and mixing ratio.
-
- In this case, we assume that we can *either* sample from the original data
- distribution with probability `m`, or sample from a reshaped distribution
- that comes from rejection sampling on the original distribution. This
- rejection sampling is done on a per-class basis, with `a_i` representing the
- probability of accepting data from class `i`.
-
- This method is based on solving the following analysis for the reshaped
- distribution:
-
- Let F be the probability of a rejection (on any example).
- Let p_i be the proportion of examples in the data in class i (init_probs)
- Let a_i is the rate the rejection sampler should *accept* class i
- Let t_i is the target proportion in the minibatches for class i (target_probs)
-
- ```
- F = sum_i(p_i * (1-a_i))
- = 1 - sum_i(p_i * a_i) using sum_i(p_i) = 1
- ```
-
- An example with class `i` will be accepted if `k` rejections occur, then an
- example with class `i` is seen by the rejector, and it is accepted. This can
- be written as follows:
-
- ```
- t_i = sum_k=0^inf(F^k * p_i * a_i)
- = p_i * a_j / (1 - F) using geometric series identity, since 0 <= F < 1
- = p_i * a_i / sum_j(p_j * a_j) using F from above
- ```
-
- Note that the following constraints hold:
- ```
- 0 <= p_i <= 1, sum_i(p_i) = 1
- 0 <= a_i <= 1
- 0 <= t_i <= 1, sum_i(t_i) = 1
- ```
-
- A solution for a_i in terms of the other variables is the following:
- ```a_i = (t_i / p_i) / max_i[t_i / p_i]```
-
- If we try to minimize the amount of data rejected, we get the following:
-
- M_max = max_i [ t_i / p_i ]
- M_min = min_i [ t_i / p_i ]
-
- The desired probability of accepting data if it comes from class `i`:
-
- a_i = (t_i/p_i - m) / (M_max - m)
-
- The desired probability of pulling a data element from the original dataset,
- rather than the filtered one:
-
- m = M_min
-
- Args:
- initial_probs: A Tensor of the initial probability distribution, given or
- estimated.
- target_probs: A Tensor of the corresponding classes.
-
- Returns:
- (A 1D Tensor with the per-class acceptance probabilities, the desired
- probability of pull from the original distribution.)
- """
- ratio_l = _get_target_to_initial_ratio(initial_probs, target_probs)
- max_ratio = math_ops.reduce_max(ratio_l)
- min_ratio = math_ops.reduce_min(ratio_l)
-
- # Target prob to sample from original distribution.
- m = min_ratio
-
- # TODO(joelshor): Simplify fraction, if possible.
- a_i = (ratio_l - m) / (max_ratio - m)
- return a_i, m
+ return resampling.rejection_resample(class_func, target_dist, initial_dist,
+ seed)
diff --git a/tensorflow/contrib/data/python/ops/scan_ops.py b/tensorflow/contrib/data/python/ops/scan_ops.py
index c52582cd35..0ca9fddb23 100644
--- a/tensorflow/contrib/data/python/ops/scan_ops.py
+++ b/tensorflow/contrib/data/python/ops/scan_ops.py
@@ -17,137 +17,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import collections
-
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.util import nest
-from tensorflow.python.data.util import sparse
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.ops import gen_dataset_ops
-
-
-class _ScanDataset(dataset_ops.UnaryDataset):
- """A dataset that scans a function across its input."""
-
- def __init__(self, input_dataset, initial_state, scan_func):
- """See `scan()` for details."""
- super(_ScanDataset, self).__init__(input_dataset)
- self._input_dataset = input_dataset
-
- with ops.name_scope("initial_state"):
- # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
- # values to tensors.
- self._initial_state = nest.pack_sequence_as(initial_state, [
- sparse_tensor.SparseTensor.from_value(t)
- if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(
- t, name="component_%d" % i)
- for i, t in enumerate(nest.flatten(initial_state))
- ])
-
- # Compute initial values for the state classes, shapes and types based on
- # the initial state. The shapes may be refined by running `tf_scan_func` one
- # or more times below.
- self._state_classes = sparse.get_classes(self._initial_state)
- self._state_shapes = nest.pack_sequence_as(
- self._initial_state,
- [t.get_shape() for t in nest.flatten(self._initial_state)])
- self._state_types = nest.pack_sequence_as(
- self._initial_state,
- [t.dtype for t in nest.flatten(self._initial_state)])
-
- # Will be populated by calling `tf_scan_func`.
- self._output_classes = None
- self._output_shapes = None
- self._output_types = None
-
- # Iteratively rerun the scan function until reaching a fixed point on
- # `self._state_shapes`.
- need_to_rerun = True
- while need_to_rerun:
-
- wrapped_func = dataset_ops.StructuredFunctionWrapper(
- scan_func, "tf.contrib.data.scan()",
- input_classes=(self._state_classes, input_dataset.output_classes),
- input_shapes=(self._state_shapes, input_dataset.output_shapes),
- input_types=(self._state_types, input_dataset.output_types),
- add_to_graph=False)
- if not (
- isinstance(wrapped_func.output_types, collections.Sequence) and
- len(wrapped_func.output_types) == 2):
- raise TypeError("The scan function must return a pair comprising the "
- "new state and the output value.")
-
- new_state_classes, self._output_classes = wrapped_func.output_classes
-
- # Extract and validate class information from the returned values.
- for new_state_class, state_class in zip(
- nest.flatten(new_state_classes),
- nest.flatten(self._state_classes)):
- if not issubclass(new_state_class, state_class):
- raise TypeError(
- "The element classes for the new state must match the initial "
- "state. Expected %s; got %s." %
- (self._state_classes, new_state_classes))
-
- # Extract and validate type information from the returned values.
- new_state_types, self._output_types = wrapped_func.output_types
- for new_state_type, state_type in zip(
- nest.flatten(new_state_types), nest.flatten(self._state_types)):
- if new_state_type != state_type:
- raise TypeError(
- "The element types for the new state must match the initial "
- "state. Expected %s; got %s." %
- (self._state_types, new_state_types))
-
- # Extract shape information from the returned values.
- new_state_shapes, self._output_shapes = wrapped_func.output_shapes
-
- flat_state_shapes = nest.flatten(self._state_shapes)
- flat_new_state_shapes = nest.flatten(new_state_shapes)
- weakened_state_shapes = [
- original.most_specific_compatible_shape(new)
- for original, new in zip(flat_state_shapes, flat_new_state_shapes)
- ]
-
- need_to_rerun = False
- for original_shape, weakened_shape in zip(flat_state_shapes,
- weakened_state_shapes):
- if original_shape.ndims is not None and (
- weakened_shape.ndims is None or
- original_shape.as_list() != weakened_shape.as_list()):
- need_to_rerun = True
- break
-
- if need_to_rerun:
- self._state_shapes = nest.pack_sequence_as(self._state_shapes,
- weakened_state_shapes)
-
- self._scan_func = wrapped_func.function
- self._scan_func.add_to_graph(ops.get_default_graph())
-
- def _as_variant_tensor(self):
- input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access
- return gen_dataset_ops.scan_dataset(
- input_t,
- nest.flatten(sparse.serialize_sparse_tensors(self._initial_state)),
- self._scan_func.captured_inputs,
- f=self._scan_func,
- **dataset_ops.flat_structure(self))
-
- @property
- def output_classes(self):
- return self._output_classes
-
- @property
- def output_shapes(self):
- return self._output_shapes
-
- @property
- def output_types(self):
- return self._output_types
+from tensorflow.python.data.experimental.ops import scan_ops
+from tensorflow.python.util import deprecation
+@deprecation.deprecated(None, "Use `tf.data.experimental.scan(...)`.")
def scan(initial_state, scan_func):
"""A transformation that scans a function across an input dataset.
@@ -168,7 +42,4 @@ def scan(initial_state, scan_func):
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
- def _apply_fn(dataset):
- return _ScanDataset(dataset, initial_state, scan_func)
-
- return _apply_fn
+ return scan_ops.scan(initial_state, scan_func)
diff --git a/tensorflow/contrib/data/python/ops/shuffle_ops.py b/tensorflow/contrib/data/python/ops/shuffle_ops.py
index 985d1d87d0..329b34fdfe 100644
--- a/tensorflow/contrib/data/python/ops/shuffle_ops.py
+++ b/tensorflow/contrib/data/python/ops/shuffle_ops.py
@@ -17,54 +17,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.util import random_seed
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import gen_dataset_ops
-
-
-class _ShuffleAndRepeatDataset(dataset_ops.UnaryDataset):
- """A `Dataset` that fuses `shuffle` and `repeat`."""
-
- def __init__(self, input_dataset, buffer_size, count=None, seed=None):
- super(_ShuffleAndRepeatDataset, self).__init__(input_dataset)
- self._input_dataset = input_dataset
- self._buffer_size = ops.convert_to_tensor(
- buffer_size, dtype=dtypes.int64, name="buffer_size")
- if count is None:
- self._count = constant_op.constant(-1, dtype=dtypes.int64, name="count")
- else:
- self._count = ops.convert_to_tensor(
- count, dtype=dtypes.int64, name="count")
- self._seed, self._seed2 = random_seed.get_seed(seed)
-
- def _as_variant_tensor(self):
- # pylint: disable=protected-access
- input_resource = self._input_dataset._as_variant_tensor()
- return gen_dataset_ops.shuffle_and_repeat_dataset(
- input_resource,
- buffer_size=self._buffer_size,
- count=self._count,
- seed=self._seed,
- seed2=self._seed2,
- **dataset_ops.flat_structure(self))
- # pylint: enable=protected-access
-
- @property
- def output_classes(self):
- return self._input_dataset.output_classes
-
- @property
- def output_shapes(self):
- return self._input_dataset.output_shapes
-
- @property
- def output_types(self):
- return self._input_dataset.output_types
+from tensorflow.python.data.experimental.ops import shuffle_ops
+from tensorflow.python.util import deprecation
+@deprecation.deprecated(None,
+ "Use `tf.data.experimental.shuffle_and_repeat(...)`.")
def shuffle_and_repeat(buffer_size, count=None, seed=None):
"""Shuffles and repeats a Dataset returning a new permutation for each epoch.
@@ -93,8 +51,4 @@ def shuffle_and_repeat(buffer_size, count=None, seed=None):
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
-
- def _apply_fn(dataset): # pylint: disable=missing-docstring
- return _ShuffleAndRepeatDataset(dataset, buffer_size, count, seed)
-
- return _apply_fn
+ return shuffle_ops.shuffle_and_repeat(buffer_size, count, seed)
diff --git a/tensorflow/contrib/data/python/ops/threadpool.py b/tensorflow/contrib/data/python/ops/threadpool.py
index f73c3fd9cb..20cceb4647 100644
--- a/tensorflow/contrib/data/python/ops/threadpool.py
+++ b/tensorflow/contrib/data/python/ops/threadpool.py
@@ -17,88 +17,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import threading
-
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.eager import context
-from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
-from tensorflow.python.ops import resource_variable_ops
-
-_uid_counter = 0
-_uid_lock = threading.Lock()
-
-
-def _generate_shared_name(prefix):
- with _uid_lock:
- global _uid_counter
- uid = _uid_counter
- _uid_counter += 1
- return "{}{}".format(prefix, uid)
-
-
-# TODO(b/73383364): Properly export in the `tf.contrib.data` API when stable
-# or make private / remove.
-class PrivateThreadPool(object):
- """A stateful resource that represents a private thread pool."""
-
- def __init__(self, num_threads, display_name=None,
- max_intra_op_parallelism=1):
- """Creates a `PrivateThreadPool` with the given number of threads."""
- if context.executing_eagerly():
- shared_name = _generate_shared_name("privatethreadpool")
- self._resource = ged_ops.experimental_thread_pool_handle(
- num_threads=num_threads,
- max_intra_op_parallelism=max_intra_op_parallelism,
- display_name=display_name,
- shared_name=shared_name)
- self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
- handle=self._resource, handle_device=context.context().device_name)
- else:
- self._resource = ged_ops.experimental_thread_pool_handle(
- num_threads=num_threads,
- max_intra_op_parallelism=max_intra_op_parallelism,
- display_name=display_name)
-
-
-class _ThreadPoolDataset(dataset_ops.UnaryDataset):
- """A `Dataset` that acts as an identity, and sets a custom threadpool."""
-
- def __init__(self, input_dataset, thread_pool):
- super(_ThreadPoolDataset, self).__init__(input_dataset)
- self._input_dataset = input_dataset
- self._thread_pool = thread_pool
-
- def _as_variant_tensor(self):
- return ged_ops.experimental_thread_pool_dataset(
- self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- self._thread_pool._resource, # pylint: disable=protected-access
- **dataset_ops.flat_structure(self))
-
- @property
- def output_shapes(self):
- return self._input_dataset.output_shapes
-
- @property
- def output_types(self):
- return self._input_dataset.output_types
-
- @property
- def output_classes(self):
- return self._input_dataset.output_classes
-
-
-# TODO(b/73383364): Properly export in the `tf.contrib.data` API when stable
-# or make private / remove.
-def override_threadpool(dataset, thread_pool):
- """Returns a new dataset that uses the given thread pool for its operations.
-
- Args:
- dataset: A `tf.data.Dataset` object.
- thread_pool: A `PrivateThreadPool` object.
-
- Returns:
- A dataset containing the same values as `dataset`, but which uses
- `thread_pool` to compute any of its parallel operations (such as
- `tf.data.Dataset.map`).
- """
- return _ThreadPoolDataset(dataset, thread_pool)
+# pylint: disable=unused-import
+from tensorflow.python.data.experimental.ops.threadpool import override_threadpool
+from tensorflow.python.data.experimental.ops.threadpool import PrivateThreadPool
diff --git a/tensorflow/contrib/data/python/ops/unique.py b/tensorflow/contrib/data/python/ops/unique.py
index ed363a7090..909d06c677 100644
--- a/tensorflow/contrib/data/python/ops/unique.py
+++ b/tensorflow/contrib/data/python/ops/unique.py
@@ -17,11 +17,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import dtypes
-from tensorflow.python.ops import gen_experimental_dataset_ops
+from tensorflow.python.data.experimental.ops import unique as experimental_unique
+from tensorflow.python.util import deprecation
+@deprecation.deprecated(None, "Use `tf.data.experimental.unique()`.")
def unique():
"""Creates a `Dataset` from another `Dataset`, discarding duplicates.
@@ -39,39 +39,4 @@ def unique():
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
-
- def _apply_fn(dataset):
- return _UniqueDataset(dataset)
-
- return _apply_fn
-
-
-class _UniqueDataset(dataset_ops.UnaryDataset):
- """A `Dataset` contains the unique elements from its input."""
-
- def __init__(self, input_dataset):
- """See `unique()` for details."""
- super(_UniqueDataset, self).__init__(input_dataset)
- self._input_dataset = input_dataset
- if input_dataset.output_types not in (dtypes.int32, dtypes.int64,
- dtypes.string):
- raise TypeError(
- "`tf.contrib.data.unique()` only supports inputs with a single "
- "`tf.int32`, `tf.int64`, or `tf.string` component.")
-
- def _as_variant_tensor(self):
- return gen_experimental_dataset_ops.experimental_unique_dataset(
- self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- **dataset_ops.flat_structure(self))
-
- @property
- def output_classes(self):
- return self._input_dataset.output_classes
-
- @property
- def output_shapes(self):
- return self._input_dataset.output_shapes
-
- @property
- def output_types(self):
- return self._input_dataset.output_types
+ return experimental_unique.unique()
diff --git a/tensorflow/contrib/data/python/ops/writers.py b/tensorflow/contrib/data/python/ops/writers.py
index c455fdcba6..42fb69bf07 100644
--- a/tensorflow/contrib/data/python/ops/writers.py
+++ b/tensorflow/contrib/data/python/ops/writers.py
@@ -17,42 +17,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.util import convert
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.data.experimental.ops import writers
+from tensorflow.python.util import deprecation
-class TFRecordWriter(object):
+class TFRecordWriter(writers.TFRecordWriter):
"""Writes data to a TFRecord file."""
+ @deprecation.deprecated(
+ None, "Use `tf.data.experimental.TFRecordWriter(...)`.")
def __init__(self, filename, compression_type=None):
- self._filename = ops.convert_to_tensor(
- filename, dtypes.string, name="filename")
- self._compression_type = convert.optional_param_to_tensor(
- "compression_type",
- compression_type,
- argument_default="",
- argument_dtype=dtypes.string)
-
- def write(self, dataset):
- """Returns a `tf.Operation` to write a dataset to a file.
-
- Args:
- dataset: a `tf.data.Dataset` whose elements are to be written to a file
-
- Returns:
- A `tf.Operation` that, when run, writes contents of `dataset` to a file.
- """
- if not isinstance(dataset, dataset_ops.Dataset):
- raise TypeError("`dataset` must be a `tf.data.Dataset` object.")
- if (dataset.output_types != dtypes.string or
- dataset.output_shapes != tensor_shape.scalar()):
- raise TypeError(
- "`dataset` must produce scalar `DT_STRING` tensors whereas it "
- "produces shape {0} and types {1}".format(dataset.output_shapes,
- dataset.output_types))
- return gen_dataset_ops.dataset_to_tf_record(
- dataset._as_variant_tensor(), self._filename, self._compression_type) # pylint: disable=protected-access
+ super(TFRecordWriter, self).__init__(filename, compression_type)
diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2.py
index 8d949943b7..d48aa9c89b 100644
--- a/tensorflow/contrib/distribute/python/prefetching_ops_v2.py
+++ b/tensorflow/contrib/distribute/python/prefetching_ops_v2.py
@@ -19,7 +19,7 @@ from __future__ import print_function
import warnings
-from tensorflow.contrib.data.python.ops import prefetching_ops
+from tensorflow.python.data.experimental.ops import prefetching_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.util import nest as data_nest
diff --git a/tensorflow/contrib/eager/python/datasets.py b/tensorflow/contrib/eager/python/datasets.py
index 135095a979..3aed121233 100644
--- a/tensorflow/contrib/eager/python/datasets.py
+++ b/tensorflow/contrib/eager/python/datasets.py
@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.ops import prefetching_ops
+from tensorflow.python.data.experimental.ops import prefetching_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
@@ -54,7 +54,7 @@ class Iterator(iterator_ops.EagerIterator):
"""
if isinstance(dataset, prefetching_ops._PrefetchToDeviceDataset): # pylint: disable=protected-access
raise TypeError(
- "`tf.contrib.data.prefetch_to_device()` is not compatible with "
+ "`tf.data.experimental.prefetch_to_device()` is not compatible with "
"`tf.contrib.eager.Iterator`. Use `for ... in dataset:` to iterate "
"over the dataset instead.")
diff --git a/tensorflow/contrib/eager/python/datasets_test.py b/tensorflow/contrib/eager/python/datasets_test.py
index a753d77580..6a508fc6ba 100644
--- a/tensorflow/contrib/eager/python/datasets_test.py
+++ b/tensorflow/contrib/eager/python/datasets_test.py
@@ -24,11 +24,11 @@ import time
import numpy as np
from tensorflow.contrib import lookup
-from tensorflow.contrib.data.python.ops import prefetching_ops
-from tensorflow.contrib.data.python.ops import threadpool
-from tensorflow.contrib.data.python.ops import unique
from tensorflow.contrib.eager.python import datasets
from tensorflow.python.data import Dataset
+from tensorflow.python.data.experimental.ops import prefetching_ops
+from tensorflow.python.data.experimental.ops import threadpool
+from tensorflow.python.data.experimental.ops import unique
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
diff --git a/tensorflow/contrib/eager/python/examples/revnet/imagenet_input.py b/tensorflow/contrib/eager/python/examples/revnet/imagenet_input.py
index 34a9984b0e..d85188de03 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/imagenet_input.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/imagenet_input.py
@@ -169,11 +169,11 @@ class ImageNetInput(object):
# Read the data from disk in parallel
dataset = dataset.apply(
- tf.contrib.data.parallel_interleave(
+ tf.data.experimental.parallel_interleave(
fetch_dataset, cycle_length=self.num_parallel_calls, sloppy=True))
if self.cache:
dataset = dataset.cache().apply(
- tf.contrib.data.shuffle_and_repeat(1024 * 16))
+ tf.data.experimental.shuffle_and_repeat(1024 * 16))
else:
dataset = dataset.shuffle(1024)
@@ -188,9 +188,11 @@ class ImageNetInput(object):
# batch size. As long as this validation is done with consistent batch size,
# exactly the same images will be used.
dataset = dataset.apply(
- tf.contrib.data.map_and_batch(
- self.dataset_parser, batch_size=batch_size,
- num_parallel_batches=self.num_cores, drop_remainder=True))
+ tf.data.experimental.map_and_batch(
+ self.dataset_parser,
+ batch_size=batch_size,
+ num_parallel_batches=self.num_cores,
+ drop_remainder=True))
# Transpose for performance on TPU
if self.transpose_input:
diff --git a/tensorflow/contrib/estimator/python/estimator/rnn_test.py b/tensorflow/contrib/estimator/python/estimator/rnn_test.py
index 1aebed348d..89506ee661 100644
--- a/tensorflow/contrib/estimator/python/estimator/rnn_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/rnn_test.py
@@ -25,12 +25,12 @@ import tempfile
import numpy as np
import six
-from tensorflow.contrib.data.python.ops import readers
from tensorflow.contrib.estimator.python.estimator import head as head_lib
from tensorflow.contrib.estimator.python.estimator import rnn
from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column as seq_fc
from tensorflow.core.example import example_pb2
from tensorflow.core.example import feature_pb2
+from tensorflow.python.data.experimental.ops import readers
from tensorflow.python.estimator import model_fn
from tensorflow.python.estimator.canned import metric_keys
from tensorflow.python.estimator.canned import parsing_utils
diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py
index 89b538d1ba..9e9345e875 100644
--- a/tensorflow/contrib/lookup/lookup_ops_test.py
+++ b/tensorflow/contrib/lookup/lookup_ops_test.py
@@ -23,8 +23,8 @@ import numpy as np
import six
from tensorflow.contrib import lookup
-from tensorflow.contrib.data.python.ops import counter
from tensorflow.python.client import session
+from tensorflow.python.data.experimental.ops import counter
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
diff --git a/tensorflow/contrib/stateless/BUILD b/tensorflow/contrib/stateless/BUILD
index dcbef2881d..a217397c1a 100644
--- a/tensorflow/contrib/stateless/BUILD
+++ b/tensorflow/contrib/stateless/BUILD
@@ -9,19 +9,13 @@ exports_files(["LICENSE"])
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
-tf_gen_op_wrapper_py(
- name = "stateless_random_ops",
- out = "gen_stateless_random_ops.py", # cmake chokes without this
- deps = ["//tensorflow/core:stateless_random_ops_op_lib"],
-)
-
py_library(
name = "stateless",
srcs = ["__init__.py"],
srcs_version = "PY2AND3",
deps = [
- ":stateless_random_ops",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:stateless_random_ops_gen",
"//tensorflow/python:util",
],
)
diff --git a/tensorflow/contrib/stateless/__init__.py b/tensorflow/contrib/stateless/__init__.py
index 0cca40f071..fe23fe0dd8 100644
--- a/tensorflow/contrib/stateless/__init__.py
+++ b/tensorflow/contrib/stateless/__init__.py
@@ -32,10 +32,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.framework import ops
+
# pylint: disable=wildcard-import
-from tensorflow.contrib.stateless.gen_stateless_random_ops import *
+from tensorflow.python.ops.gen_stateless_random_ops import *
-from tensorflow.python.framework import ops
from tensorflow.python.util.all_util import remove_undocumented
ops.NotDifferentiable("StatelessMultinomial")
diff --git a/tensorflow/contrib/tpu/python/tpu/datasets.py b/tensorflow/contrib/tpu/python/tpu/datasets.py
index d879170b68..c694e9c1bc 100644
--- a/tensorflow/contrib/tpu/python/tpu/datasets.py
+++ b/tensorflow/contrib/tpu/python/tpu/datasets.py
@@ -18,8 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.ops import batching
-from tensorflow.contrib.data.python.ops import interleave_ops
+from tensorflow.python.data.experimental.ops import batching
+from tensorflow.python.data.experimental.ops import interleave_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.ops import readers
diff --git a/tensorflow/contrib/tpu/tpu_estimator.md b/tensorflow/contrib/tpu/tpu_estimator.md
index 639e708169..b6514e19dc 100644
--- a/tensorflow/contrib/tpu/tpu_estimator.md
+++ b/tensorflow/contrib/tpu/tpu_estimator.md
@@ -87,7 +87,7 @@ handle training:
label = tf.cast(features["label"], tf.int32)
return image, label
- dataset = tf.contrib.data.TFRecordDataset(
+ dataset = tf.data.TFRecordDataset(
filename, buffer_size=FLAGS.dataset_reader_buffer_size)
dataset = dataset.map(parser).cache().repeat().batch(batch_size)
images, labels = dataset.make_one_shot_iterator().get_next()
diff --git a/tensorflow/contrib/training/BUILD b/tensorflow/contrib/training/BUILD
index b565ebd073..00295f57f6 100644
--- a/tensorflow/contrib/training/BUILD
+++ b/tensorflow/contrib/training/BUILD
@@ -295,7 +295,6 @@ py_test(
tags = ["notsan"],
deps = [
":training_py",
- "//tensorflow/contrib/data/python/kernel_tests/serialization:dataset_serialization_test_base",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:gradients",
@@ -305,6 +304,7 @@ py_test(
"//tensorflow/python:training",
"//tensorflow/python:variables",
"//tensorflow/python/data",
+ "//tensorflow/python/data/experimental/kernel_tests/serialization:dataset_serialization_test_base",
"//third_party/py/numpy",
],
)
diff --git a/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py b/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py
index d9b0511a98..c1657fec7b 100644
--- a/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py
+++ b/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py
@@ -19,8 +19,8 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.contrib.training.python.training import tensor_queue_dataset as tqd
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
diff --git a/tensorflow/core/api_def/python_api/api_def_StatelessMultinomial.pbtxt b/tensorflow/core/api_def/python_api/api_def_StatelessMultinomial.pbtxt
new file mode 100644
index 0000000000..d3c70190dd
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_StatelessMultinomial.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "StatelessMultinomial"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_StatelessRandomNormal.pbtxt b/tensorflow/core/api_def/python_api/api_def_StatelessRandomNormal.pbtxt
new file mode 100644
index 0000000000..e294325fb8
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_StatelessRandomNormal.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "StatelessRandomNormal"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_StatelessRandomUniform.pbtxt b/tensorflow/core/api_def/python_api/api_def_StatelessRandomUniform.pbtxt
new file mode 100644
index 0000000000..95d414c54a
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_StatelessRandomUniform.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "StatelessRandomUniform"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_StatelessTruncatedNormal.pbtxt b/tensorflow/core/api_def/python_api/api_def_StatelessTruncatedNormal.pbtxt
new file mode 100644
index 0000000000..c72bdda94a
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_StatelessTruncatedNormal.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "StatelessTruncatedNormal"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/examples/get_started/regression/test.py b/tensorflow/examples/get_started/regression/test.py
index 0b1477ad96..bb4db6700b 100644
--- a/tensorflow/examples/get_started/regression/test.py
+++ b/tensorflow/examples/get_started/regression/test.py
@@ -29,7 +29,7 @@ import tensorflow.examples.get_started.regression.imports85 as imports85
sys.modules["imports85"] = imports85
# pylint: disable=g-bad-import-order,g-import-not-at-top
-import tensorflow.contrib.data as data
+import tensorflow.data as data
import tensorflow.examples.get_started.regression.dnn_regression as dnn_regression
import tensorflow.examples.get_started.regression.linear_regression as linear_regression
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 9275ad767e..fe81254ef7 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -1740,6 +1740,14 @@ tf_gen_op_wrapper_private_py(
)
tf_gen_op_wrapper_private_py(
+ name = "stateless_random_ops_gen",
+ visibility = [
+ "//tensorflow/contrib/stateless:__pkg__",
+ "//tensorflow/python/data/experimental/ops:__pkg__",
+ ],
+)
+
+tf_gen_op_wrapper_private_py(
name = "list_ops_gen",
)
@@ -3302,9 +3310,11 @@ py_library(
"training/checkpointable/**/*.py",
# The following targets have their own build rules (same name as the
# file):
+ "training/basic_session_run_hooks.py",
"training/checkpoint_management.py",
"training/saveable_object.py",
"training/saver.py",
+ "training/session_run_hook.py",
"training/training_util.py",
],
),
@@ -3312,6 +3322,7 @@ py_library(
deps = [
":array_ops",
":array_ops_gen",
+ ":basic_session_run_hooks",
":checkpoint_management",
":checkpoint_ops_gen",
":client",
@@ -3336,6 +3347,7 @@ py_library(
":saver",
":sdca_ops",
":session",
+ ":session_run_hook",
":sparse_ops",
":sparse_tensor",
":state_ops",
@@ -3380,6 +3392,28 @@ py_library(
)
py_library(
+ name = "session_run_hook",
+ srcs = ["training/session_run_hook.py"],
+ srcs_version = "PY2AND3",
+ deps = [":util"],
+)
+
+py_library(
+ name = "basic_session_run_hooks",
+ srcs = ["training/basic_session_run_hooks.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":client",
+ ":framework",
+ ":platform",
+ ":protos_all_py",
+ ":session_run_hook",
+ ":training_util",
+ ":util",
+ ],
+)
+
+py_library(
name = "saver",
srcs = ["training/saver.py"],
srcs_version = "PY2AND3",
diff --git a/tensorflow/python/data/BUILD b/tensorflow/python/data/BUILD
index 138141f4fc..e32eeecbb8 100644
--- a/tensorflow/python/data/BUILD
+++ b/tensorflow/python/data/BUILD
@@ -10,6 +10,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:util",
+ "//tensorflow/python/data/experimental",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:iterator_ops",
"//tensorflow/python/data/ops:multi_device_iterator_ops",
diff --git a/tensorflow/python/data/__init__.py b/tensorflow/python/data/__init__.py
index f8b561205e..7536ba668a 100644
--- a/tensorflow/python/data/__init__.py
+++ b/tensorflow/python/data/__init__.py
@@ -22,6 +22,7 @@ from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import
+from tensorflow.python.data import experimental
from tensorflow.python.data.ops.dataset_ops import Dataset
from tensorflow.python.data.ops.iterator_ops import Iterator
from tensorflow.python.data.ops.readers import FixedLengthRecordDataset
diff --git a/tensorflow/python/data/experimental/BUILD b/tensorflow/python/data/experimental/BUILD
new file mode 100644
index 0000000000..84e761d376
--- /dev/null
+++ b/tensorflow/python/data/experimental/BUILD
@@ -0,0 +1,16 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+py_library(
+ name = "experimental",
+ srcs = ["__init__.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/experimental/ops:dataset_ops",
+ "//tensorflow/python/data/experimental/ops:iterator_ops",
+ ],
+)
diff --git a/tensorflow/python/data/experimental/__init__.py b/tensorflow/python/data/experimental/__init__.py
new file mode 100644
index 0000000000..2ac159d38a
--- /dev/null
+++ b/tensorflow/python/data/experimental/__init__.py
@@ -0,0 +1,109 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Experimental API for building input pipelines.
+
+This module contains experimental `Dataset` sources and transformations that can
+be used in conjunction with the `tf.data.Dataset` API. Note that the
+`tf.data.experimental` API is not subject to the same backwards compatibility
+guarantees as `tf.data`, but we will provide deprecation advice in advance of
+removing existing functionality.
+
+See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
+
+@@Counter
+@@CheckpointInputPipelineHook
+@@CsvDataset
+@@Optional
+@@RandomDataset
+@@Reducer
+@@SqlDataset
+@@TFRecordWriter
+
+@@bucket_by_sequence_length
+@@choose_from_datasets
+@@copy_to_device
+@@dense_to_sparse_batch
+@@enumerate_dataset
+@@get_next_as_optional
+@@get_single_element
+@@group_by_reducer
+@@group_by_window
+@@ignore_errors
+@@latency_stats
+@@make_batched_features_dataset
+@@make_csv_dataset
+@@make_saveable_from_iterator
+@@map_and_batch
+@@parallel_interleave
+@@parse_example_dataset
+@@prefetch_to_device
+@@rejection_resample
+@@sample_from_datasets
+@@scan
+@@set_stats_aggregator
+@@shuffle_and_repeat
+@@StatsAggregator
+@@unbatch
+@@unique
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=unused-import
+
+from tensorflow.python.data.experimental.ops.batching import dense_to_sparse_batch
+from tensorflow.python.data.experimental.ops.batching import map_and_batch
+from tensorflow.python.data.experimental.ops.batching import unbatch
+from tensorflow.python.data.experimental.ops.counter import Counter
+from tensorflow.python.data.experimental.ops.enumerate_ops import enumerate_dataset
+from tensorflow.python.data.experimental.ops.error_ops import ignore_errors
+from tensorflow.python.data.experimental.ops.get_single_element import get_single_element
+from tensorflow.python.data.experimental.ops.grouping import bucket_by_sequence_length
+from tensorflow.python.data.experimental.ops.grouping import group_by_reducer
+from tensorflow.python.data.experimental.ops.grouping import group_by_window
+from tensorflow.python.data.experimental.ops.grouping import Reducer
+from tensorflow.python.data.experimental.ops.interleave_ops import choose_from_datasets
+from tensorflow.python.data.experimental.ops.interleave_ops import parallel_interleave
+from tensorflow.python.data.experimental.ops.interleave_ops import sample_from_datasets
+from tensorflow.python.data.experimental.ops.iterator_ops import CheckpointInputPipelineHook
+from tensorflow.python.data.experimental.ops.iterator_ops import make_saveable_from_iterator
+
+# Optimization constant that can be used to enable auto-tuning.
+from tensorflow.python.data.experimental.ops.optimization import AUTOTUNE
+
+from tensorflow.python.data.experimental.ops.parsing_ops import parse_example_dataset
+from tensorflow.python.data.experimental.ops.prefetching_ops import copy_to_device
+from tensorflow.python.data.experimental.ops.prefetching_ops import prefetch_to_device
+from tensorflow.python.data.experimental.ops.random_ops import RandomDataset
+from tensorflow.python.data.experimental.ops.readers import CsvDataset
+from tensorflow.python.data.experimental.ops.readers import make_batched_features_dataset
+from tensorflow.python.data.experimental.ops.readers import make_csv_dataset
+from tensorflow.python.data.experimental.ops.readers import SqlDataset
+from tensorflow.python.data.experimental.ops.resampling import rejection_resample
+from tensorflow.python.data.experimental.ops.scan_ops import scan
+from tensorflow.python.data.experimental.ops.shuffle_ops import shuffle_and_repeat
+from tensorflow.python.data.experimental.ops.stats_ops import latency_stats
+from tensorflow.python.data.experimental.ops.stats_ops import set_stats_aggregator
+from tensorflow.python.data.experimental.ops.stats_ops import StatsAggregator
+from tensorflow.python.data.experimental.ops.unique import unique
+from tensorflow.python.data.experimental.ops.writers import TFRecordWriter
+from tensorflow.python.data.ops.iterator_ops import get_next_as_optional
+from tensorflow.python.data.ops.optional_ops import Optional
+# pylint: enable=unused-import
+
+from tensorflow.python.util.all_util import remove_undocumented
+remove_undocumented(__name__)
diff --git a/tensorflow/python/data/experimental/kernel_tests/BUILD b/tensorflow/python/data/experimental/kernel_tests/BUILD
new file mode 100644
index 0000000000..a46c30ed2e
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/BUILD
@@ -0,0 +1,569 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "cuda_py_test")
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+py_test(
+ name = "batch_dataset_op_test",
+ size = "medium",
+ srcs = ["batch_dataset_op_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss", # (b/79552534)
+ "no_pip",
+ ],
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:script_ops",
+ "//tensorflow/python:session",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:string_ops",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/experimental/ops:batching",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+py_test(
+ name = "bucketing_test",
+ size = "medium",
+ srcs = ["bucketing_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:string_ops",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/experimental/ops:grouping",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "csv_dataset_op_test",
+ size = "medium",
+ srcs = ["csv_dataset_op_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:session",
+ "//tensorflow/python/data/experimental/ops:error_ops",
+ "//tensorflow/python/data/experimental/ops:readers",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:readers",
+ "//tensorflow/python/eager:context",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "dataset_constructor_op_test",
+ size = "medium",
+ srcs = ["dataset_constructor_op_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "manual",
+ "nomac", # b/62040583
+ ],
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python/data/experimental/ops:batching",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ ],
+)
+
+py_test(
+ name = "directed_interleave_dataset_test",
+ size = "medium",
+ srcs = ["directed_interleave_dataset_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:random_seed",
+ "//tensorflow/python/data/experimental/ops:interleave_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "get_single_element_test",
+ size = "small",
+ srcs = ["get_single_element_test.py"],
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/experimental/ops:get_single_element",
+ "//tensorflow/python/data/experimental/ops:grouping",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+py_test(
+ name = "indexed_dataset_ops_test",
+ srcs = ["indexed_dataset_ops_test.py"],
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:experimental_dataset_ops_gen",
+ "//tensorflow/python/data/experimental/ops:indexed_dataset_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "interleave_dataset_op_test",
+ size = "medium",
+ srcs = ["interleave_dataset_op_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "notap",
+ ],
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:script_ops",
+ "//tensorflow/python:sparse_ops",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/experimental/ops:interleave_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "@six_archive//:six",
+ ],
+)
+
+py_test(
+ name = "iterator_ops_test",
+ size = "small",
+ srcs = ["iterator_ops_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/data/experimental/ops:iterator_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/estimator:estimator_py",
+ ],
+)
+
+py_test(
+ name = "map_dataset_op_test",
+ size = "medium",
+ srcs = ["map_dataset_op_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_pip",
+ "noasan", # times out
+ "optonly",
+ ],
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:io_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/experimental/ops:batching",
+ "//tensorflow/python/data/experimental/ops:error_ops",
+ "//tensorflow/python/data/experimental/ops:optimization",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "filter_dataset_op_test",
+ size = "medium",
+ srcs = ["filter_dataset_op_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:io_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/experimental/ops:optimization",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "map_defun_op_test",
+ size = "small",
+ srcs = ["map_defun_op_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:check_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:data_flow_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:function",
+ "//tensorflow/python:functional_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:session",
+ "//tensorflow/python/data/experimental/ops:map_defun",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ ],
+)
+
+py_test(
+ name = "parsing_ops_test",
+ size = "small",
+ srcs = ["parsing_ops_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/experimental/ops:parsing_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ "//third_party/py/numpy",
+ ],
+)
+
+cuda_py_test(
+ name = "prefetching_ops_test",
+ size = "small",
+ srcs = ["prefetching_ops_test.py"],
+ additional_deps = [
+ "//tensorflow/python/data/experimental/ops:prefetching_ops",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:function",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python/compat:compat",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
+ ],
+ tags = ["no_windows_gpu"],
+)
+
+py_test(
+ name = "range_dataset_op_test",
+ size = "small",
+ srcs = ["range_dataset_op_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/experimental/ops:counter",
+ "//tensorflow/python/data/experimental/ops:enumerate_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_library(
+ name = "reader_dataset_ops_test_base",
+ testonly = 1,
+ srcs = [
+ "reader_dataset_ops_test_base.py",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = [
+ "//tensorflow/python/data/experimental/kernel_tests:__pkg__",
+ "//tensorflow/python/data/experimental/kernel_tests/serialization:__pkg__",
+ ],
+ deps = [
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:lib",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/experimental/ops:readers",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python/data/ops:readers",
+ ],
+)
+
+py_test(
+ name = "reader_dataset_ops_test",
+ size = "medium",
+ srcs = ["reader_dataset_ops_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":reader_dataset_ops_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:string_ops",
+ "//tensorflow/python/data/experimental/ops:readers",
+ "//tensorflow/python/data/ops:readers",
+ "//tensorflow/python/data/util:nest",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "resample_test",
+ size = "medium",
+ srcs = ["resample_test.py"],
+ shard_count = 2,
+ srcs_version = "PY2AND3",
+ tags = [
+ "noasan",
+ "optonly",
+ ],
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:random_ops",
+ "//tensorflow/python:string_ops",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/experimental/ops:resampling",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
+ "@six_archive//:six",
+ ],
+)
+
+py_test(
+ name = "scan_dataset_op_test",
+ size = "small",
+ srcs = ["scan_dataset_op_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/experimental/ops:scan_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/eager:context",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "shuffle_dataset_op_test",
+ size = "medium",
+ srcs = ["shuffle_dataset_op_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_pip",
+ "optonly",
+ ],
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python/data/experimental/ops:shuffle_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_library(
+ name = "sql_dataset_op_test_base",
+ srcs = ["sql_dataset_op_test_base.py"],
+ srcs_version = "PY2AND3",
+ visibility = [
+ "//tensorflow/python/data/experimental/kernel_tests:__pkg__",
+ "//tensorflow/python/data/experimental/kernel_tests/serialization:__pkg__",
+ ],
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python/data/experimental/ops:readers",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "@org_sqlite//:python",
+ ],
+)
+
+py_test(
+ name = "sql_dataset_op_test",
+ size = "small",
+ srcs = ["sql_dataset_op_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":sql_dataset_op_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ ],
+)
+
+py_test(
+ name = "stats_dataset_ops_test",
+ size = "medium",
+ srcs = ["stats_dataset_ops_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":reader_dataset_ops_test_base",
+ ":stats_dataset_test_base",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python/data/experimental/ops:stats_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_library(
+ name = "stats_dataset_test_base",
+ srcs = ["stats_dataset_test_base.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ ],
+)
+
+py_test(
+ name = "threadpool_dataset_ops_test",
+ size = "small",
+ srcs = ["threadpool_dataset_ops_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:script_ops",
+ "//tensorflow/python/data/experimental/ops:threadpool",
+ "//tensorflow/python/data/experimental/ops:unique",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+py_test(
+ name = "unique_dataset_op_test",
+ size = "small",
+ srcs = ["unique_dataset_op_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/experimental/ops:unique",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "writer_ops_test",
+ size = "small",
+ srcs = ["writer_ops_test.py"],
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:lib",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/experimental/ops:writers",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:readers",
+ ],
+)
diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/batch_dataset_op_test.py
index fed7de5f2b..8703b2810e 100644
--- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/batch_dataset_op_test.py
@@ -23,8 +23,8 @@ import time
from absl.testing import parameterized
import numpy as np
-from tensorflow.contrib.data.python.ops import batching
from tensorflow.python.client import session
+from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
@@ -32,7 +32,6 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import script_ops
@@ -43,7 +42,6 @@ from tensorflow.python.util import compat
class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
-
def testDenseToSparseBatchDataset(self):
components = np.random.randint(12, size=(100,)).astype(np.int32)
iterator = (
@@ -302,128 +300,6 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.assertRaises(errors.InvalidArgumentError):
sess.run(next_element)
- def testBatchAndDropRemainder(self):
- components = (np.arange(7),
- np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
- np.array(37.0) * np.arange(7))
-
- batch_size = array_ops.placeholder(dtypes.int64, shape=[])
-
- iterator = (
- dataset_ops.Dataset.from_tensor_slices(components).apply(
- batching.batch_and_drop_remainder(batch_size))
- .make_initializable_iterator())
-
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- for test_batch_size in [1, 3, 7, 10]:
- sess.run(iterator.initializer, feed_dict={batch_size: test_batch_size})
- num_batches = 7 // test_batch_size
- for i in range(num_batches):
- result = sess.run(next_element)
- for component, result_component in zip(components, result):
- for j in range(test_batch_size):
- self.assertAllEqual(component[(i * test_batch_size + j)],
- result_component[j])
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testBatchAndDropRemainderSparse(self):
-
- def _sparse(i):
- return sparse_tensor.SparseTensorValue(
- indices=[[0]], values=(i * [1]), dense_shape=[1])
-
- iterator = dataset_ops.Dataset.range(12).map(_sparse).apply(
- batching.batch_and_drop_remainder(5)).make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
- for i in range(2):
- actual = sess.run(get_next)
- expected = sparse_tensor.SparseTensorValue(
- indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
- values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4],
- dense_shape=[5, 1])
- self.assertTrue(sparse_tensor.is_sparse(actual))
- self.assertSparseValuesEqual(actual, expected)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testPaddedBatchAndDropRemainder(self):
- els = []
- for length in [3, 6, 9, 4, 12, 10, 2]:
- els.append((np.array(length), np.arange(length) + 1,
- np.array(length * 2)))
-
- dataset = dataset_ops.Dataset.from_tensors(els[0])
- for el in els[1:]:
- dataset = dataset.concatenate(dataset_ops.Dataset.from_tensors(el))
-
- batch_size = array_ops.placeholder(dtypes.int64, shape=[])
- iterator = (
- dataset.apply(
- batching.padded_batch_and_drop_remainder(
- batch_size, ([], [None], []))).make_initializable_iterator())
-
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- for test_batch_size in [1, 3, 7, 10]:
- sess.run(iterator.initializer, feed_dict={batch_size: test_batch_size})
- num_batches = 7 // test_batch_size
- for i in range(num_batches):
- result = sess.run(next_element)
- for component_idx, result_component in enumerate(result):
- for j in range(test_batch_size):
- data_idx = i * test_batch_size + j
- comp = result_component[j]
- unpadded = comp[comp > 0]
- if np.isscalar(comp):
- # The boolean mask indexing above adds a dim back. Rm it.
- unpadded = unpadded[0]
- self.assertAllEqual(els[data_idx][component_idx], unpadded)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testPaddedBatchAndDropRemainderSparseError(self):
-
- def _map_fn(i):
- return sparse_tensor.SparseTensorValue(
- indices=[[0, 0]], values=(i * [1]), dense_shape=[1, 1]), i
-
- with self.assertRaises(TypeError):
- _ = dataset_ops.Dataset.range(10).map(_map_fn).apply(
- batching.padded_batch_and_drop_remainder(5))
-
- def testBatchAndDropRemainderShapeInference(self):
- components = (array_ops.placeholder(dtypes.int32),
- (array_ops.placeholder(dtypes.int32, shape=[None]),
- array_ops.placeholder(dtypes.int32, shape=[20, 30])))
-
- # Test with a statically known batch size.
- dataset = (
- dataset_ops.Dataset.from_tensor_slices(components).apply(
- batching.batch_and_drop_remainder(128)))
-
- self.assertIs(None, dataset.output_shapes[0].ndims)
- self.assertEqual([128], dataset.output_shapes[1][0].as_list())
- self.assertEqual([128, 30], dataset.output_shapes[1][1].as_list())
-
- # Test with a dynamic batch size: the static shape will be unknown, because
- # `batch_size` is a placeholder.
- batch_size = array_ops.placeholder(dtypes.int64)
- dataset = (
- dataset_ops.Dataset.from_tensor_slices(components).apply(
- batching.batch_and_drop_remainder(batch_size)))
-
- self.assertIs(None, dataset.output_shapes[0].ndims)
- self.assertEqual([None], dataset.output_shapes[1][0].as_list())
- self.assertEqual([None, 30], dataset.output_shapes[1][1].as_list())
-
@parameterized.named_parameters(
("Default", None, None),
("SequentialCalls", 1, None),
@@ -720,197 +596,6 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertAllEqual([element for _ in range(10)], sess.run(get_next))
-class RestructuredDatasetTest(test_base.DatasetTestBase):
-
- def test_assert_element_shape(self):
-
- def create_dataset(_):
- return (array_ops.ones(2, dtype=dtypes.float32),
- array_ops.zeros((3, 4), dtype=dtypes.int32))
-
- dataset = dataset_ops.Dataset.range(5).map(create_dataset)
- expected_shapes = (tensor_shape.TensorShape(2),
- tensor_shape.TensorShape((3, 4)))
- self.assertEqual(expected_shapes, dataset.output_shapes)
-
- result = dataset.apply(batching.assert_element_shape(expected_shapes))
- self.assertEqual(expected_shapes, result.output_shapes)
-
- iterator = result.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- sess.run(init_op)
- for _ in range(5):
- sess.run(get_next)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def test_assert_wrong_element_shape(self):
-
- def create_dataset(_):
- return (array_ops.ones(2, dtype=dtypes.float32),
- array_ops.zeros((3, 4), dtype=dtypes.int32))
-
- dataset = dataset_ops.Dataset.range(3).map(create_dataset)
- wrong_shapes = (tensor_shape.TensorShape(2),
- tensor_shape.TensorShape((3, 10)))
- with self.assertRaises(ValueError):
- dataset.apply(batching.assert_element_shape(wrong_shapes))
-
- def test_assert_element_shape_on_unknown_shape_dataset(self):
-
- def create_unknown_shape_dataset(x):
- return script_ops.py_func(
- lambda _: ( # pylint: disable=g-long-lambda
- np.ones(2, dtype=np.float32),
- np.zeros((3, 4), dtype=np.int32)),
- [x],
- [dtypes.float32, dtypes.int32])
-
- dataset = dataset_ops.Dataset.range(5).map(create_unknown_shape_dataset)
- unknown_shapes = (tensor_shape.TensorShape(None),
- tensor_shape.TensorShape(None))
- self.assertEqual(unknown_shapes, dataset.output_shapes)
-
- expected_shapes = (tensor_shape.TensorShape(2),
- tensor_shape.TensorShape((3, 4)))
- result = dataset.apply(batching.assert_element_shape(expected_shapes))
- self.assertEqual(expected_shapes, result.output_shapes)
-
- iterator = result.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- sess.run(init_op)
- for _ in range(5):
- sess.run(get_next)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def test_assert_wrong_element_shape_on_unknown_shape_dataset(self):
-
- def create_unknown_shape_dataset(x):
- return script_ops.py_func(
- lambda _: ( # pylint: disable=g-long-lambda
- np.ones(2, dtype=np.float32),
- np.zeros((3, 4), dtype=np.int32)),
- [x],
- [dtypes.float32, dtypes.int32])
-
- dataset = dataset_ops.Dataset.range(3).map(create_unknown_shape_dataset)
- unknown_shapes = (tensor_shape.TensorShape(None),
- tensor_shape.TensorShape(None))
- self.assertEqual(unknown_shapes, dataset.output_shapes)
-
- wrong_shapes = (tensor_shape.TensorShape(2),
- tensor_shape.TensorShape((3, 10)))
- iterator = (
- dataset.apply(batching.assert_element_shape(wrong_shapes))
- .make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- sess.run(init_op)
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(get_next)
-
- def test_assert_partial_element_shape(self):
-
- def create_dataset(_):
- return (array_ops.ones(2, dtype=dtypes.float32),
- array_ops.zeros((3, 4), dtype=dtypes.int32))
-
- dataset = dataset_ops.Dataset.range(5).map(create_dataset)
- partial_expected_shape = (tensor_shape.TensorShape(None), # Unknown shape
- tensor_shape.TensorShape((None, 4))) # Partial shape
- result = dataset.apply(
- batching.assert_element_shape(partial_expected_shape))
- # Partial shapes are merged with actual shapes:
- actual_shapes = (tensor_shape.TensorShape(2),
- tensor_shape.TensorShape((3, 4)))
- self.assertEqual(actual_shapes, result.output_shapes)
-
- iterator = result.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- sess.run(init_op)
- for _ in range(5):
- sess.run(get_next)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def test_assert_wrong_partial_element_shape(self):
-
- def create_dataset(_):
- return (array_ops.ones(2, dtype=dtypes.float32),
- array_ops.zeros((3, 4), dtype=dtypes.int32))
-
- dataset = dataset_ops.Dataset.range(3).map(create_dataset)
- wrong_shapes = (tensor_shape.TensorShape(2),
- tensor_shape.TensorShape((None, 10)))
- with self.assertRaises(ValueError):
- dataset.apply(batching.assert_element_shape(wrong_shapes))
-
- def test_assert_partial_element_shape_on_unknown_shape_dataset(self):
-
- def create_unknown_shape_dataset(x):
- return script_ops.py_func(
- lambda _: ( # pylint: disable=g-long-lambda
- np.ones(2, dtype=np.float32),
- np.zeros((3, 4), dtype=np.int32)),
- [x],
- [dtypes.float32, dtypes.int32])
-
- dataset = dataset_ops.Dataset.range(5).map(create_unknown_shape_dataset)
- unknown_shapes = (tensor_shape.TensorShape(None),
- tensor_shape.TensorShape(None))
- self.assertEqual(unknown_shapes, dataset.output_shapes)
-
- expected_shapes = (tensor_shape.TensorShape(2),
- tensor_shape.TensorShape((None, 4)))
- result = dataset.apply(batching.assert_element_shape(expected_shapes))
- self.assertEqual(expected_shapes, result.output_shapes)
-
- iterator = result.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- sess.run(init_op)
- for _ in range(5):
- sess.run(get_next)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def test_assert_wrong_partial_element_shape_on_unknown_shape_dataset(self):
-
- def create_unknown_shape_dataset(x):
- return script_ops.py_func(
- lambda _: ( # pylint: disable=g-long-lambda
- np.ones(2, dtype=np.float32),
- np.zeros((3, 4), dtype=np.int32)),
- [x],
- [dtypes.float32, dtypes.int32])
-
- dataset = dataset_ops.Dataset.range(3).map(create_unknown_shape_dataset)
- unknown_shapes = (tensor_shape.TensorShape(None),
- tensor_shape.TensorShape(None))
- self.assertEqual(unknown_shapes, dataset.output_shapes)
-
- wrong_shapes = (tensor_shape.TensorShape(2),
- tensor_shape.TensorShape((None, 10)))
- iterator = (
- dataset.apply(batching.assert_element_shape(wrong_shapes))
- .make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- sess.run(init_op)
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(get_next)
-
-
class UnbatchDatasetBenchmark(test.Benchmark):
def benchmarkNativeUnbatch(self):
diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/python/data/experimental/kernel_tests/bucketing_test.py
index ae401f786c..153a03989b 100644
--- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/bucketing_test.py
@@ -21,7 +21,7 @@ import random
import numpy as np
-from tensorflow.contrib.data.python.ops import grouping
+from tensorflow.python.data.experimental.ops import grouping
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
diff --git a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/csv_dataset_op_test.py
index 5b3c512b64..4ee1779710 100644
--- a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/csv_dataset_op_test.py
@@ -27,9 +27,9 @@ import zlib
import numpy as np
-from tensorflow.contrib.data.python.ops import error_ops
-from tensorflow.contrib.data.python.ops import readers
from tensorflow.python.client import session
+from tensorflow.python.data.experimental.ops import error_ops
+from tensorflow.python.data.experimental.ops import readers
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import readers as core_readers
from tensorflow.python.eager import context
diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py b/tensorflow/python/data/experimental/kernel_tests/dataset_constructor_op_test.py
index 722e87e555..3fc7157bc5 100644
--- a/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/dataset_constructor_op_test.py
@@ -17,7 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.ops import batching
+from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py b/tensorflow/python/data/experimental/kernel_tests/dataset_serialization_test_base.py
index 595cecef4d..7f435b8239 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py
+++ b/tensorflow/python/data/experimental/kernel_tests/dataset_serialization_test_base.py
@@ -22,7 +22,7 @@ import os
import numpy as np
-from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops
+from tensorflow.python.data.experimental.ops import iterator_ops as contrib_iterator_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
diff --git a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/directed_interleave_dataset_test.py
index bc10c21472..796a692c56 100644
--- a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/directed_interleave_dataset_test.py
@@ -19,7 +19,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.data.python.ops import interleave_ops
+from tensorflow.python.data.experimental.ops import interleave_ops
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
@@ -84,7 +84,7 @@ class DirectedInterleaveDatasetTest(test_base.DatasetTestBase):
# Use chi-squared test to assert that the observed distribution matches the
# expected distribution. Based on the implementation in
- # "tensorflow/python/kernel_tests/multinomial_op_test.py".
+ # "third_party/tensorflow/python/kernel_tests/multinomial_op_test.py".
for probs in [[.85, .05, .1], rand_probs, [1.]]:
probs = np.asarray(probs)
classes = len(probs)
diff --git a/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/filter_dataset_op_test.py
index 6d01bf585c..c6ee88c676 100644
--- a/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/filter_dataset_op_test.py
@@ -21,8 +21,8 @@ import time
import numpy as np
-from tensorflow.contrib.data.python.ops import optimization
from tensorflow.python.client import session
+from tensorflow.python.data.experimental.ops import optimization
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
diff --git a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py b/tensorflow/python/data/experimental/kernel_tests/get_single_element_test.py
index cc22ea1df7..8c07afbac5 100644
--- a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/get_single_element_test.py
@@ -18,10 +18,8 @@ from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
-import numpy as np
-from tensorflow.contrib.data.python.ops import get_single_element
-from tensorflow.contrib.data.python.ops import grouping
+from tensorflow.python.data.experimental.ops import get_single_element
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
@@ -69,32 +67,6 @@ class GetSingleElementTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.assertRaisesRegexp(error, error_msg):
sess.run(element, feed_dict={skip_t: skip, take_t: take})
- @parameterized.named_parameters(
- ("SumZero", 0),
- ("SumOne", 1),
- ("SumFive", 5),
- ("SumTen", 10),
- )
- def testReduceDataset(self, stop):
- def init_fn(_):
- return np.int64(0)
-
- def reduce_fn(state, value):
- return state + value
-
- def finalize_fn(state):
- return state
-
- sum_reducer = grouping.Reducer(init_fn, reduce_fn, finalize_fn)
-
- stop_t = array_ops.placeholder(dtypes.int64, shape=[])
- dataset = dataset_ops.Dataset.range(stop_t)
- element = get_single_element.reduce_dataset(dataset, sum_reducer)
-
- with self.cached_session() as sess:
- value = sess.run(element, feed_dict={stop_t: stop})
- self.assertEqual(stop * (stop - 1) / 2, value)
-
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/indexed_dataset_ops_test.py
index d4d3d4adb2..c93a8353ce 100644
--- a/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/indexed_dataset_ops_test.py
@@ -19,7 +19,7 @@ from __future__ import print_function
import unittest
-from tensorflow.contrib.data.python.ops import indexed_dataset_ops
+from tensorflow.python.data.experimental.ops import indexed_dataset_ops
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
diff --git a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/interleave_dataset_op_test.py
index 28bd670ab5..560902caad 100644
--- a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/interleave_dataset_op_test.py
@@ -24,7 +24,7 @@ import time
from six.moves import zip_longest
-from tensorflow.contrib.data.python.ops import interleave_ops
+from tensorflow.python.data.experimental.ops import interleave_ops
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
diff --git a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/iterator_ops_test.py
index 58a1d7c93b..94393d6d4b 100644
--- a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/iterator_ops_test.py
@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.ops import iterator_ops
+from tensorflow.python.data.experimental.ops import iterator_ops
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator import estimator
diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/map_dataset_op_test.py
index 385c4ef6ea..2f0bd1456b 100644
--- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/map_dataset_op_test.py
@@ -24,11 +24,11 @@ import time
import numpy as np
-from tensorflow.contrib.data.python.ops import batching
-from tensorflow.contrib.data.python.ops import error_ops
-from tensorflow.contrib.data.python.ops import optimization
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
+from tensorflow.python.data.experimental.ops import batching
+from tensorflow.python.data.experimental.ops import error_ops
+from tensorflow.python.data.experimental.ops import optimization
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
diff --git a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py b/tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py
index 751e6d5b30..612ee332c4 100644
--- a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py
@@ -19,8 +19,8 @@ from __future__ import print_function
import time
-from tensorflow.contrib.data.python.ops import map_defun
from tensorflow.python.client import session
+from tensorflow.python.data.experimental.ops import map_defun
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD b/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD
index d7b5edcd9a..68f73bddb5 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD
@@ -12,9 +12,9 @@ py_test(
srcs = ["assert_next_dataset_op_test.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/contrib/data/python/ops:optimization",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
+ "//tensorflow/python/data/experimental/ops:optimization",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
],
@@ -26,12 +26,12 @@ py_test(
srcs = ["hoist_random_uniform_test.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/contrib/data/python/ops:optimization",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:math_ops",
+ "//tensorflow/python/data/experimental/ops:optimization",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"@absl_py//absl/testing:parameterized",
@@ -44,11 +44,11 @@ py_test(
srcs = ["latency_all_edges_test.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/contrib/data/python/kernel_tests:stats_dataset_test_base",
- "//tensorflow/contrib/data/python/ops:optimization",
- "//tensorflow/contrib/data/python/ops:stats_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
+ "//tensorflow/python/data/experimental/kernel_tests:stats_dataset_test_base",
+ "//tensorflow/python/data/experimental/ops:optimization",
+ "//tensorflow/python/data/experimental/ops:stats_ops",
"//tensorflow/python/data/ops:dataset_ops",
],
)
@@ -59,7 +59,6 @@ py_test(
srcs = ["map_vectorization_test.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/contrib/data/python/ops:optimization",
"//tensorflow/python:check_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
@@ -68,6 +67,7 @@ py_test(
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:session",
+ "//tensorflow/python/data/experimental/ops:optimization",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
@@ -81,12 +81,12 @@ py_test(
srcs = ["map_and_filter_fusion_test.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/contrib/data/python/ops:optimization",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:math_ops",
+ "//tensorflow/python/data/experimental/ops:optimization",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"@absl_py//absl/testing:parameterized",
@@ -99,12 +99,12 @@ py_test(
srcs = ["map_parallelization_test.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/contrib/data/python/ops:optimization",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:math_ops",
+ "//tensorflow/python/data/experimental/ops:optimization",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"@absl_py//absl/testing:parameterized",
@@ -120,11 +120,11 @@ py_test(
"optonly",
],
deps = [
- "//tensorflow/contrib/data/python/ops:batching",
- "//tensorflow/contrib/data/python/ops:interleave_ops",
- "//tensorflow/contrib/data/python/ops:optimization",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
+ "//tensorflow/python/data/experimental/ops:batching",
+ "//tensorflow/python/data/experimental/ops:interleave_ops",
+ "//tensorflow/python/data/experimental/ops:optimization",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
@@ -137,11 +137,11 @@ py_test(
srcs = ["noop_elimination_test.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/contrib/data/python/ops:batching",
- "//tensorflow/contrib/data/python/ops:interleave_ops",
- "//tensorflow/contrib/data/python/ops:optimization",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
+ "//tensorflow/python/data/experimental/ops:batching",
+ "//tensorflow/python/data/experimental/ops:interleave_ops",
+ "//tensorflow/python/data/experimental/ops:optimization",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
@@ -154,9 +154,9 @@ py_test(
srcs = ["optimize_dataset_op_test.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/contrib/data/python/ops:optimization",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
+ "//tensorflow/python/data/experimental/ops:optimization",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/assert_next_dataset_op_test.py
index fe1b5280ba..45b77b5c20 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/assert_next_dataset_op_test.py
@@ -17,7 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.experimental.ops import optimization
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/hoist_random_uniform_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/hoist_random_uniform_test.py
index b43efb5c7c..3cd9753665 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/hoist_random_uniform_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/hoist_random_uniform_test.py
@@ -19,7 +19,7 @@ from __future__ import print_function
from absl.testing import parameterized
-from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.experimental.ops import optimization
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/latency_all_edges_test.py
index e4f18222fd..45623876ae 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/latency_all_edges_test.py
@@ -17,9 +17,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.kernel_tests import stats_dataset_test_base
-from tensorflow.contrib.data.python.ops import optimization
-from tensorflow.contrib.data.python.ops import stats_ops
+from tensorflow.python.data.experimental.kernel_tests import stats_dataset_test_base
+from tensorflow.python.data.experimental.ops import optimization
+from tensorflow.python.data.experimental.ops import stats_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/map_and_filter_fusion_test.py
index e9e3fc81e5..a439635716 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/map_and_filter_fusion_test.py
@@ -19,7 +19,7 @@ from __future__ import print_function
from absl.testing import parameterized
-from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.experimental.ops import optimization
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/map_parallelization_test.py
index f7907eb890..334d8e3778 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/map_parallelization_test.py
@@ -19,7 +19,7 @@ from __future__ import print_function
from absl.testing import parameterized
-from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.experimental.ops import optimization
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py
index a5ea85f454..d47492753e 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py
@@ -22,8 +22,8 @@ import time
from absl.testing import parameterized
import numpy as np
-from tensorflow.contrib.data.python.ops import optimization
from tensorflow.python.client import session
+from tensorflow.python.data.experimental.ops import optimization
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/model_dataset_op_test.py
index 33c250ab2a..a9f2ce8c03 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/model_dataset_op_test.py
@@ -21,8 +21,8 @@ import time
import numpy as np
-from tensorflow.contrib.data.python.ops import batching
-from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.experimental.ops import batching
+from tensorflow.python.data.experimental.ops import optimization
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.ops import math_ops
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/noop_elimination_test.py
index b9e60cfa4e..092e0ff62a 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/noop_elimination_test.py
@@ -17,7 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.experimental.ops import optimization
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_op_test.py
index 04f499f8c5..eb661796c0 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_op_test.py
@@ -19,7 +19,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.experimental.ops import optimization
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
diff --git a/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/parsing_ops_test.py
index 66ccaceea5..13f924b656 100644
--- a/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/parsing_ops_test.py
@@ -22,9 +22,9 @@ import copy
import numpy as np
-from tensorflow.contrib.data.python.ops import parsing_ops as contrib_parsing_ops
from tensorflow.core.example import example_pb2
from tensorflow.core.example import feature_pb2
+from tensorflow.python.data.experimental.ops import parsing_ops as contrib_parsing_ops
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
@@ -846,6 +846,5 @@ class ParseExampleTest(test_base.DatasetTestBase):
"allow_missing to be True."))
-
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/prefetching_ops_test.py
index 7a6a7a709a..7d7b842c17 100644
--- a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/prefetching_ops_test.py
@@ -19,9 +19,9 @@ from __future__ import print_function
import threading
-from tensorflow.contrib.data.python.ops import prefetching_ops
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.compat import compat
+from tensorflow.python.data.experimental.ops import prefetching_ops
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
diff --git a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/range_dataset_op_test.py
index 2e901587f4..22412c3965 100644
--- a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/range_dataset_op_test.py
@@ -17,8 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.ops import counter
-from tensorflow.contrib.data.python.ops import enumerate_ops
+from tensorflow.python.data.experimental.ops import counter
+from tensorflow.python.data.experimental.ops import enumerate_ops
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/reader_dataset_ops_test.py
index 66ed547b6d..a02f4bd14f 100644
--- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/reader_dataset_ops_test.py
@@ -23,8 +23,8 @@ import zlib
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base
-from tensorflow.contrib.data.python.ops import readers
+from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
+from tensorflow.python.data.experimental.ops import readers
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import readers as core_readers
from tensorflow.python.data.util import nest
diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py b/tensorflow/python/data/experimental/kernel_tests/reader_dataset_ops_test_base.py
index f443b5501b..b6ab80d132 100644
--- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py
+++ b/tensorflow/python/data/experimental/kernel_tests/reader_dataset_ops_test_base.py
@@ -22,9 +22,9 @@ import gzip
import os
import zlib
-from tensorflow.contrib.data.python.ops import readers
from tensorflow.core.example import example_pb2
from tensorflow.core.example import feature_pb2
+from tensorflow.python.data.experimental.ops import readers
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.ops import readers as core_readers
diff --git a/tensorflow/contrib/data/python/kernel_tests/resample_test.py b/tensorflow/python/data/experimental/kernel_tests/resample_test.py
index 32474bd411..775648c943 100644
--- a/tensorflow/contrib/data/python/kernel_tests/resample_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/resample_test.py
@@ -23,7 +23,7 @@ from absl.testing import parameterized
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
-from tensorflow.contrib.data.python.ops import resampling
+from tensorflow.python.data.experimental.ops import resampling
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
diff --git a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/scan_dataset_op_test.py
index bdf80eae4e..78ec80de23 100644
--- a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/scan_dataset_op_test.py
@@ -21,7 +21,7 @@ import itertools
import numpy as np
-from tensorflow.contrib.data.python.ops import scan_ops
+from tensorflow.python.data.experimental.ops import scan_ops
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD b/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD
index aa89674c6e..20c02a5366 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD
@@ -13,7 +13,6 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/contrib/data/python/ops:iterator_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
@@ -24,6 +23,7 @@ py_library(
"//tensorflow/python:training",
"//tensorflow/python:util",
"//tensorflow/python:variables",
+ "//tensorflow/python/data/experimental/ops:iterator_ops",
"//tensorflow/python/data/ops:iterator_ops",
"//third_party/py/numpy",
],
@@ -37,10 +37,10 @@ py_test(
tags = ["no_pip"],
deps = [
":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:batching",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/experimental/ops:batching",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
@@ -81,9 +81,9 @@ py_test(
tags = ["no_pip"],
deps = [
":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:readers",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python/data/experimental/ops:readers",
],
)
@@ -126,8 +126,8 @@ py_test(
tags = ["no_pip"],
deps = [
":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/kernel_tests:reader_dataset_ops_test_base",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/kernel_tests:reader_dataset_ops_test_base",
"//tensorflow/python/data/ops:readers",
],
)
@@ -160,8 +160,8 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:grouping",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/ops:grouping",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
@@ -174,8 +174,8 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:grouping",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/ops:grouping",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
@@ -189,9 +189,9 @@ py_test(
tags = ["no_pip"],
deps = [
":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:error_ops",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/ops:error_ops",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
@@ -222,9 +222,9 @@ py_test(
tags = ["no_pip"],
deps = [
":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:batching",
"//tensorflow/python:client_testlib",
"//tensorflow/python:math_ops",
+ "//tensorflow/python/data/experimental/ops:batching",
"//tensorflow/python/data/ops:dataset_ops",
],
)
@@ -258,8 +258,8 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:optimization",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/ops:optimization",
"//tensorflow/python/data/ops:dataset_ops",
],
)
@@ -288,10 +288,10 @@ py_test(
tags = ["no_pip"],
deps = [
":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:interleave_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:sparse_ops",
"//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/experimental/ops:interleave_ops",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
@@ -326,8 +326,8 @@ py_test(
tags = ["no_pip"],
deps = [
":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/kernel_tests:reader_dataset_ops_test_base",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/kernel_tests:reader_dataset_ops_test_base",
],
)
@@ -370,8 +370,8 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:interleave_ops",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/ops:interleave_ops",
"//tensorflow/python/data/ops:dataset_ops",
],
)
@@ -384,8 +384,8 @@ py_test(
tags = ["no_pip"],
deps = [
":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:scan_ops",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/ops:scan_ops",
"//tensorflow/python/data/ops:dataset_ops",
],
)
@@ -411,10 +411,10 @@ py_test(
srcs_version = "PY2AND3",
tags = ["no_pip"],
deps = [
- "//tensorflow/contrib/data/python/ops:iterator_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_ops",
"//tensorflow/python:training",
+ "//tensorflow/python/data/experimental/ops:iterator_ops",
"//tensorflow/python/data/ops:dataset_ops",
],
)
@@ -427,8 +427,8 @@ py_test(
tags = ["no_pip"],
deps = [
":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:shuffle_ops",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/ops:shuffle_ops",
"//tensorflow/python/data/ops:dataset_ops",
],
)
@@ -441,10 +441,10 @@ py_test(
tags = ["no_pip"],
deps = [
":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:iterator_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_ops",
"//tensorflow/python:training",
+ "//tensorflow/python/data/experimental/ops:iterator_ops",
"//tensorflow/python/data/ops:dataset_ops",
],
)
@@ -457,11 +457,11 @@ py_test(
tags = ["no_pip"],
deps = [
":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/kernel_tests:sql_dataset_op_test_base",
- "//tensorflow/contrib/data/python/ops:readers",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
+ "//tensorflow/python/data/experimental/kernel_tests:sql_dataset_op_test_base",
+ "//tensorflow/python/data/experimental/ops:readers",
],
)
@@ -473,10 +473,10 @@ py_test(
tags = ["no_pip"],
deps = [
":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:stats_ops",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python/data/experimental/ops:stats_ops",
"//tensorflow/python/data/ops:dataset_ops",
],
)
@@ -490,8 +490,8 @@ py_test(
tags = ["no_pip"],
deps = [
":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/kernel_tests:reader_dataset_ops_test_base",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/kernel_tests:reader_dataset_ops_test_base",
"//tensorflow/python/data/ops:readers",
],
)
@@ -505,8 +505,8 @@ py_test(
tags = ["no_pip"],
deps = [
":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/kernel_tests:reader_dataset_ops_test_base",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/kernel_tests:reader_dataset_ops_test_base",
"//tensorflow/python/data/ops:readers",
],
)
@@ -519,8 +519,8 @@ py_test(
tags = ["no_pip"],
deps = [
":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:batching",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/ops:batching",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
@@ -534,8 +534,8 @@ py_test(
tags = ["no_pip"],
deps = [
":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:unique",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/ops:unique",
"//tensorflow/python/data/ops:dataset_ops",
],
)
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/batch_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/batch_dataset_serialization_test.py
index af87d8b608..d72a6df14c 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/batch_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/batch_dataset_serialization_test.py
@@ -19,8 +19,8 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import batching
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/cache_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/cache_dataset_serialization_test.py
index 1b6059ccbc..2bcf77f5d8 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/cache_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/cache_dataset_serialization_test.py
@@ -21,7 +21,7 @@ import os
from absl.testing import parameterized
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/concatenate_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/concatenate_dataset_serialization_test.py
index 96f13d75a3..c075dff8cb 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/concatenate_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/concatenate_dataset_serialization_test.py
@@ -19,7 +19,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/csv_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/csv_dataset_serialization_test.py
index 247f2046ea..d4983492e7 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/csv_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/csv_dataset_serialization_test.py
@@ -20,8 +20,8 @@ from __future__ import print_function
import gzip
import os
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import readers
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import readers
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_constructor_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_constructor_serialization_test.py
index 2139b5c33d..41a095fb1a 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_constructor_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_constructor_serialization_test.py
@@ -19,7 +19,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.platform import test
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_serialization_test_base.py b/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_serialization_test_base.py
new file mode 100644
index 0000000000..7f435b8239
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_serialization_test_base.py
@@ -0,0 +1,692 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Base class for testing serializable datasets."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+import numpy as np
+
+from tensorflow.python.data.experimental.ops import iterator_ops as contrib_iterator_ops
+from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import lookup_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import test
+from tensorflow.python.training import checkpoint_management
+from tensorflow.python.training import saver as saver_lib
+from tensorflow.python.util import nest
+
+
+def remove_variants(get_next_op):
+ # TODO(b/72408568): Remove this once session.run can get
+ # variant tensors.
+ """Remove variants from a nest structure, so sess.run will execute."""
+
+ def _remove_variant(x):
+ if isinstance(x, ops.Tensor) and x.dtype == dtypes.variant:
+ return ()
+ else:
+ return x
+
+ return nest.map_structure(_remove_variant, get_next_op)
+
+
+class DatasetSerializationTestBase(test.TestCase):
+ """Base class for testing serializable datasets."""
+
+ def tearDown(self):
+ self._delete_ckpt()
+
+ # TODO(b/72657739): Remove sparse_tensor argument, which is to test the
+ # (deprecated) saveable `SparseTensorSliceDataset`, once the API
+ # `from_sparse_tensor_slices()`and related tests are deleted.
+ def run_core_tests(self, ds_fn1, ds_fn2, num_outputs, sparse_tensors=False):
+ """Runs the core tests.
+
+ Args:
+ ds_fn1: 0-argument function that returns a Dataset.
+ ds_fn2: 0-argument function that returns a Dataset different from
+ ds_fn1. If None, verify_restore_in_modified_graph test is not run.
+ num_outputs: Total number of outputs expected from this Dataset.
+ sparse_tensors: Whether dataset is built from SparseTensor(s).
+
+ Raises:
+ AssertionError if any test fails.
+ """
+ self.verify_unused_iterator(
+ ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
+ self.verify_fully_used_iterator(
+ ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
+ self.verify_exhausted_iterator(
+ ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
+ self.verify_init_before_restore(
+ ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
+ self.verify_multiple_breaks(
+ ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
+ self.verify_reset_restored_iterator(
+ ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
+ self.verify_restore_in_empty_graph(
+ ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
+ if ds_fn2:
+ self.verify_restore_in_modified_graph(
+ ds_fn1, ds_fn2, num_outputs, sparse_tensors=sparse_tensors)
+
+ def verify_unused_iterator(self,
+ ds_fn,
+ num_outputs,
+ sparse_tensors=False,
+ verify_exhausted=True):
+ """Verifies that saving and restoring an unused iterator works.
+
+ Args:
+ ds_fn: See `run_core_tests`.
+ num_outputs: See `run_core_tests`.
+ sparse_tensors: See `run_core_tests`.
+ verify_exhausted: See `gen_outputs`.
+
+ Raises:
+ AssertionError if any test fails.
+ """
+ self.verify_run_with_breaks(
+ ds_fn, [0],
+ num_outputs,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=verify_exhausted)
+
+ def verify_fully_used_iterator(self, ds_fn, num_outputs,
+ sparse_tensors=False):
+ """Verifies that saving and restoring a fully used iterator works.
+
+ Note that this only checks saving and restoring an iterator from which
+ `num_outputs` items have been produced but does not check for an
+ exhausted iterator, i.e., one from which an OutOfRange error has been
+ returned.
+
+ Args:
+ ds_fn: See `run_core_tests`.
+ num_outputs: See `run_core_tests`.
+ sparse_tensors: See `run_core_tests`.
+
+ Raises:
+ AssertionError if test fails.
+ """
+ self.verify_run_with_breaks(
+ ds_fn, [num_outputs], num_outputs, sparse_tensors=sparse_tensors)
+
+ def verify_exhausted_iterator(self, ds_fn, num_outputs, sparse_tensors=False):
+ """Verifies that saving and restoring an exhausted iterator works.
+
+ An exhausted iterator is one which has returned an OutOfRange error.
+
+ Args:
+ ds_fn: See `run_core_tests`.
+ num_outputs: See `run_core_tests`.
+ sparse_tensors: See `run_core_tests`.
+
+ Raises:
+ AssertionError if any test fails.
+ """
+ self.gen_outputs(
+ ds_fn, [],
+ num_outputs,
+ verify_exhausted=True,
+ sparse_tensors=sparse_tensors)
+ actual = self.gen_outputs(
+ ds_fn, [],
+ 0,
+ ckpt_saved=True,
+ verify_exhausted=True,
+ sparse_tensors=sparse_tensors)
+ self.assertEqual(len(actual), 0)
+
+ def verify_init_before_restore(self,
+ ds_fn,
+ num_outputs,
+ sparse_tensors=False,
+ verify_exhausted=True):
+ """Verifies that restoring into an already initialized iterator works.
+
+ Args:
+ ds_fn: See `run_core_tests`.
+ num_outputs: See `run_core_tests`.
+ sparse_tensors: See `run_core_tests`.
+ verify_exhausted: See `gen_outputs`.
+
+ Raises:
+ AssertionError if any test fails.
+ """
+ self.verify_run_with_breaks(
+ ds_fn,
+ self.gen_break_points(num_outputs),
+ num_outputs,
+ init_before_restore=True,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=verify_exhausted)
+
+ def verify_multiple_breaks(self,
+ ds_fn,
+ num_outputs,
+ num_breaks=10,
+ sparse_tensors=False,
+ verify_exhausted=True):
+ """Attempts to save/restore at multiple break points.
+
+ Args:
+ ds_fn: See `run_core_tests`.
+ num_outputs: See `run_core_tests`.
+ num_breaks: The number of break points. These are uniformly spread in
+ [0, num_outputs] both inclusive.
+ sparse_tensors: See `run_core_tests`.
+ verify_exhausted: See `gen_outputs`.
+
+ Raises:
+ AssertionError if any test fails.
+ """
+ self.verify_run_with_breaks(
+ ds_fn,
+ self.gen_break_points(num_outputs, num_breaks),
+ num_outputs,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=verify_exhausted)
+
+ def verify_reset_restored_iterator(self,
+ ds_fn,
+ num_outputs,
+ break_point=None,
+ sparse_tensors=False,
+ verify_exhausted=True):
+ """Attempts to re-initialize a restored iterator.
+
+ This is useful when restoring a training checkpoint during validation.
+
+ Args:
+ ds_fn: See `run_core_tests`.
+ num_outputs: See `run_core_tests`.
+ break_point: Break point. Optional. Defaults to num_outputs/2.
+ sparse_tensors: See `run_core_tests`.
+ verify_exhausted: See `gen_outputs`.
+
+ Raises:
+ AssertionError if any test fails.
+ """
+ break_point = num_outputs // 2 if not break_point else break_point
+
+ # Collect ground truth containing all outputs.
+ expected = self.gen_outputs(
+ ds_fn, [],
+ num_outputs,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=verify_exhausted)
+
+ # Skip some items and save checkpoint.
+ self.gen_outputs(
+ ds_fn, [],
+ break_point,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=False)
+
+ actual = []
+ # Restore from checkpoint and then run init_op.
+ with ops.Graph().as_default() as g:
+ saver = self._import_meta_graph()
+ init_op, get_next_op = self._get_iterator_ops_from_collection(
+ ds_fn, sparse_tensors=sparse_tensors)
+ get_next_op = remove_variants(get_next_op)
+ with self.session(graph=g) as sess:
+ self._restore(saver, sess)
+ self._initialize(init_op, sess)
+ for _ in range(num_outputs):
+ actual.append(sess.run(get_next_op))
+ if verify_exhausted:
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
+ self.match(expected, actual)
+
+ def verify_restore_in_modified_graph(self,
+ ds_fn1,
+ ds_fn2,
+ num_outputs,
+ break_point=None,
+ sparse_tensors=False,
+ verify_exhausted=True):
+ """Attempts to restore an iterator in a modified graph.
+
+ Builds an input pipeline using ds_fn1, runs it for `break_point` steps
+ and saves a checkpoint. Then builds a new graph using ds_fn2, restores
+ the checkpoint from ds_fn1 and verifies that the restore is successful.
+
+ Args:
+ ds_fn1: See `run_core_tests`.
+ ds_fn2: See `run_core_tests`.
+ num_outputs: See `run_core_tests`.
+ break_point: Break point. Optional. Defaults to num_outputs/2.
+ sparse_tensors: See `run_core_tests`.
+ verify_exhausted: See `gen_outputs`.
+
+ Raises:
+ AssertionError if any test fails.
+ """
+ break_point = num_outputs // 2 if not break_point else break_point
+
+ # Skip `break_point` items and store the remaining produced from ds_fn1
+ # in `expected`.
+ self.gen_outputs(
+ ds_fn1, [],
+ break_point,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=False)
+ expected = self.gen_outputs(
+ ds_fn1, [],
+ num_outputs - break_point,
+ ckpt_saved=True,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=verify_exhausted)
+
+ # Generate `break_point` items from ds_fn1 and save checkpoint.
+ self.gen_outputs(
+ ds_fn1, [],
+ break_point,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=False)
+
+ actual = []
+ # Build graph for ds_fn2 but load checkpoint for ds_fn1.
+ with ops.Graph().as_default() as g:
+ _, get_next_op, saver = self._build_graph(
+ ds_fn2, sparse_tensors=sparse_tensors)
+ get_next_op = remove_variants(get_next_op)
+ with self.session(graph=g) as sess:
+ self._restore(saver, sess)
+ for _ in range(num_outputs - break_point):
+ actual.append(sess.run(get_next_op))
+ if verify_exhausted:
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
+
+ self.match(expected, actual)
+
+ def verify_restore_in_empty_graph(self,
+ ds_fn,
+ num_outputs,
+ break_point=None,
+ sparse_tensors=False,
+ verify_exhausted=True):
+ """Attempts to restore an iterator in an empty graph.
+
+ Builds an input pipeline using ds_fn, runs it for `break_point` steps
+ and saves a checkpoint. Then builds a new empty graph, restores
+ the checkpoint from ds_fn and verifies that the restore is successful.
+
+ Args:
+ ds_fn: See `run_core_tests`.
+ num_outputs: See `run_core_tests`.
+ break_point: Break point. Optional. Defaults to num_outputs/2.
+ sparse_tensors: See `run_core_tests`.
+ verify_exhausted: See `gen_outputs`.
+
+ Raises:
+ AssertionError if any test fails.
+ """
+ break_point = num_outputs // 2 if not break_point else break_point
+
+ # Skip `break_point` items and store the remaining produced from ds_fn
+ # in `expected`.
+ self.gen_outputs(
+ ds_fn, [],
+ break_point,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=False)
+ expected = self.gen_outputs(
+ ds_fn, [],
+ num_outputs - break_point,
+ ckpt_saved=True,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=verify_exhausted)
+
+ # Generate `break_point` items from ds_fn and save checkpoint.
+ self.gen_outputs(
+ ds_fn, [],
+ break_point,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=False)
+
+ actual = []
+ # Build an empty graph but load checkpoint for ds_fn.
+ with ops.Graph().as_default() as g:
+ get_next_op, saver = self._build_empty_graph(
+ ds_fn, sparse_tensors=sparse_tensors)
+ get_next_op = remove_variants(get_next_op)
+ with self.session(graph=g) as sess:
+ self._restore(saver, sess)
+ for _ in range(num_outputs - break_point):
+ actual.append(sess.run(get_next_op))
+ if verify_exhausted:
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
+
+ self.match(expected, actual)
+
+ def verify_error_on_save(self,
+ ds_fn,
+ num_outputs,
+ error,
+ break_point=None,
+ sparse_tensors=False):
+ """Attempts to save a non-saveable iterator.
+
+ Args:
+ ds_fn: See `run_core_tests`.
+ num_outputs: See `run_core_tests`.
+ error: Declared error when trying to save iterator.
+ break_point: Break point. Optional. Defaults to num_outputs/2.
+ sparse_tensors: See `run_core_tests`.
+
+ Raises:
+ AssertionError if any test fails.
+ """
+
+ break_point = num_outputs // 2 if not break_point else break_point
+ with ops.Graph().as_default() as g:
+ init_op, get_next_op, saver = self._build_graph(
+ ds_fn, sparse_tensors=sparse_tensors)
+ get_next_op = remove_variants(get_next_op)
+ with self.session(graph=g) as sess:
+ self._initialize(init_op, sess)
+ for _ in range(break_point):
+ sess.run(get_next_op)
+ with self.assertRaises(error):
+ self._save(sess, saver)
+
+ def verify_run_with_breaks(self,
+ ds_fn,
+ break_points,
+ num_outputs,
+ init_before_restore=False,
+ sparse_tensors=False,
+ verify_exhausted=True):
+ """Verifies that ds_fn() produces the same outputs with and without breaks.
+
+ 1. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it
+ *without* stopping at break points.
+ 2. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it
+ with stopping at break points.
+
+ Deep matches outputs from 1 and 2.
+
+ Args:
+ ds_fn: See `gen_outputs`.
+ break_points: See `gen_outputs`.
+ num_outputs: See `gen_outputs`.
+ init_before_restore: See `gen_outputs`.
+ sparse_tensors: See `run_core_tests`.
+ verify_exhausted: See `gen_outputs`.
+
+ Raises:
+ AssertionError if any test fails.
+ """
+ expected = self.gen_outputs(
+ ds_fn, [],
+ num_outputs,
+ init_before_restore=init_before_restore,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=verify_exhausted)
+
+ actual = self.gen_outputs(
+ ds_fn,
+ break_points,
+ num_outputs,
+ init_before_restore=init_before_restore,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=verify_exhausted)
+
+ self.match(expected, actual)
+
+ def gen_outputs(self,
+ ds_fn,
+ break_points,
+ num_outputs,
+ ckpt_saved=False,
+ init_before_restore=False,
+ sparse_tensors=False,
+ verify_exhausted=True,
+ save_checkpoint_at_end=True):
+ """Generates elements from input dataset while stopping at break points.
+
+ Produces `num_outputs` outputs and saves the state of the iterator in the
+ Saver checkpoint.
+
+ Args:
+ ds_fn: 0-argument function that returns the dataset.
+ break_points: A list of integers. For each `break_point` in
+ `break_points`, we produce outputs till `break_point` number of items
+ have been produced and then checkpoint the state. The current graph
+ and session are destroyed and a new graph and session are used to
+ produce outputs till next checkpoint or till `num_outputs` elements
+ have been produced. `break_point` must be <= `num_outputs`.
+ num_outputs: The total number of outputs to produce from the iterator.
+ ckpt_saved: Whether a checkpoint already exists. If False, we build the
+ graph from ds_fn.
+ init_before_restore: Whether init should be called before saver.restore.
+ This is just so that we can verify that restoring an already initialized
+ iterator works.
+ sparse_tensors: Whether dataset is built from SparseTensor(s).
+ verify_exhausted: Whether to verify that the iterator has been exhausted
+ after producing `num_outputs` elements.
+ save_checkpoint_at_end: Whether to save a checkpoint after producing all
+ outputs. If False, checkpoints are saved each break point but not at the
+ end. Note that checkpoints overwrite each other so there is always only
+ a single checkpoint available. Defaults to True.
+
+ Returns:
+ A list of `num_outputs` items.
+ """
+ outputs = []
+
+ def get_ops():
+ if ckpt_saved:
+ saver = self._import_meta_graph()
+ init_op, get_next_op = self._get_iterator_ops_from_collection(
+ ds_fn, sparse_tensors=sparse_tensors)
+ else:
+ init_op, get_next_op, saver = self._build_graph(
+ ds_fn, sparse_tensors=sparse_tensors)
+ return init_op, get_next_op, saver
+
+ for i in range(len(break_points) + 1):
+ with ops.Graph().as_default() as g:
+ init_op, get_next_op, saver = get_ops()
+ get_next_op = remove_variants(get_next_op)
+ with self.session(graph=g) as sess:
+ if ckpt_saved:
+ if init_before_restore:
+ self._initialize(init_op, sess)
+ self._restore(saver, sess)
+ else:
+ self._initialize(init_op, sess)
+ start = break_points[i - 1] if i > 0 else 0
+ end = break_points[i] if i < len(break_points) else num_outputs
+ num_iters = end - start
+ for _ in range(num_iters):
+ outputs.append(sess.run(get_next_op))
+ if i == len(break_points) and verify_exhausted:
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
+ if save_checkpoint_at_end or i < len(break_points):
+ self._save(sess, saver)
+ ckpt_saved = True
+
+ return outputs
+
+ def match(self, expected, actual):
+ """Matches nested structures.
+
+ Recursively matches shape and values of `expected` and `actual`.
+ Handles scalars, numpy arrays and other python sequence containers
+ e.g. list, dict.
+
+ Args:
+ expected: Nested structure 1.
+ actual: Nested structure 2.
+
+ Raises:
+ AssertionError if matching fails.
+ """
+ if isinstance(expected, np.ndarray):
+ expected = expected.tolist()
+ if isinstance(actual, np.ndarray):
+ actual = actual.tolist()
+ self.assertEqual(type(expected), type(actual))
+
+ if nest.is_sequence(expected):
+ self.assertEqual(len(expected), len(actual))
+ if isinstance(expected, dict):
+ for key1, key2 in zip(sorted(expected), sorted(actual)):
+ self.assertEqual(key1, key2)
+ self.match(expected[key1], actual[key2])
+ else:
+ for item1, item2 in zip(expected, actual):
+ self.match(item1, item2)
+ else:
+ self.assertEqual(expected, actual)
+
+ def does_not_match(self, expected, actual):
+ with self.assertRaises(AssertionError):
+ self.match(expected, actual)
+
+ def gen_break_points(self, num_outputs, num_samples=10):
+ """Generates `num_samples` breaks points in [0, num_outputs]."""
+ return np.linspace(0, num_outputs, num_samples, dtype=int)
+
+ def _build_graph(self, ds_fn, sparse_tensors=False):
+ iterator = ds_fn().make_initializable_iterator()
+
+ saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
+ ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
+ init_op = iterator.initializer
+ if sparse_tensors:
+ get_next = sparse_tensor.SparseTensor(*iterator.get_next())
+ else:
+ get_next = iterator.get_next()
+ self._add_iterator_ops_to_collection(init_op, get_next, ds_fn,
+ sparse_tensors)
+ saver = saver_lib.Saver(allow_empty=True)
+ return init_op, get_next, saver
+
+ def _build_empty_graph(self, ds_fn, sparse_tensors=False):
+ iterator = iterator_ops.Iterator.from_structure(
+ self._get_output_types(ds_fn),
+ output_shapes=self._get_output_shapes(ds_fn),
+ output_classes=self._get_output_classes(ds_fn))
+ saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
+ ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
+ if sparse_tensors:
+ get_next = sparse_tensor.SparseTensor(*iterator.get_next())
+ else:
+ get_next = iterator.get_next()
+ saver = saver_lib.Saver(allow_empty=True)
+ return get_next, saver
+
+ def _add_iterator_ops_to_collection(self,
+ init_op,
+ get_next,
+ ds_fn,
+ sparse_tensors=False):
+ ops.add_to_collection("iterator_ops", init_op)
+ # `get_next` may be a tuple e.g. in TensorSliceDataset. Since Collections
+ # do not support tuples we flatten the tensors and restore the shape in
+ # `_get_iterator_ops_from_collection`.
+ if sparse_tensors: # specific for deprecated `from_sparse_tensor_slices`.
+ ops.add_to_collection("iterator_ops", get_next.indices)
+ ops.add_to_collection("iterator_ops", get_next.values)
+ ops.add_to_collection("iterator_ops", get_next.dense_shape)
+ return
+
+ get_next_list = nest.flatten(get_next)
+ for i, output_class in enumerate(
+ nest.flatten(self._get_output_classes(ds_fn))):
+ if output_class is sparse_tensor.SparseTensor:
+ ops.add_to_collection("iterator_ops", get_next_list[i].indices)
+ ops.add_to_collection("iterator_ops", get_next_list[i].values)
+ ops.add_to_collection("iterator_ops", get_next_list[i].dense_shape)
+ else:
+ ops.add_to_collection("iterator_ops", get_next_list[i])
+
+ def _get_iterator_ops_from_collection(self, ds_fn, sparse_tensors=False):
+ all_ops = ops.get_collection("iterator_ops")
+ if sparse_tensors: # specific for deprecated `from_sparse_tensor_slices`.
+ init_op, indices, values, dense_shape = all_ops
+ return init_op, sparse_tensor.SparseTensor(indices, values, dense_shape)
+ get_next_list = []
+ i = 1
+ for output_class in nest.flatten(self._get_output_classes(ds_fn)):
+ if output_class is sparse_tensor.SparseTensor:
+ indices, values, dense_shape = all_ops[i:i + 3]
+ i += 3
+ get_next_list.append(
+ sparse_tensor.SparseTensor(indices, values, dense_shape))
+ else:
+ get_next_list.append(all_ops[i])
+ i += 1
+ return all_ops[0], nest.pack_sequence_as(
+ self._get_output_types(ds_fn), get_next_list)
+
+ def _get_output_types(self, ds_fn):
+ with ops.Graph().as_default():
+ return ds_fn().output_types
+
+ def _get_output_shapes(self, ds_fn):
+ with ops.Graph().as_default():
+ return ds_fn().output_shapes
+
+ def _get_output_classes(self, ds_fn):
+ with ops.Graph().as_default():
+ return ds_fn().output_classes
+
+ def _ckpt_path(self):
+ return os.path.join(self.get_temp_dir(), "iterator")
+
+ def _latest_ckpt(self):
+ return checkpoint_management.latest_checkpoint(self.get_temp_dir())
+
+ def _save(self, sess, saver):
+ saver.save(sess, self._ckpt_path())
+
+ def _restore(self, saver, sess):
+ sess.run(lookup_ops.tables_initializer())
+ saver.restore(sess, self._latest_ckpt())
+
+ def _initialize(self, init_op, sess):
+ sess.run(variables.global_variables_initializer())
+ sess.run(lookup_ops.tables_initializer())
+ sess.run(init_op)
+
+ def _import_meta_graph(self):
+ meta_file_path = self._ckpt_path() + ".meta"
+ return saver_lib.import_meta_graph(meta_file_path)
+
+ def _delete_ckpt(self):
+ # Remove all checkpoint files.
+ prefix = self._ckpt_path()
+ pattern = prefix + "*"
+ files = gfile.Glob(pattern)
+ map(gfile.Remove, files)
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/filter_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/filter_dataset_serialization_test.py
index 7c170078a1..225f6cbac0 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/filter_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/filter_dataset_serialization_test.py
@@ -19,7 +19,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import math_ops
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py
index 34392d88d4..70caf3e0d5 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py
@@ -17,8 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.data.ops import readers as core_readers
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/flat_map_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/flat_map_dataset_serialization_test.py
index 16051ffd3f..c30534a9e9 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/flat_map_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/flat_map_dataset_serialization_test.py
@@ -17,7 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/group_by_reducer_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/group_by_reducer_serialization_test.py
index 571e0899bb..169c8845d0 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/group_by_reducer_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/group_by_reducer_serialization_test.py
@@ -19,8 +19,8 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import grouping
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import grouping
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/group_by_window_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/group_by_window_serialization_test.py
index f86af4084e..e5bc76288e 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/group_by_window_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/group_by_window_serialization_test.py
@@ -19,8 +19,8 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import grouping
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import grouping
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/ignore_errors_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/ignore_errors_serialization_test.py
index 65ae9923b8..df1f43129a 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/ignore_errors_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/ignore_errors_serialization_test.py
@@ -19,8 +19,8 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import error_ops
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import error_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/interleave_dataset_serialization_test.py
index 243f6405a1..0c1d40ce39 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/interleave_dataset_serialization_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
from absl.testing import parameterized
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import sparse_ops
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py
index c9cd211328..166ffa99ca 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py
@@ -19,8 +19,8 @@ from __future__ import print_function
import math
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import batching
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/map_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/map_dataset_serialization_test.py
index ab783e5cce..b93156a96c 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/map_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/map_dataset_serialization_test.py
@@ -19,7 +19,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/optimize_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/optimize_dataset_serialization_test.py
index d5c03495e3..ed4a1da596 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/optimize_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/optimize_dataset_serialization_test.py
@@ -17,8 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import optimization
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/padded_batch_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/padded_batch_dataset_serialization_test.py
index 9ac42a461a..6f72b24673 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/padded_batch_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/padded_batch_dataset_serialization_test.py
@@ -19,7 +19,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import string_ops
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py
index 1f8a584df9..b8f38e8a28 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py
@@ -19,8 +19,8 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import interleave_ops
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import interleave_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import sparse_ops
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/parallel_map_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/parallel_map_dataset_serialization_test.py
index 3fb7605be1..a0bdd4fa59 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/parallel_map_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/parallel_map_dataset_serialization_test.py
@@ -19,7 +19,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -65,7 +65,7 @@ class ParallelMapDatasetSerializationTest(
for ds_fn in [self._build_ds, self._build_ds_with_prefetch]:
self.run_core_tests(
ds_fn,
- lambda: ds_fn(multiplier=15.0),
+ lambda: ds_fn(multiplier=15.0), # pylint: disable=cell-var-from-loop
self._num_outputs)
def testSaveStatefulFunction(self):
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/parse_example_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/parse_example_dataset_serialization_test.py
index d3fa84e74c..a0dd6960b0 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/parse_example_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/parse_example_dataset_serialization_test.py
@@ -17,8 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/prefetch_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/prefetch_dataset_serialization_test.py
index c802402461..00d74c0025 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/prefetch_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/prefetch_dataset_serialization_test.py
@@ -17,7 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/range_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/range_dataset_serialization_test.py
index 6341190847..ef99d01c73 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/range_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/range_dataset_serialization_test.py
@@ -19,7 +19,7 @@ from __future__ import print_function
import os
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/sample_from_datasets_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/sample_from_datasets_serialization_test.py
index fdb35ea624..c23c1ecdfb 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/sample_from_datasets_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/sample_from_datasets_serialization_test.py
@@ -17,8 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import interleave_ops
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import interleave_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/scan_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/scan_dataset_serialization_test.py
index af9ef48c0f..5f50160619 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/scan_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/scan_dataset_serialization_test.py
@@ -17,8 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import scan_ops
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import scan_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/sequence_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/sequence_dataset_serialization_test.py
index 2afebca0f5..fe99a3d3d9 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/sequence_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/sequence_dataset_serialization_test.py
@@ -19,7 +19,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/serialization_integration_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/serialization_integration_test.py
index 6aac50ecd9..88d5c896c9 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/serialization_integration_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/serialization_integration_test.py
@@ -19,7 +19,7 @@ from __future__ import print_function
import os
-from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops
+from tensorflow.python.data.experimental.ops import iterator_ops as contrib_iterator_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import ops
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py
index f199ec835e..f847ac19f9 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py
@@ -17,8 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import shuffle_ops
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import shuffle_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/shuffle_dataset_serialization_test.py
index a59fa94d66..a04f1ddafc 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/shuffle_dataset_serialization_test.py
@@ -17,8 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import iterator_ops as contrib_iterator_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import ops
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/sql_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/sql_dataset_serialization_test.py
index 93b26ed58a..b179770ce3 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/sql_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/sql_dataset_serialization_test.py
@@ -19,9 +19,9 @@ from __future__ import print_function
import os
-from tensorflow.contrib.data.python.kernel_tests import sql_dataset_op_test_base
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import readers
+from tensorflow.python.data.experimental.kernel_tests import sql_dataset_op_test_base
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import readers
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/stats_dataset_serialization_test.py
index a10f85263a..ef7061b190 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/stats_dataset_serialization_test.py
@@ -17,8 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import stats_ops
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import stats_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/textline_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/textline_dataset_serialization_test.py
index 2483787f44..c87a7443a7 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/textline_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/textline_dataset_serialization_test.py
@@ -17,8 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.data.ops import readers as core_readers
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/tf_record_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/tf_record_dataset_serialization_test.py
index 55a6257a27..f0dcc131d4 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/tf_record_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/tf_record_dataset_serialization_test.py
@@ -21,8 +21,8 @@ import gzip
import os
import zlib
-from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.data.ops import readers as core_readers
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/unbatch_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/unbatch_dataset_serialization_test.py
index b2a5a8a20d..528598dfe4 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/unbatch_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/unbatch_dataset_serialization_test.py
@@ -19,8 +19,8 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import batching
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/unique_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/unique_dataset_serialization_test.py
index 22f15b8846..e2862af4d6 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/unique_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/unique_dataset_serialization_test.py
@@ -17,8 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import unique
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import unique
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/zip_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/zip_dataset_serialization_test.py
index 340a6ff72e..4ea6131c22 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/zip_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/zip_dataset_serialization_test.py
@@ -19,7 +19,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.platform import test
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization_integration_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization_integration_test.py
new file mode 100644
index 0000000000..88d5c896c9
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization_integration_test.py
@@ -0,0 +1,85 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Integration test for dataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from tensorflow.python.data.experimental.ops import iterator_ops as contrib_iterator_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import ops
+from tensorflow.python.platform import test
+from tensorflow.python.training import saver as saver_lib
+
+
+class SerializationIntegrationTest(test.TestCase):
+
+ def _build_input_pipeline(self, name, num_outputs):
+ with ops.name_scope(name):
+ ds = dataset_ops.Dataset.range(num_outputs).shuffle(
+ 10, reshuffle_each_iteration=False).prefetch(10)
+ iterator = ds.make_initializable_iterator()
+ saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
+ ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
+ return iterator.initializer, iterator.get_next()
+
+ def _build_graph(self, num_pipelines, num_outputs):
+ init_ops = []
+ get_next_ops = []
+ for i in range(num_pipelines):
+ name = "input_pipeline_%d" % i
+ init_op, get_next_op = self._build_input_pipeline(name, num_outputs)
+ init_ops.append(init_op)
+ get_next_ops.append(get_next_op)
+ saver = saver_lib.Saver()
+ return init_ops, get_next_ops, saver
+
+ def _ckpt_path(self):
+ return os.path.join(self.get_temp_dir(), "iterator")
+
+ def testConcurrentSaves(self):
+ num_pipelines = 100
+ num_outputs = 100
+ break_point = 10
+ all_outputs = [[] for _ in range(num_pipelines)]
+ with ops.Graph().as_default() as g:
+ init_ops, get_next_ops, saver = self._build_graph(num_pipelines,
+ num_outputs)
+ with self.session(graph=g) as sess:
+ sess.run(init_ops)
+ for _ in range(break_point):
+ output = sess.run(get_next_ops)
+ for i in range(num_pipelines):
+ all_outputs[i].append(output[i])
+ saver.save(sess, self._ckpt_path())
+
+ with ops.Graph().as_default() as g:
+ init_ops, get_next_ops, saver = self._build_graph(num_pipelines,
+ num_outputs)
+ with self.session(graph=g) as sess:
+ saver.restore(sess, self._ckpt_path())
+ for _ in range(num_outputs - break_point):
+ output = sess.run(get_next_ops)
+ for i in range(num_pipelines):
+ all_outputs[i].append(output[i])
+
+ for output in all_outputs:
+ self.assertSequenceEqual(sorted(output), range(num_outputs))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/shuffle_dataset_op_test.py
index c97002a255..50895b5945 100644
--- a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/shuffle_dataset_op_test.py
@@ -19,7 +19,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.data.python.ops import shuffle_ops
+from tensorflow.python.data.experimental.ops import shuffle_ops
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
diff --git a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/sql_dataset_op_test.py
index 52823d3fca..301f75488a 100644
--- a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/sql_dataset_op_test.py
@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.kernel_tests import sql_dataset_op_test_base
+from tensorflow.python.data.experimental.kernel_tests import sql_dataset_op_test_base
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py b/tensorflow/python/data/experimental/kernel_tests/sql_dataset_op_test_base.py
index 319a2ea263..a135c357f0 100644
--- a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py
+++ b/tensorflow/python/data/experimental/kernel_tests/sql_dataset_op_test_base.py
@@ -23,7 +23,7 @@ import os
import sqlite3
-from tensorflow.contrib.data.python.ops import readers
+from tensorflow.python.data.experimental.ops import readers
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py
index be8ae5e955..6761fbd16b 100644
--- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py
@@ -19,8 +19,8 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests import stats_dataset_test_base
-from tensorflow.contrib.data.python.ops import stats_ops
+from tensorflow.python.data.experimental.kernel_tests import stats_dataset_test_base
+from tensorflow.python.data.experimental.ops import stats_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py b/tensorflow/python/data/experimental/kernel_tests/stats_dataset_test_base.py
index 80f2625927..80f2625927 100644
--- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py
+++ b/tensorflow/python/data/experimental/kernel_tests/stats_dataset_test_base.py
diff --git a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/threadpool_dataset_ops_test.py
index 08de3a9143..4432dcb05a 100644
--- a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/threadpool_dataset_ops_test.py
@@ -22,8 +22,8 @@ import threading
from absl.testing import parameterized
import numpy as np
-from tensorflow.contrib.data.python.ops import threadpool
-from tensorflow.contrib.data.python.ops import unique
+from tensorflow.python.data.experimental.ops import threadpool
+from tensorflow.python.data.experimental.ops import unique
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
diff --git a/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/unique_dataset_op_test.py
index 8856ce5afb..b5a0b20f3f 100644
--- a/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/unique_dataset_op_test.py
@@ -17,7 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.ops import unique
+from tensorflow.python.data.experimental.ops import unique
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
diff --git a/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/writer_ops_test.py
index fca546a570..25a2e63ba1 100644
--- a/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/writer_ops_test.py
@@ -19,7 +19,7 @@ from __future__ import print_function
import os
-from tensorflow.contrib.data.python.ops import writers
+from tensorflow.python.data.experimental.ops import writers
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers
diff --git a/tensorflow/python/data/experimental/ops/BUILD b/tensorflow/python/data/experimental/ops/BUILD
new file mode 100644
index 0000000000..915d399f1b
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/BUILD
@@ -0,0 +1,377 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_gen_op_wrapper_py",
+ "tf_kernel_library",
+)
+load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
+
+py_library(
+ name = "counter",
+ srcs = ["counter.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":scan_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_library(
+ name = "get_single_element",
+ srcs = ["get_single_element.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
+
+py_library(
+ name = "iterator_ops",
+ srcs = [
+ "iterator_ops.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:basic_session_run_hooks",
+ "//tensorflow/python:checkpoint_management",
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:saver",
+ "//tensorflow/python:session_run_hook",
+ "//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python/data/ops:optional_ops",
+ ],
+)
+
+py_library(
+ name = "random_ops",
+ srcs = [
+ "random_ops.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:random_seed",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
+
+py_library(
+ name = "readers",
+ srcs = [
+ "readers.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":batching",
+ ":interleave_ops",
+ ":optimization",
+ ":parsing_ops",
+ ":shuffle_ops",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:experimental_dataset_ops_gen",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:lib",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:readers",
+ "//tensorflow/python/data/util:convert",
+ "//tensorflow/python/data/util:nest",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_library(
+ name = "shuffle_ops",
+ srcs = [
+ "shuffle_ops.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_library(
+ name = "batching",
+ srcs = ["batching.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":get_single_element",
+ ":grouping",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python:tensor_util",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:convert",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_library(
+ name = "enumerate_ops",
+ srcs = ["enumerate_ops.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_library(
+ name = "error_ops",
+ srcs = ["error_ops.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:experimental_dataset_ops_gen",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
+
+py_library(
+ name = "grouping",
+ srcs = ["grouping.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:check_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:function",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
+
+py_library(
+ name = "interleave_ops",
+ srcs = ["interleave_ops.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":random_ops",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:experimental_dataset_ops_gen",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:stateless_random_ops_gen",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/ops:readers",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
+
+py_library(
+ name = "optimization",
+ srcs = ["optimization.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:experimental_dataset_ops_gen",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
+
+py_library(
+ name = "parsing_ops",
+ srcs = ["parsing_ops.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ ],
+)
+
+py_library(
+ name = "map_defun",
+ srcs = ["map_defun.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:tensor_shape",
+ ],
+)
+
+py_library(
+ name = "resampling",
+ srcs = ["resampling.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":batching",
+ ":interleave_ops",
+ ":scan_ops",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:logging_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:random_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_library(
+ name = "scan_ops",
+ srcs = ["scan_ops.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:function",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
+
+py_library(
+ name = "stats_ops",
+ srcs = ["stats_ops.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
+
+py_library(
+ name = "threadpool",
+ srcs = ["threadpool.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:experimental_dataset_ops_gen",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ "//tensorflow/python/eager:context",
+ ],
+)
+
+py_library(
+ name = "unique",
+ srcs = [
+ "unique.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:experimental_dataset_ops_gen",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
+
+py_library(
+ name = "writers",
+ srcs = [
+ "writers.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_library(
+ name = "indexed_dataset_ops",
+ srcs = ["indexed_dataset_ops.py"],
+ deps = [
+ "//tensorflow/python:experimental_dataset_ops_gen",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
+
+py_library(
+ name = "prefetching_ops",
+ srcs = ["prefetching_ops.py"],
+ deps = [
+ "//tensorflow/python:experimental_dataset_ops_gen",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
+
+py_library(
+ name = "dataset_ops",
+ deps = [
+ ":batching",
+ ":counter",
+ ":enumerate_ops",
+ ":error_ops",
+ ":get_single_element",
+ ":grouping",
+ ":indexed_dataset_ops",
+ ":interleave_ops",
+ ":map_defun",
+ ":optimization",
+ ":prefetching_ops",
+ ":readers",
+ ":resampling",
+ ":scan_ops",
+ ":shuffle_ops",
+ ":stats_ops",
+ ":threadpool",
+ ":unique",
+ ":writers",
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ ],
+)
diff --git a/tensorflow/python/data/experimental/ops/batching.py b/tensorflow/python/data/experimental/ops/batching.py
new file mode 100644
index 0000000000..d42af9e7e9
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/batching.py
@@ -0,0 +1,669 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Batching dataset transformations."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.ops import get_single_element
+from tensorflow.python.data.experimental.ops import grouping
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import convert
+from tensorflow.python.data.util import nest
+from tensorflow.python.data.util import sparse
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_array_ops
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import sparse_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+def batch_window(dataset):
+ """Batches a window of tensors.
+
+ Args:
+ dataset: the input dataset.
+
+ Returns:
+ A `Tensor` representing the batch of the entire input dataset.
+ """
+ if isinstance(dataset.output_classes, tuple):
+ raise TypeError("Input dataset expected to have a single component")
+ if dataset.output_classes is ops.Tensor:
+ return _batch_dense_window(dataset)
+ elif dataset.output_classes is sparse_tensor.SparseTensor:
+ return _batch_sparse_window(dataset)
+ else:
+ raise TypeError("Unsupported dataset type: %s" % dataset.output_classes)
+
+
+def _batch_dense_window(dataset):
+ """Batches a window of dense tensors."""
+
+ def key_fn(_):
+ return np.int64(0)
+
+ def shape_init_fn(_):
+ return array_ops.shape(first_element)
+
+ def shape_reduce_fn(state, value):
+ check_ops.assert_equal(state, array_ops.shape(value))
+ return state
+
+ def finalize_fn(state):
+ return state
+
+ if dataset.output_shapes.is_fully_defined():
+ shape = dataset.output_shapes
+ else:
+ first_element = get_single_element.get_single_element(dataset.take(1))
+ shape_reducer = grouping.Reducer(shape_init_fn, shape_reduce_fn,
+ finalize_fn)
+ shape = get_single_element.get_single_element(
+ dataset.apply(grouping.group_by_reducer(key_fn, shape_reducer)))
+
+ def batch_init_fn(_):
+ batch_shape = array_ops.concat([[0], shape], 0)
+ return gen_array_ops.empty(batch_shape, dtype=dataset.output_types)
+
+ def batch_reduce_fn(state, value):
+ return array_ops.concat([state, [value]], 0)
+
+ batch_reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn)
+ return get_single_element.get_single_element(
+ dataset.apply(grouping.group_by_reducer(key_fn, batch_reducer)))
+
+
+def _batch_sparse_window(dataset):
+ """Batches a window of sparse tensors."""
+
+ def key_fn(_):
+ return np.int64(0)
+
+ def shape_init_fn(_):
+ return first_element.dense_shape
+
+ def shape_reduce_fn(state, value):
+ check_ops.assert_equal(state, value.dense_shape)
+ return state
+
+ def finalize_fn(state):
+ return state
+
+ if dataset.output_shapes.is_fully_defined():
+ shape = dataset.output_shapes
+ else:
+ first_element = get_single_element.get_single_element(dataset.take(1))
+ shape_reducer = grouping.Reducer(shape_init_fn, shape_reduce_fn,
+ finalize_fn)
+ shape = get_single_element.get_single_element(
+ dataset.apply(grouping.group_by_reducer(key_fn, shape_reducer)))
+
+ def batch_init_fn(_):
+ indices_shape = array_ops.concat([[0], [array_ops.size(shape) + 1]], 0)
+ return sparse_tensor.SparseTensor(
+ indices=gen_array_ops.empty(indices_shape, dtype=dtypes.int64),
+ values=constant_op.constant([], shape=[0], dtype=dataset.output_types),
+ dense_shape=array_ops.concat(
+ [np.array([0], dtype=np.int64),
+ math_ops.cast(shape, dtypes.int64)], 0))
+
+ def batch_reduce_fn(state, value):
+ return sparse_ops.sparse_concat(0, [state, value])
+
+ def reshape_fn(value):
+ return sparse_ops.sparse_reshape(
+ value,
+ array_ops.concat([np.array([1], dtype=np.int64), value.dense_shape], 0))
+
+ batch_reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn)
+ return get_single_element.get_single_element(
+ dataset.map(reshape_fn).apply(
+ grouping.group_by_reducer(key_fn, batch_reducer)))
+
+
+@tf_export("data.experimental.dense_to_sparse_batch")
+def dense_to_sparse_batch(batch_size, row_shape):
+ """A transformation that batches ragged elements into `tf.SparseTensor`s.
+
+ Like `Dataset.padded_batch()`, this transformation combines multiple
+ consecutive elements of the dataset, which might have different
+ shapes, into a single element. The resulting element has three
+ components (`indices`, `values`, and `dense_shape`), which
+ comprise a `tf.SparseTensor` that represents the same data. The
+ `row_shape` represents the dense shape of each row in the
+ resulting `tf.SparseTensor`, to which the effective batch size is
+ prepended. For example:
+
+ ```python
+ # NOTE: The following examples use `{ ... }` to represent the
+ # contents of a dataset.
+ a = { ['a', 'b', 'c'], ['a', 'b'], ['a', 'b', 'c', 'd'] }
+
+ a.apply(tf.data.experimental.dense_to_sparse_batch(
+ batch_size=2, row_shape=[6])) ==
+ {
+ ([[0, 0], [0, 1], [0, 2], [1, 0], [1, 1]], # indices
+ ['a', 'b', 'c', 'a', 'b'], # values
+ [2, 6]), # dense_shape
+ ([[0, 0], [0, 1], [0, 2], [0, 3]],
+ ['a', 'b', 'c', 'd'],
+ [1, 6])
+ }
+ ```
+
+ Args:
+ batch_size: A `tf.int64` scalar `tf.Tensor`, representing the
+ number of consecutive elements of this dataset to combine in a
+ single batch.
+ row_shape: A `tf.TensorShape` or `tf.int64` vector tensor-like
+ object representing the equivalent dense shape of a row in the
+ resulting `tf.SparseTensor`. Each element of this dataset must
+ have the same rank as `row_shape`, and must have size less
+ than or equal to `row_shape` in each dimension.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+
+ def _apply_fn(dataset):
+ return _DenseToSparseBatchDataset(dataset, batch_size, row_shape)
+
+ return _apply_fn
+
+
+def padded_batch_window(dataset, padded_shape, padding_value=None):
+ """Batches a window of tensors with padding.
+
+ Args:
+ dataset: the input dataset.
+ padded_shape: (Optional.) `tf.TensorShape` or `tf.int64` vector tensor-like
+ object representing the shape to which the input elements should be padded
+ prior to batching. Any unknown dimensions (e.g. `tf.Dimension(None)` in a
+ `tf.TensorShape` or `-1` in a tensor-like object) will be padded to the
+ maximum size of that dimension in each batch.
+ padding_value: (Optional.) A scalar-shaped `tf.Tensor`, representing the
+ padding value to use. Defaults are `0` for numeric types and the empty
+ string for string types. If `dataset` contains `tf.SparseTensor`, this
+ value is ignored.
+
+ Returns:
+ A `Tensor` representing the batch of the entire input dataset.
+
+ Raises:
+ ValueError: if invalid arguments are provided.
+ """
+ if not issubclass(dataset.output_classes,
+ (ops.Tensor, sparse_tensor.SparseTensor)):
+ raise TypeError("Input dataset expected to have a single tensor component")
+ if issubclass(dataset.output_classes, (ops.Tensor)):
+ return _padded_batch_dense_window(dataset, padded_shape, padding_value)
+ elif issubclass(dataset.output_classes, (sparse_tensor.SparseTensor)):
+ if padding_value is not None:
+ raise ValueError("Padding value not allowed for sparse tensors")
+ return _padded_batch_sparse_window(dataset, padded_shape)
+ else:
+ raise TypeError("Unsupported dataset type: %s" % dataset.output_classes)
+
+
+def _padded_batch_dense_window(dataset, padded_shape, padding_value=None):
+ """Batches a window of dense tensors with padding."""
+
+ padded_shape = math_ops.cast(
+ convert.partial_shape_to_tensor(padded_shape), dtypes.int32)
+
+ def key_fn(_):
+ return np.int64(0)
+
+ def max_init_fn(_):
+ return padded_shape
+
+ def max_reduce_fn(state, value):
+ """Computes the maximum shape to pad to."""
+ condition = math_ops.reduce_all(
+ math_ops.logical_or(
+ math_ops.less_equal(array_ops.shape(value), padded_shape),
+ math_ops.equal(padded_shape, -1)))
+ assert_op = control_flow_ops.Assert(condition, [
+ "Actual shape greater than padded shape: ",
+ array_ops.shape(value), padded_shape
+ ])
+ with ops.control_dependencies([assert_op]):
+ return math_ops.maximum(state, array_ops.shape(value))
+
+ def finalize_fn(state):
+ return state
+
+ # Compute the padded shape.
+ max_reducer = grouping.Reducer(max_init_fn, max_reduce_fn, finalize_fn)
+ padded_shape = get_single_element.get_single_element(
+ dataset.apply(grouping.group_by_reducer(key_fn, max_reducer)))
+
+ if padding_value is None:
+ if dataset.output_types == dtypes.string:
+ padding_value = ""
+ elif dataset.output_types == dtypes.bool:
+ padding_value = False
+ elif dataset.output_types == dtypes.variant:
+ raise TypeError("Unable to create padding for field of type 'variant'")
+ else:
+ padding_value = 0
+
+ def batch_init_fn(_):
+ batch_shape = array_ops.concat(
+ [np.array([0], dtype=np.int32), padded_shape], 0)
+ return gen_array_ops.empty(batch_shape, dtype=dataset.output_types)
+
+ def batch_reduce_fn(state, value):
+ return array_ops.concat([state, [value]], 0)
+
+ def pad_fn(value):
+ shape = array_ops.shape(value)
+ left = array_ops.zeros_like(shape)
+ right = padded_shape - shape
+ return array_ops.pad(
+ value, array_ops.stack([left, right], 1), constant_values=padding_value)
+
+ batch_reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn)
+ return get_single_element.get_single_element(
+ dataset.map(pad_fn).apply(
+ grouping.group_by_reducer(key_fn, batch_reducer)))
+
+
+def _padded_batch_sparse_window(dataset, padded_shape):
+ """Batches a window of sparse tensors with padding."""
+
+ def key_fn(_):
+ return np.int64(0)
+
+ def max_init_fn(_):
+ return convert.partial_shape_to_tensor(padded_shape)
+
+ def max_reduce_fn(state, value):
+ """Computes the maximum shape to pad to."""
+ condition = math_ops.reduce_all(
+ math_ops.logical_or(
+ math_ops.less_equal(value.dense_shape, padded_shape),
+ math_ops.equal(padded_shape, -1)))
+ assert_op = control_flow_ops.Assert(condition, [
+ "Actual shape greater than padded shape: ", value.dense_shape,
+ padded_shape
+ ])
+ with ops.control_dependencies([assert_op]):
+ return math_ops.maximum(state, value.dense_shape)
+
+ def finalize_fn(state):
+ return state
+
+ # Compute the padded shape.
+ max_reducer = grouping.Reducer(max_init_fn, max_reduce_fn, finalize_fn)
+ padded_shape = get_single_element.get_single_element(
+ dataset.apply(grouping.group_by_reducer(key_fn, max_reducer)))
+
+ def batch_init_fn(_):
+ indices_shape = array_ops.concat([[0], [array_ops.size(padded_shape) + 1]],
+ 0)
+ return sparse_tensor.SparseTensor(
+ indices=gen_array_ops.empty(indices_shape, dtype=dtypes.int64),
+ values=constant_op.constant([], shape=[0], dtype=dataset.output_types),
+ dense_shape=array_ops.concat(
+ [np.array([0], dtype=np.int64), padded_shape], 0))
+
+ def batch_reduce_fn(state, value):
+ padded_value = sparse_tensor.SparseTensor(
+ indices=value.indices, values=value.values, dense_shape=padded_shape)
+ reshaped_value = sparse_ops.sparse_reshape(
+ padded_value,
+ array_ops.concat(
+ [np.array([1], dtype=np.int64), padded_value.dense_shape], 0))
+ return sparse_ops.sparse_concat(0, [state, reshaped_value])
+
+ reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn)
+ return get_single_element.get_single_element(
+ dataset.apply(grouping.group_by_reducer(key_fn, reducer)))
+
+
+class _UnbatchDataset(dataset_ops.UnaryDataset):
+ """A dataset that splits the elements of its input into multiple elements."""
+
+ def __init__(self, input_dataset):
+ """See `unbatch()` for more details."""
+ super(_UnbatchDataset, self).__init__(input_dataset)
+ flat_shapes = nest.flatten(input_dataset.output_shapes)
+ if any(s.ndims == 0 for s in flat_shapes):
+ raise ValueError("Cannot unbatch an input with scalar components.")
+ known_batch_dim = tensor_shape.Dimension(None)
+ for s in flat_shapes:
+ try:
+ known_batch_dim = known_batch_dim.merge_with(s[0])
+ except ValueError:
+ raise ValueError("Cannot unbatch an input whose components have "
+ "different batch sizes.")
+ self._input_dataset = input_dataset
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.unbatch_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+ @property
+ def output_shapes(self):
+ return nest.map_structure(lambda s: s[1:],
+ self._input_dataset.output_shapes)
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+
+@tf_export("data.experimental.unbatch")
+def unbatch():
+ """Splits elements of a dataset into multiple elements on the batch dimension.
+
+ For example, if elements of the dataset are shaped `[B, a0, a1, ...]`,
+ where `B` may vary for each input element, then for each element in the
+ dataset, the unbatched dataset will contain `B` consecutive elements
+ of shape `[a0, a1, ...]`.
+
+ ```python
+ # NOTE: The following example uses `{ ... }` to represent the contents
+ # of a dataset.
+ a = { ['a', 'b', 'c'], ['a', 'b'], ['a', 'b', 'c', 'd'] }
+
+ a.apply(tf.data.experimental.unbatch()) == {
+ 'a', 'b', 'c', 'a', 'b', 'a', 'b', 'c', 'd'}
+ ```
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+
+ def _apply_fn(dataset):
+ """Function from `Dataset` to `Dataset` that applies the transformation."""
+ if not sparse.any_sparse(dataset.output_classes):
+ return _UnbatchDataset(dataset)
+
+ # NOTE(mrry): We must ensure that any SparseTensors in `dataset`
+ # are normalized to the rank-1 dense representation, so that the
+ # sparse-oblivious unbatching logic will slice them
+ # appropriately. This leads to a somewhat inefficient re-encoding step
+ # for all SparseTensor components.
+ # TODO(mrry): Consider optimizing this in future
+ # if it turns out to be a bottleneck.
+ def normalize(arg, *rest):
+ if rest:
+ return sparse.serialize_many_sparse_tensors((arg,) + rest)
+ else:
+ return sparse.serialize_many_sparse_tensors(arg)
+
+ normalized_dataset = dataset.map(normalize)
+
+ # NOTE(mrry): Our `map()` has lost information about the sparseness
+ # of any SparseTensor components, so re-apply the structure of the
+ # original dataset.
+ restructured_dataset = _RestructuredDataset(
+ normalized_dataset,
+ dataset.output_types,
+ dataset.output_shapes,
+ dataset.output_classes,
+ allow_unsafe_cast=True)
+ return _UnbatchDataset(restructured_dataset)
+
+ return _apply_fn
+
+
+class _DenseToSparseBatchDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` that batches ragged dense elements into `tf.SparseTensor`s."""
+
+ def __init__(self, input_dataset, batch_size, row_shape):
+ """See `Dataset.dense_to_sparse_batch()` for more details."""
+ super(_DenseToSparseBatchDataset, self).__init__(input_dataset)
+ if not isinstance(input_dataset.output_types, dtypes.DType):
+ raise TypeError("DenseToSparseDataset requires an input whose elements "
+ "have a single component, whereas the input has %r." %
+ input_dataset.output_types)
+ self._input_dataset = input_dataset
+ self._batch_size = batch_size
+ self._row_shape = row_shape
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.dense_to_sparse_batch_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._batch_size,
+ row_shape=convert.partial_shape_to_tensor(self._row_shape),
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_classes(self):
+ return sparse_tensor.SparseTensor
+
+ @property
+ def output_shapes(self):
+ return tensor_shape.vector(None).concatenate(self._row_shape)
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+
+class _RestructuredDataset(dataset_ops.UnaryDataset):
+ """An internal helper for changing the structure and shape of a dataset."""
+
+ def __init__(self,
+ dataset,
+ output_types,
+ output_shapes=None,
+ output_classes=None,
+ allow_unsafe_cast=False):
+ """Creates a new dataset with the given output types and shapes.
+
+ The given `dataset` must have a structure that is convertible:
+ * `dataset.output_types` must be the same as `output_types` module nesting.
+ * Each shape in `dataset.output_shapes` must be compatible with each shape
+ in `output_shapes` (if given).
+
+ Note: This helper permits "unsafe casts" for shapes, equivalent to using
+ `tf.Tensor.set_shape()` where domain-specific knowledge is available.
+
+ Args:
+ dataset: A `Dataset` object.
+ output_types: A nested structure of `tf.DType` objects.
+ output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects.
+ If omitted, the shapes will be inherited from `dataset`.
+ output_classes: (Optional.) A nested structure of class types.
+ If omitted, the class types will be inherited from `dataset`.
+ allow_unsafe_cast: (Optional.) If `True`, the caller may switch the
+ reported output types and shapes of the restructured dataset, e.g. to
+ switch a sparse tensor represented as `tf.variant` to its user-visible
+ type and shape.
+
+ Raises:
+ ValueError: If either `output_types` or `output_shapes` is not compatible
+ with the structure of `dataset`.
+ """
+ super(_RestructuredDataset, self).__init__(dataset)
+ self._input_dataset = dataset
+
+ if not allow_unsafe_cast:
+ # Validate that the types are compatible.
+ output_types = nest.map_structure(dtypes.as_dtype, output_types)
+ flat_original_types = nest.flatten(dataset.output_types)
+ flat_new_types = nest.flatten(output_types)
+ if flat_original_types != flat_new_types:
+ raise ValueError(
+ "Dataset with output types %r cannot be restructured to have "
+ "output types %r" % (dataset.output_types, output_types))
+
+ self._output_types = output_types
+
+ if output_shapes is None:
+ # Inherit shapes from the original `dataset`.
+ self._output_shapes = nest.pack_sequence_as(output_types,
+ nest.flatten(
+ dataset.output_shapes))
+ else:
+ if not allow_unsafe_cast:
+ # Validate that the shapes are compatible.
+ nest.assert_same_structure(output_types, output_shapes)
+ flat_original_shapes = nest.flatten(dataset.output_shapes)
+ flat_new_shapes = nest.flatten_up_to(output_types, output_shapes)
+
+ for original_shape, new_shape in zip(flat_original_shapes,
+ flat_new_shapes):
+ if not original_shape.is_compatible_with(new_shape):
+ raise ValueError(
+ "Dataset with output shapes %r cannot be restructured to have "
+ "incompatible output shapes %r" % (dataset.output_shapes,
+ output_shapes))
+ self._output_shapes = nest.map_structure_up_to(
+ output_types, tensor_shape.as_shape, output_shapes)
+ if output_classes is None:
+ # Inherit class types from the original `dataset`.
+ self._output_classes = nest.pack_sequence_as(output_types,
+ nest.flatten(
+ dataset.output_classes))
+ else:
+ self._output_classes = output_classes
+
+ def _as_variant_tensor(self):
+ return self._input_dataset._as_variant_tensor() # pylint: disable=protected-access
+
+ @property
+ def output_classes(self):
+ return self._output_classes
+
+ @property
+ def output_types(self):
+ return self._output_types
+
+ @property
+ def output_shapes(self):
+ return self._output_shapes
+
+
+class _MapAndBatchDataset(dataset_ops.MapDataset):
+ """A `Dataset` that maps a function over a batch of elements."""
+
+ def __init__(self, input_dataset, map_func, batch_size, num_parallel_calls,
+ drop_remainder):
+ """See `Dataset.map()` for details."""
+ super(_MapAndBatchDataset, self).__init__(input_dataset, map_func)
+ self._batch_size_t = ops.convert_to_tensor(
+ batch_size, dtype=dtypes.int64, name="batch_size")
+ self._num_parallel_calls_t = ops.convert_to_tensor(
+ num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls")
+ self._drop_remainder_t = ops.convert_to_tensor(
+ drop_remainder, dtype=dtypes.bool, name="drop_remainder")
+
+ self._batch_size = batch_size
+ self._drop_remainder = drop_remainder
+
+ def _as_variant_tensor(self):
+ # pylint: disable=protected-access
+ input_resource = self._input_dataset._as_variant_tensor()
+ return gen_dataset_ops.map_and_batch_dataset_v2(
+ input_resource,
+ self._map_func.captured_inputs,
+ f=self._map_func,
+ batch_size=self._batch_size_t,
+ num_parallel_calls=self._num_parallel_calls_t,
+ drop_remainder=self._drop_remainder_t,
+ **dataset_ops.flat_structure(self))
+ # pylint: enable=protected-access
+
+ @property
+ def output_shapes(self):
+ dim = self._batch_size if self._drop_remainder else None
+ return nest.pack_sequence_as(self._output_shapes, [
+ tensor_shape.vector(dim).concatenate(s)
+ for s in nest.flatten(self._output_shapes)
+ ])
+
+ @property
+ def output_types(self):
+ return self._output_types
+
+
+@tf_export("data.experimental.map_and_batch")
+def map_and_batch(map_func,
+ batch_size,
+ num_parallel_batches=None,
+ drop_remainder=False,
+ num_parallel_calls=None):
+ """Fused implementation of `map` and `batch`.
+
+ Maps `map_func` across `batch_size` consecutive elements of this dataset
+ and then combines them into a batch. Functionally, it is equivalent to `map`
+ followed by `batch`. However, by fusing the two transformations together, the
+ implementation can be more efficient. Surfacing this transformation in the API
+ is temporary. Once automatic input pipeline optimization is implemented,
+ the fusing of `map` and `batch` will happen automatically and this API will be
+ deprecated.
+
+ Args:
+ map_func: A function mapping a nested structure of tensors to another
+ nested structure of tensors.
+ batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
+ consecutive elements of this dataset to combine in a single batch.
+ num_parallel_batches: (Optional.) A `tf.int64` scalar `tf.Tensor`,
+ representing the number of batches to create in parallel. On one hand,
+ higher values can help mitigate the effect of stragglers. On the other
+ hand, higher values can increase contention if CPU is scarce.
+ drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
+ whether the last batch should be dropped in case its size is smaller than
+ desired; the default behavior is not to drop the smaller batch.
+ num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`,
+ representing the number of elements to process in parallel. If not
+ specified, `batch_size * num_parallel_batches` elements will be
+ processed in parallel.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+
+ Raises:
+ ValueError: If both `num_parallel_batches` and `num_parallel_calls` are
+ specified.
+ """
+
+ if num_parallel_batches is None and num_parallel_calls is None:
+ num_parallel_calls = batch_size
+ elif num_parallel_batches is not None and num_parallel_calls is None:
+ num_parallel_calls = batch_size * num_parallel_batches
+ elif num_parallel_batches is not None and num_parallel_calls is not None:
+ raise ValueError("The `num_parallel_batches` and `num_parallel_calls` "
+ "arguments are mutually exclusive.")
+
+ def _apply_fn(dataset):
+ return _MapAndBatchDataset(dataset, map_func, batch_size,
+ num_parallel_calls, drop_remainder)
+
+ return _apply_fn
diff --git a/tensorflow/python/data/experimental/ops/counter.py b/tensorflow/python/data/experimental/ops/counter.py
new file mode 100644
index 0000000000..42200eaef9
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/counter.py
@@ -0,0 +1,55 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""The Counter Dataset."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.ops import scan_ops
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("data.experimental.Counter")
+def Counter(start=0, step=1, dtype=dtypes.int64):
+ """Creates a `Dataset` that counts from `start` in steps of size `step`.
+
+ For example:
+
+ ```python
+ Dataset.count() == [0, 1, 2, ...)
+ Dataset.count(2) == [2, 3, ...)
+ Dataset.count(2, 5) == [2, 7, 12, ...)
+ Dataset.count(0, -1) == [0, -1, -2, ...)
+ Dataset.count(10, -1) == [10, 9, ...)
+ ```
+
+ Args:
+ start: (Optional.) The starting value for the counter. Defaults to 0.
+ step: (Optional.) The step size for the counter. Defaults to 1.
+ dtype: (Optional.) The data type for counter elements. Defaults to
+ `tf.int64`.
+
+ Returns:
+ A `Dataset` of scalar `dtype` elements.
+ """
+ with ops.name_scope("counter"):
+ start = ops.convert_to_tensor(start, dtype=dtype, name="start")
+ step = ops.convert_to_tensor(step, dtype=dtype, name="step")
+ return dataset_ops.Dataset.from_tensors(0).repeat(None).apply(
+ scan_ops.scan(start, lambda state, _: (state + step, state)))
diff --git a/tensorflow/python/data/experimental/ops/enumerate_ops.py b/tensorflow/python/data/experimental/ops/enumerate_ops.py
new file mode 100644
index 0000000000..a1af98f552
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/enumerate_ops.py
@@ -0,0 +1,60 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Enumerate dataset transformations."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("data.experimental.enumerate_dataset")
+def enumerate_dataset(start=0):
+ """A transformation that enumerate the elements of a dataset.
+
+ It is Similar to python's `enumerate`.
+ For example:
+
+ ```python
+ # NOTE: The following examples use `{ ... }` to represent the
+ # contents of a dataset.
+ a = { 1, 2, 3 }
+ b = { (7, 8), (9, 10) }
+
+ # The nested structure of the `datasets` argument determines the
+ # structure of elements in the resulting dataset.
+ a.apply(tf.data.experimental.enumerate(start=5)) == { (5, 1), (6, 2), (7, 3) }
+ b.apply(tf.data.experimental.enumerate()) == { (0, (7, 8)), (1, (9, 10)) }
+ ```
+
+ Args:
+ start: A `tf.int64` scalar `tf.Tensor`, representing the start
+ value for enumeration.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+
+ def _apply_fn(dataset):
+ max_value = np.iinfo(dtypes.int64.as_numpy_dtype).max
+ return dataset_ops.Dataset.zip((dataset_ops.Dataset.range(start, max_value),
+ dataset))
+
+ return _apply_fn
diff --git a/tensorflow/python/data/experimental/ops/error_ops.py b/tensorflow/python/data/experimental/ops/error_ops.py
new file mode 100644
index 0000000000..82e274b70c
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/error_ops.py
@@ -0,0 +1,78 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Ignore_errors dataset transformations."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.ops import gen_experimental_dataset_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("data.experimental.ignore_errors")
+def ignore_errors():
+ """Creates a `Dataset` from another `Dataset` and silently ignores any errors.
+
+ Use this transformation to produce a dataset that contains the same elements
+ as the input, but silently drops any elements that caused an error. For
+ example:
+
+ ```python
+ dataset = tf.data.Dataset.from_tensor_slices([1., 2., 0., 4.])
+
+ # Computing `tf.check_numerics(1. / 0.)` will raise an InvalidArgumentError.
+ dataset = dataset.map(lambda x: tf.check_numerics(1. / x, "error"))
+
+ # Using `ignore_errors()` will drop the element that causes an error.
+ dataset =
+ dataset.apply(tf.data.experimental.ignore_errors()) # ==> {1., 0.5, 0.2}
+ ```
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+
+ def _apply_fn(dataset):
+ return _IgnoreErrorsDataset(dataset)
+
+ return _apply_fn
+
+
+class _IgnoreErrorsDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` that silently ignores errors when computing its input."""
+
+ def __init__(self, input_dataset):
+ """See `Dataset.ignore_errors()` for details."""
+ super(_IgnoreErrorsDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+
+ def _as_variant_tensor(self):
+ return gen_experimental_dataset_ops.experimental_ignore_errors_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
diff --git a/tensorflow/python/data/experimental/ops/get_single_element.py b/tensorflow/python/data/experimental/ops/get_single_element.py
new file mode 100644
index 0000000000..132526166c
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/get_single_element.py
@@ -0,0 +1,72 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python wrappers for Datasets and Iterators."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import nest
+from tensorflow.python.data.util import sparse
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("data.experimental.get_single_element")
+def get_single_element(dataset):
+ """Returns the single element in `dataset` as a nested structure of tensors.
+
+ This function enables you to use a `tf.data.Dataset` in a stateless
+ "tensor-in tensor-out" expression, without creating a `tf.data.Iterator`.
+ This can be useful when your preprocessing transformations are expressed
+ as a `Dataset`, and you want to use the transformation at serving time.
+ For example:
+
+ ```python
+ input_batch = tf.placeholder(tf.string, shape=[BATCH_SIZE])
+
+ def preprocessing_fn(input_str):
+ # ...
+ return image, label
+
+ dataset = (tf.data.Dataset.from_tensor_slices(input_batch)
+ .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE)
+ .batch(BATCH_SIZE))
+
+ image_batch, label_batch = tf.data.experimental.get_single_element(dataset)
+ ```
+
+ Args:
+ dataset: A `tf.data.Dataset` object containing a single element.
+
+ Returns:
+ A nested structure of `tf.Tensor` objects, corresponding to the single
+ element of `dataset`.
+
+ Raises:
+ TypeError: if `dataset` is not a `tf.data.Dataset` object.
+ InvalidArgumentError (at runtime): if `dataset` does not contain exactly
+ one element.
+ """
+ if not isinstance(dataset, dataset_ops.Dataset):
+ raise TypeError("`dataset` must be a `tf.data.Dataset` object.")
+
+ nested_ret = nest.pack_sequence_as(
+ dataset.output_types, gen_dataset_ops.dataset_to_single_element(
+ dataset._as_variant_tensor(), # pylint: disable=protected-access
+ **dataset_ops.flat_structure(dataset)))
+ return sparse.deserialize_sparse_tensors(
+ nested_ret, dataset.output_types, dataset.output_shapes,
+ dataset.output_classes)
diff --git a/tensorflow/python/data/experimental/ops/grouping.py b/tensorflow/python/data/experimental/ops/grouping.py
new file mode 100644
index 0000000000..18ba583220
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/grouping.py
@@ -0,0 +1,551 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Grouping dataset transformations."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import nest
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("data.experimental.group_by_reducer")
+def group_by_reducer(key_func, reducer):
+ """A transformation that groups elements and performs a reduction.
+
+ This transformation maps element of a dataset to a key using `key_func` and
+ groups the elements by key. The `reducer` is used to process each group; its
+ `init_func` is used to initialize state for each group when it is created, the
+ `reduce_func` is used to update the state every time an element is mapped to
+ the matching group, and the `finalize_func` is used to map the final state to
+ an output value.
+
+ Args:
+ key_func: A function mapping a nested structure of tensors
+ (having shapes and types defined by `self.output_shapes` and
+ `self.output_types`) to a scalar `tf.int64` tensor.
+ reducer: An instance of `Reducer`, which captures the reduction logic using
+ the `init_func`, `reduce_func`, and `finalize_func` functions.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+
+ def _apply_fn(dataset):
+ """Function from `Dataset` to `Dataset` that applies the transformation."""
+ return _GroupByReducerDataset(dataset, key_func, reducer)
+
+ return _apply_fn
+
+
+@tf_export("data.experimental.group_by_window")
+def group_by_window(key_func,
+ reduce_func,
+ window_size=None,
+ window_size_func=None):
+ """A transformation that groups windows of elements by key and reduces them.
+
+ This transformation maps each consecutive element in a dataset to a key
+ using `key_func` and groups the elements by key. It then applies
+ `reduce_func` to at most `window_size_func(key)` elements matching the same
+ key. All except the final window for each key will contain
+ `window_size_func(key)` elements; the final window may be smaller.
+
+ You may provide either a constant `window_size` or a window size determined by
+ the key through `window_size_func`.
+
+ Args:
+ key_func: A function mapping a nested structure of tensors
+ (having shapes and types defined by `self.output_shapes` and
+ `self.output_types`) to a scalar `tf.int64` tensor.
+ reduce_func: A function mapping a key and a dataset of up to `window_size`
+ consecutive elements matching that key to another dataset.
+ window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
+ consecutive elements matching the same key to combine in a single
+ batch, which will be passed to `reduce_func`. Mutually exclusive with
+ `window_size_func`.
+ window_size_func: A function mapping a key to a `tf.int64` scalar
+ `tf.Tensor`, representing the number of consecutive elements matching
+ the same key to combine in a single batch, which will be passed to
+ `reduce_func`. Mutually exclusive with `window_size`.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+
+ Raises:
+ ValueError: if neither or both of {`window_size`, `window_size_func`} are
+ passed.
+ """
+ if (window_size is not None and window_size_func or
+ not (window_size is not None or window_size_func)):
+ raise ValueError("Must pass either window_size or window_size_func.")
+
+ if window_size is not None:
+
+ def constant_window_func(unused_key):
+ return ops.convert_to_tensor(window_size, dtype=dtypes.int64)
+
+ window_size_func = constant_window_func
+
+ assert window_size_func is not None
+
+ def _apply_fn(dataset):
+ """Function from `Dataset` to `Dataset` that applies the transformation."""
+ return _GroupByWindowDataset(dataset, key_func, reduce_func,
+ window_size_func)
+
+ return _apply_fn
+
+
+@tf_export("data.experimental.bucket_by_sequence_length")
+def bucket_by_sequence_length(element_length_func,
+ bucket_boundaries,
+ bucket_batch_sizes,
+ padded_shapes=None,
+ padding_values=None,
+ pad_to_bucket_boundary=False,
+ no_padding=False):
+ """A transformation that buckets elements in a `Dataset` by length.
+
+ Elements of the `Dataset` are grouped together by length and then are padded
+ and batched.
+
+ This is useful for sequence tasks in which the elements have variable length.
+ Grouping together elements that have similar lengths reduces the total
+ fraction of padding in a batch which increases training step efficiency.
+
+ Args:
+ element_length_func: function from element in `Dataset` to `tf.int32`,
+ determines the length of the element, which will determine the bucket it
+ goes into.
+ bucket_boundaries: `list<int>`, upper length boundaries of the buckets.
+ bucket_batch_sizes: `list<int>`, batch size per bucket. Length should be
+ `len(bucket_boundaries) + 1`.
+ padded_shapes: Nested structure of `tf.TensorShape` to pass to
+ `tf.data.Dataset.padded_batch`. If not provided, will use
+ `dataset.output_shapes`, which will result in variable length dimensions
+ being padded out to the maximum length in each batch.
+ padding_values: Values to pad with, passed to
+ `tf.data.Dataset.padded_batch`. Defaults to padding with 0.
+ pad_to_bucket_boundary: bool, if `False`, will pad dimensions with unknown
+ size to maximum length in batch. If `True`, will pad dimensions with
+ unknown size to bucket boundary minus 1 (i.e., the maximum length in each
+ bucket), and caller must ensure that the source `Dataset` does not contain
+ any elements with length longer than `max(bucket_boundaries)`.
+ no_padding: `bool`, indicates whether to pad the batch features (features
+ need to be either of type `tf.SparseTensor` or of same shape).
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+
+ Raises:
+ ValueError: if `len(bucket_batch_sizes) != len(bucket_boundaries) + 1`.
+ """
+ with ops.name_scope("bucket_by_seq_length"):
+ if len(bucket_batch_sizes) != (len(bucket_boundaries) + 1):
+ raise ValueError(
+ "len(bucket_batch_sizes) must equal len(bucket_boundaries) + 1")
+
+ batch_sizes = constant_op.constant(bucket_batch_sizes, dtype=dtypes.int64)
+
+ def element_to_bucket_id(*args):
+ """Return int64 id of the length bucket for this element."""
+ seq_length = element_length_func(*args)
+
+ boundaries = list(bucket_boundaries)
+ buckets_min = [np.iinfo(np.int32).min] + boundaries
+ buckets_max = boundaries + [np.iinfo(np.int32).max]
+ conditions_c = math_ops.logical_and(
+ math_ops.less_equal(buckets_min, seq_length),
+ math_ops.less(seq_length, buckets_max))
+ bucket_id = math_ops.reduce_min(array_ops.where(conditions_c))
+
+ return bucket_id
+
+ def window_size_fn(bucket_id):
+ # The window size is set to the batch size for this bucket
+ window_size = batch_sizes[bucket_id]
+ return window_size
+
+ def make_padded_shapes(shapes, none_filler=None):
+ padded = []
+ for shape in nest.flatten(shapes):
+ shape = tensor_shape.TensorShape(shape)
+ shape = [
+ none_filler if d.value is None else d
+ for d in shape
+ ]
+ padded.append(shape)
+ return nest.pack_sequence_as(shapes, padded)
+
+ def batching_fn(bucket_id, grouped_dataset):
+ """Batch elements in dataset."""
+ batch_size = window_size_fn(bucket_id)
+ if no_padding:
+ return grouped_dataset.batch(batch_size)
+ none_filler = None
+ if pad_to_bucket_boundary:
+ err_msg = ("When pad_to_bucket_boundary=True, elements must have "
+ "length < max(bucket_boundaries).")
+ check = check_ops.assert_less(
+ bucket_id,
+ constant_op.constant(len(bucket_batch_sizes) - 1,
+ dtype=dtypes.int64),
+ message=err_msg)
+ with ops.control_dependencies([check]):
+ boundaries = constant_op.constant(bucket_boundaries,
+ dtype=dtypes.int64)
+ bucket_boundary = boundaries[bucket_id]
+ none_filler = bucket_boundary - 1
+ shapes = make_padded_shapes(
+ padded_shapes or grouped_dataset.output_shapes,
+ none_filler=none_filler)
+ return grouped_dataset.padded_batch(batch_size, shapes, padding_values)
+
+ def _apply_fn(dataset):
+ return dataset.apply(
+ group_by_window(element_to_bucket_id, batching_fn,
+ window_size_func=window_size_fn))
+
+ return _apply_fn
+
+
+def _map_x_dataset(map_func):
+ """A transformation that maps `map_func` across its input.
+
+ This transformation is similar to `tf.data.Dataset.map`, but in addition to
+ supporting dense and sparse tensor inputs, it also supports dataset inputs.
+
+ Args:
+ map_func: A function mapping a nested structure of tensors and/or datasets
+ (having shapes and types defined by `self.output_shapes` and
+ `self.output_types`) to another nested structure of tensors and/or
+ datasets.
+
+ Returns:
+ Dataset: A `Dataset`.
+ """
+
+ def _apply_fn(dataset):
+ """Function from `Dataset` to `Dataset` that applies the transformation."""
+ return _MapXDataset(dataset, map_func)
+
+ return _apply_fn
+
+
+class _GroupByReducerDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` that groups its input and performs a reduction."""
+
+ def __init__(self, input_dataset, key_func, reducer):
+ """See `group_by_reducer()` for details."""
+ super(_GroupByReducerDataset, self).__init__(input_dataset)
+
+ self._input_dataset = input_dataset
+
+ self._make_key_func(key_func, input_dataset)
+ self._make_init_func(reducer.init_func)
+ self._make_reduce_func(reducer.reduce_func, input_dataset)
+ self._make_finalize_func(reducer.finalize_func)
+
+ def _make_key_func(self, key_func, input_dataset):
+ """Make wrapping Defun for key_func."""
+ wrapped_func = dataset_ops.StructuredFunctionWrapper(
+ key_func, "tf.data.experimental.group_by_reducer()", input_dataset)
+ if not (
+ wrapped_func.output_types == dtypes.int64 and
+ wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())):
+ raise ValueError(
+ "`key_func` must return a single tf.int64 tensor. "
+ "Got type=%s and shape=%s"
+ % (wrapped_func.output_types, wrapped_func.output_shapes))
+ self._key_func = wrapped_func.function
+
+ def _make_init_func(self, init_func):
+ """Make wrapping Defun for init_func."""
+ wrapped_func = dataset_ops.StructuredFunctionWrapper(
+ init_func,
+ "tf.data.experimental.group_by_reducer()",
+ input_classes=ops.Tensor,
+ input_shapes=tensor_shape.scalar(),
+ input_types=dtypes.int64)
+ self._init_func = wrapped_func.function
+ self._state_classes = wrapped_func.output_classes
+ self._state_shapes = wrapped_func.output_shapes
+ self._state_types = wrapped_func.output_types
+
+ def _make_reduce_func(self, reduce_func, input_dataset):
+ """Make wrapping Defun for reduce_func."""
+
+ # Iteratively rerun the reduce function until reaching a fixed point on
+ # `self._state_shapes`.
+ need_to_rerun = True
+ while need_to_rerun:
+
+ wrapped_func = dataset_ops.StructuredFunctionWrapper(
+ reduce_func,
+ "tf.data.experimental.group_by_reducer()",
+ input_classes=(self._state_classes, input_dataset.output_classes),
+ input_shapes=(self._state_shapes, input_dataset.output_shapes),
+ input_types=(self._state_types, input_dataset.output_types),
+ add_to_graph=False)
+
+ # Extract and validate class information from the returned values.
+ for new_state_class, state_class in zip(
+ nest.flatten(wrapped_func.output_classes),
+ nest.flatten(self._state_classes)):
+ if not issubclass(new_state_class, state_class):
+ raise TypeError(
+ "The element classes for the new state must match the initial "
+ "state. Expected %s; got %s." %
+ (self._state_classes, wrapped_func.output_classes))
+
+ # Extract and validate type information from the returned values.
+ for new_state_type, state_type in zip(
+ nest.flatten(wrapped_func.output_types),
+ nest.flatten(self._state_types)):
+ if new_state_type != state_type:
+ raise TypeError(
+ "The element types for the new state must match the initial "
+ "state. Expected %s; got %s." %
+ (self._state_types, wrapped_func.output_types))
+
+ # Extract shape information from the returned values.
+ flat_state_shapes = nest.flatten(self._state_shapes)
+ flat_new_state_shapes = nest.flatten(wrapped_func.output_shapes)
+ weakened_state_shapes = [
+ original.most_specific_compatible_shape(new)
+ for original, new in zip(flat_state_shapes, flat_new_state_shapes)
+ ]
+
+ need_to_rerun = False
+ for original_shape, weakened_shape in zip(flat_state_shapes,
+ weakened_state_shapes):
+ if original_shape.ndims is not None and (
+ weakened_shape.ndims is None or
+ original_shape.as_list() != weakened_shape.as_list()):
+ need_to_rerun = True
+ break
+
+ if need_to_rerun:
+ self._state_shapes = nest.pack_sequence_as(self._state_shapes,
+ weakened_state_shapes)
+
+ self._reduce_func = wrapped_func.function
+ self._reduce_func.add_to_graph(ops.get_default_graph())
+
+ def _make_finalize_func(self, finalize_func):
+ """Make wrapping Defun for finalize_func."""
+ wrapped_func = dataset_ops.StructuredFunctionWrapper(
+ finalize_func,
+ "tf.data.experimental.group_by_reducer()",
+ input_classes=self._state_classes,
+ input_shapes=self._state_shapes,
+ input_types=self._state_types)
+ self._finalize_func = wrapped_func.function
+ self._output_classes = wrapped_func.output_classes
+ self._output_shapes = wrapped_func.output_shapes
+ self._output_types = wrapped_func.output_types
+
+ @property
+ def output_classes(self):
+ return self._output_classes
+
+ @property
+ def output_shapes(self):
+ return self._output_shapes
+
+ @property
+ def output_types(self):
+ return self._output_types
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.group_by_reducer_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._key_func.captured_inputs,
+ self._init_func.captured_inputs,
+ self._reduce_func.captured_inputs,
+ self._finalize_func.captured_inputs,
+ key_func=self._key_func,
+ init_func=self._init_func,
+ reduce_func=self._reduce_func,
+ finalize_func=self._finalize_func,
+ **dataset_ops.flat_structure(self))
+
+
+class _GroupByWindowDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` that groups its input and performs a windowed reduction."""
+
+ def __init__(self, input_dataset, key_func, reduce_func, window_size_func):
+ """See `group_by_window()` for details."""
+ super(_GroupByWindowDataset, self).__init__(input_dataset)
+
+ self._input_dataset = input_dataset
+
+ self._make_key_func(key_func, input_dataset)
+ self._make_reduce_func(reduce_func, input_dataset)
+ self._make_window_size_func(window_size_func)
+
+ def _make_window_size_func(self, window_size_func):
+ """Make wrapping Defun for window_size_func."""
+ def window_size_func_wrapper(key):
+ return ops.convert_to_tensor(window_size_func(key), dtype=dtypes.int64)
+ wrapped_func = dataset_ops.StructuredFunctionWrapper(
+ window_size_func_wrapper,
+ "tf.data.experimental.group_by_window()",
+ input_classes=ops.Tensor,
+ input_shapes=tensor_shape.scalar(),
+ input_types=dtypes.int64)
+ if not (
+ wrapped_func.output_types == dtypes.int64 and
+ wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())):
+ raise ValueError(
+ "`window_size_func` must return a single tf.int64 scalar tensor.")
+ self._window_size_func = wrapped_func.function
+
+ def _make_key_func(self, key_func, input_dataset):
+ """Make wrapping Defun for key_func."""
+ def key_func_wrapper(*args):
+ return ops.convert_to_tensor(key_func(*args), dtype=dtypes.int64)
+ wrapped_func = dataset_ops.StructuredFunctionWrapper(
+ key_func_wrapper, "tf.data.experimental.group_by_window()",
+ input_dataset)
+ if not (
+ wrapped_func.output_types == dtypes.int64 and
+ wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())):
+ raise ValueError(
+ "`key_func` must return a single tf.int64 scalar tensor.")
+ self._key_func = wrapped_func.function
+
+ def _make_reduce_func(self, reduce_func, input_dataset):
+ """Make wrapping Defun for reduce_func."""
+ nested_dataset = dataset_ops._NestedDatasetComponent(input_dataset) # pylint: disable=protected-access
+ wrapped_func = dataset_ops.StructuredFunctionWrapper(
+ reduce_func,
+ "tf.data.experimental.reduce_by_window()",
+ input_classes=(ops.Tensor, nested_dataset),
+ input_shapes=(tensor_shape.scalar(), nested_dataset),
+ input_types=(dtypes.int64, nested_dataset),
+ experimental_nested_dataset_support=True)
+ if not isinstance(
+ wrapped_func.output_classes, dataset_ops._NestedDatasetComponent): # pylint: disable=protected-access
+ raise TypeError("`reduce_func` must return a `Dataset` object.")
+ self._output_classes = wrapped_func.output_classes.output_classes
+ self._output_types = wrapped_func.output_types.output_types
+ self._output_shapes = wrapped_func.output_shapes.output_shapes
+ self._reduce_func = wrapped_func.function
+
+ @property
+ def output_classes(self):
+ return self._output_classes
+
+ @property
+ def output_shapes(self):
+ return self._output_shapes
+
+ @property
+ def output_types(self):
+ return self._output_types
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.group_by_window_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._key_func.captured_inputs,
+ self._reduce_func.captured_inputs,
+ self._window_size_func.captured_inputs,
+ key_func=self._key_func,
+ reduce_func=self._reduce_func,
+ window_size_func=self._window_size_func,
+ **dataset_ops.flat_structure(self))
+
+
+@tf_export("data.experimental.Reducer")
+class Reducer(object):
+ """A reducer is used for reducing a set of elements.
+
+ A reducer is represented as a tuple of the three functions:
+ 1) initialization function: key => initial state
+ 2) reduce function: (old state, input) => new state
+ 3) finalization function: state => result
+ """
+
+ def __init__(self, init_func, reduce_func, finalize_func):
+ self._init_func = init_func
+ self._reduce_func = reduce_func
+ self._finalize_func = finalize_func
+
+ @property
+ def init_func(self):
+ return self._init_func
+
+ @property
+ def reduce_func(self):
+ return self._reduce_func
+
+ @property
+ def finalize_func(self):
+ return self._finalize_func
+
+
+class _MapXDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` that maps a function over elements in its input."""
+
+ def __init__(self, input_dataset, map_func):
+ """See `map_x_dataset()` for details."""
+ super(_MapXDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+
+ wrapped_func = dataset_ops.StructuredFunctionWrapper(
+ map_func,
+ "tf.data.experimental.map_x_dataset()",
+ input_dataset,
+ experimental_nested_dataset_support=True)
+ self._output_classes = wrapped_func.output_classes
+ self._output_shapes = wrapped_func.output_shapes
+ self._output_types = wrapped_func.output_types
+ self._map_func = wrapped_func.function
+
+ def _as_variant_tensor(self):
+ input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access
+ return gen_dataset_ops.map_dataset(
+ input_t,
+ self._map_func.captured_inputs,
+ f=self._map_func,
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_classes(self):
+ return self._output_classes
+
+ @property
+ def output_shapes(self):
+ return self._output_shapes
+
+ @property
+ def output_types(self):
+ return self._output_types
diff --git a/tensorflow/contrib/data/python/ops/indexed_dataset_ops.py b/tensorflow/python/data/experimental/ops/indexed_dataset_ops.py
index 9c06474a2f..9c06474a2f 100644
--- a/tensorflow/contrib/data/python/ops/indexed_dataset_ops.py
+++ b/tensorflow/python/data/experimental/ops/indexed_dataset_ops.py
diff --git a/tensorflow/python/data/experimental/ops/interleave_ops.py b/tensorflow/python/data/experimental/ops/interleave_ops.py
new file mode 100644
index 0000000000..a3c094859e
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/interleave_ops.py
@@ -0,0 +1,262 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Non-deterministic dataset transformations."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.ops import random_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.ops import readers
+from tensorflow.python.data.util import nest
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_experimental_dataset_ops
+from tensorflow.python.ops import gen_stateless_random_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("data.experimental.parallel_interleave")
+def parallel_interleave(map_func,
+ cycle_length,
+ block_length=1,
+ sloppy=False,
+ buffer_output_elements=None,
+ prefetch_input_elements=None):
+ """A parallel version of the `Dataset.interleave()` transformation.
+
+ `parallel_interleave()` maps `map_func` across its input to produce nested
+ datasets, and outputs their elements interleaved. Unlike
+ `tf.data.Dataset.interleave`, it gets elements from `cycle_length` nested
+ datasets in parallel, which increases the throughput, especially in the
+ presence of stragglers. Furthermore, the `sloppy` argument can be used to
+ improve performance, by relaxing the requirement that the outputs are produced
+ in a deterministic order, and allowing the implementation to skip over nested
+ datasets whose elements are not readily available when requested.
+
+ Example usage:
+
+ ```python
+ # Preprocess 4 files concurrently.
+ filenames = tf.data.Dataset.list_files("/path/to/data/train*.tfrecords")
+ dataset = filenames.apply(
+ tf.data.experimental.parallel_interleave(
+ lambda filename: tf.data.TFRecordDataset(filename),
+ cycle_length=4))
+ ```
+
+ WARNING: If `sloppy` is `True`, the order of produced elements is not
+ deterministic.
+
+ Args:
+ map_func: A function mapping a nested structure of tensors to a `Dataset`.
+ cycle_length: The number of input `Dataset`s to interleave from in parallel.
+ block_length: The number of consecutive elements to pull from an input
+ `Dataset` before advancing to the next input `Dataset`.
+ sloppy: If false, elements are produced in deterministic order. Otherwise,
+ the implementation is allowed, for the sake of expediency, to produce
+ elements in a non-deterministic order.
+ buffer_output_elements: The number of elements each iterator being
+ interleaved should buffer (similar to the `.prefetch()` transformation for
+ each interleaved iterator).
+ prefetch_input_elements: The number of input elements to transform to
+ iterators before they are needed for interleaving.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+ def _apply_fn(dataset):
+ return readers.ParallelInterleaveDataset(
+ dataset, map_func, cycle_length, block_length, sloppy,
+ buffer_output_elements, prefetch_input_elements)
+
+ return _apply_fn
+
+
+class _DirectedInterleaveDataset(dataset_ops.Dataset):
+ """A substitute for `Dataset.interleave()` on a fixed list of datasets."""
+
+ def __init__(self, selector_input, data_inputs):
+ self._selector_input = selector_input
+ self._data_inputs = list(data_inputs)
+
+ for data_input in data_inputs[1:]:
+ if (data_input.output_types != data_inputs[0].output_types or
+ data_input.output_classes != data_inputs[0].output_classes):
+ raise TypeError("All datasets must have the same type and class.")
+
+ def _as_variant_tensor(self):
+ # pylint: disable=protected-access
+ return (
+ gen_experimental_dataset_ops.experimental_directed_interleave_dataset(
+ self._selector_input._as_variant_tensor(), [
+ data_input._as_variant_tensor()
+ for data_input in self._data_inputs
+ ], **dataset_ops.flat_structure(self)))
+ # pylint: enable=protected-access
+
+ def _inputs(self):
+ return [self._selector_input] + self._data_inputs
+
+ @property
+ def output_classes(self):
+ return self._data_inputs[0].output_classes
+
+ @property
+ def output_shapes(self):
+ ret = self._data_inputs[0].output_shapes
+ for data_input in self._data_inputs[1:]:
+ ret = nest.pack_sequence_as(ret, [
+ ts1.most_specific_compatible_shape(ts2) for (ts1, ts2) in zip(
+ nest.flatten(ret), nest.flatten(data_input.output_shapes))
+ ])
+ return ret
+
+ @property
+ def output_types(self):
+ return self._data_inputs[0].output_types
+
+
+@tf_export("data.experimental.sample_from_datasets")
+def sample_from_datasets(datasets, weights=None, seed=None):
+ """Samples elements at random from the datasets in `datasets`.
+
+ Args:
+ datasets: A list of `tf.data.Dataset` objects with compatible structure.
+ weights: (Optional.) A list of `len(datasets)` floating-point values where
+ `weights[i]` represents the probability with which an element should be
+ sampled from `datasets[i]`, or a `tf.data.Dataset` object where each
+ element is such a list. Defaults to a uniform distribution across
+ `datasets`.
+ seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
+ random seed that will be used to create the distribution. See
+ `tf.set_random_seed` for behavior.
+
+ Returns:
+ A dataset that interleaves elements from `datasets` at random, according to
+ `weights` if provided, otherwise with uniform probability.
+
+ Raises:
+ TypeError: If the `datasets` or `weights` arguments have the wrong type.
+ ValueError: If the `weights` argument is specified and does not match the
+ length of the `datasets` element.
+ """
+ num_datasets = len(datasets)
+ if not isinstance(weights, dataset_ops.Dataset):
+ if weights is None:
+ # Select inputs with uniform probability.
+ logits = [[1.0] * num_datasets]
+
+ else:
+ # Use the given `weights` as the probability of choosing the respective
+ # input.
+ weights = ops.convert_to_tensor(weights, name="weights")
+ if weights.dtype not in (dtypes.float32, dtypes.float64):
+ raise TypeError("`weights` must be convertible to a tensor of "
+ "`tf.float32` or `tf.float64` elements.")
+ if not weights.shape.is_compatible_with([num_datasets]):
+ raise ValueError(
+ "`weights` must be a vector of length `len(datasets)`.")
+
+ # The `stateless_multinomial()` op expects log-probabilities, as opposed
+ # to weights.
+ logits = array_ops.expand_dims(math_ops.log(weights, name="logits"), 0)
+
+ # NOTE(mrry): We only specialize when `weights` is not a `Dataset`. When it
+ # is a `Dataset`, it is possible that evaluating it has a side effect the
+ # user depends on.
+ if len(datasets) == 1:
+ return datasets[0]
+
+ def select_dataset_constant_logits(seed):
+ return array_ops.squeeze(
+ gen_stateless_random_ops.stateless_multinomial(logits, 1, seed=seed),
+ axis=[0, 1])
+
+ selector_input = dataset_ops.MapDataset(
+ random_ops.RandomDataset(seed).batch(2),
+ select_dataset_constant_logits,
+ use_inter_op_parallelism=False)
+
+ else:
+ # Use each element of the given `weights` dataset as the probability of
+ # choosing the respective input.
+
+ # The `stateless_multinomial()` op expects log-probabilities, as opposed to
+ # weights.
+ logits_ds = weights.map(lambda *p: math_ops.log(p, name="logits"))
+
+ def select_dataset_varying_logits(logits, seed):
+ return array_ops.squeeze(
+ gen_stateless_random_ops.stateless_multinomial(logits, 1, seed=seed),
+ axis=[0, 1])
+
+ logits_and_seeds = dataset_ops.Dataset.zip(
+ (logits_ds, random_ops.RandomDataset(seed).batch(2)))
+ selector_input = dataset_ops.MapDataset(
+ logits_and_seeds,
+ select_dataset_varying_logits,
+ use_inter_op_parallelism=False)
+
+ return _DirectedInterleaveDataset(selector_input, datasets)
+
+
+@tf_export("data.experimental.choose_from_datasets")
+def choose_from_datasets(datasets, choice_dataset):
+ """Creates a dataset that deterministically chooses elements from `datasets`.
+
+ For example, given the following datasets:
+
+ ```python
+ datasets = [tf.data.Dataset.from_tensors("foo").repeat(),
+ tf.data.Dataset.from_tensors("bar").repeat(),
+ tf.data.Dataset.from_tensors("baz").repeat()]
+
+ # Define a dataset containing `[0, 1, 2, 0, 1, 2, 0, 1, 2]`.
+ choice_dataset = tf.data.Dataset.range(3).repeat(3)
+
+ result = tf.data.experimental.choose_from_datasets(datasets, choice_dataset)
+ ```
+
+ The elements of `result` will be:
+
+ ```
+ "foo", "bar", "baz", "foo", "bar", "baz", "foo", "bar", "baz"
+ ```
+
+ Args:
+ datasets: A list of `tf.data.Dataset` objects with compatible structure.
+ choice_dataset: A `tf.data.Dataset` of scalar `tf.int64` tensors between
+ `0` and `len(datasets) - 1`.
+
+ Returns:
+ A dataset that interleaves elements from `datasets` according to the values
+ of `choice_dataset`.
+
+ Raises:
+ TypeError: If the `datasets` or `choice_dataset` arguments have the wrong
+ type.
+ """
+ if not (choice_dataset.output_types == dtypes.int64
+ and choice_dataset.output_shapes.is_compatible_with(
+ tensor_shape.scalar())
+ and choice_dataset.output_classes == ops.Tensor):
+ raise TypeError("`choice_dataset` must be a dataset of scalar "
+ "`tf.int64` tensors.")
+ return _DirectedInterleaveDataset(choice_dataset, datasets)
diff --git a/tensorflow/python/data/experimental/ops/iterator_ops.py b/tensorflow/python/data/experimental/ops/iterator_ops.py
new file mode 100644
index 0000000000..72d7d58f06
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/iterator_ops.py
@@ -0,0 +1,268 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Iterator ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.data.ops import optional_ops
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.training import basic_session_run_hooks
+from tensorflow.python.training import checkpoint_management
+from tensorflow.python.training import saver as saver_lib
+from tensorflow.python.training import session_run_hook
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("data.experimental.make_saveable_from_iterator")
+def make_saveable_from_iterator(iterator):
+ """Returns a SaveableObject for saving/restore iterator state using Saver.
+
+ Args:
+ iterator: Iterator.
+
+ For example:
+
+ ```python
+ with tf.Graph().as_default():
+ ds = tf.data.Dataset.range(10)
+ iterator = ds.make_initializable_iterator()
+ # Build the iterator SaveableObject.
+ saveable_obj = tf.data.experimental.make_saveable_from_iterator(iterator)
+ # Add the SaveableObject to the SAVEABLE_OBJECTS collection so
+ # it can be automatically saved using Saver.
+ tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, saveable_obj)
+ saver = tf.train.Saver()
+
+ while continue_training:
+ ... Perform training ...
+ if should_save_checkpoint:
+ saver.save()
+ ```
+
+ Note: When restoring the iterator, the existing iterator state is completely
+ discarded. This means that any changes you may have made to the Dataset
+ graph will be discarded as well! This includes the new Dataset graph
+ that you may have built during validation. So, while running validation,
+ make sure to run the initializer for the validation input pipeline after
+ restoring the checkpoint.
+
+ Note: Not all iterators support checkpointing yet. Attempting to save the
+ state of an unsupported iterator will throw an error.
+ """
+ return _Saveable(iterator._iterator_resource) # pylint: disable=protected-access
+
+
+class _Saveable(saver_lib.BaseSaverBuilder.SaveableObject):
+ """SaveableObject for saving/restoring iterator state."""
+
+ def __init__(self, iterator_resource):
+ serialized_iterator = gen_dataset_ops.serialize_iterator(iterator_resource)
+ specs = [
+ saver_lib.BaseSaverBuilder.SaveSpec(serialized_iterator, "",
+ iterator_resource.name + "-state")
+ ]
+ super(_Saveable, self).__init__(iterator_resource, specs,
+ iterator_resource.name)
+
+ def restore(self, restored_tensors, unused_restored_shapes):
+ with ops.colocate_with(self.op):
+ return gen_dataset_ops.deserialize_iterator(self.op, restored_tensors[0])
+
+
+@tf_export("data.experimental.CheckpointInputPipelineHook")
+class CheckpointInputPipelineHook(session_run_hook.SessionRunHook):
+ """Checkpoints input pipeline state every N steps or seconds.
+
+ This hook saves the state of the iterators in the `Graph` so that when
+ training is resumed the input pipeline continues from where it left off.
+ This could potentially avoid overfitting in certain pipelines where the
+ number of training steps per eval are small compared to the dataset
+ size or if the training pipeline is pre-empted.
+
+ Differences from `CheckpointSaverHook`:
+ 1. Saves only the input pipelines in the "iterators" collection and not the
+ global variables or other saveable objects.
+ 2. Does not write the `GraphDef` and `MetaGraphDef` to the summary.
+
+ Example of checkpointing the training pipeline:
+
+ ```python
+ est = tf.estimator.Estimator(model_fn)
+ while True:
+ est.train(
+ train_input_fn,
+ hooks=[tf.data.experimental.CheckpointInputPipelineHook(est)],
+ steps=train_steps_per_eval)
+ # Note: We do not pass the hook here.
+ metrics = est.evaluate(eval_input_fn)
+ if should_stop_the_training(metrics):
+ break
+ ```
+
+ This hook should be used if the input pipeline state needs to be saved
+ separate from the model checkpoint. Doing so may be useful for a few reasons:
+ 1. The input pipeline checkpoint may be large, if there are large shuffle
+ or prefetch buffers for instance, and may bloat the checkpoint size.
+ 2. If the input pipeline is shared between training and validation, restoring
+ the checkpoint during validation may override the validation input
+ pipeline.
+
+ For saving the input pipeline checkpoint alongside the model weights use
+ `tf.data.experimental.make_saveable_from_iterator` directly to create a
+ `SaveableObject` and add to the `SAVEABLE_OBJECTS` collection. Note, however,
+ that you will need to be careful not to restore the training iterator during
+ eval. You can do that by not adding the iterator to the SAVEABLE_OBJECTS
+ collector when building the eval graph.
+ """
+
+ def __init__(self, estimator):
+ """Initializes a `CheckpointInputPipelineHook`.
+
+ Args:
+ estimator: Estimator.
+
+ Raises:
+ ValueError: One of `save_steps` or `save_secs` should be set.
+ ValueError: At most one of saver or scaffold should be set.
+ """
+ # `checkpoint_basename` is "input.ckpt" for non-distributed pipelines or
+ # of the form "input_<task_type>_<task_id>.ckpt" for distributed pipelines.
+ # Note: The default `checkpoint_basename` used by `CheckpointSaverHook` is
+ # "model.ckpt". We intentionally choose the input pipeline checkpoint prefix
+ # to be different to avoid conflicts with the model checkpoint.
+
+ # pylint: disable=protected-access
+ checkpoint_prefix = "input"
+ if estimator._config.num_worker_replicas > 1:
+ # Distributed setting.
+ suffix = "_{}_{}".format(estimator._config.task_type,
+ estimator._config.task_id)
+ checkpoint_prefix += suffix
+ # pylint: enable=protected-access
+
+ # We use a composition paradigm instead of inheriting from
+ # `CheckpointSaverHook` because `Estimator` does an `isinstance` check
+ # to check whether a `CheckpointSaverHook` is already present in the list
+ # of hooks and if not, adds one. Inheriting from `CheckpointSaverHook`
+ # would thwart this behavior. This hook checkpoints *only the iterators*
+ # and not the graph variables.
+ self._checkpoint_saver_hook = basic_session_run_hooks.CheckpointSaverHook(
+ estimator.model_dir,
+ save_secs=estimator._config.save_checkpoints_secs, # pylint: disable=protected-access
+ save_steps=estimator._config.save_checkpoints_steps, # pylint: disable=protected-access
+ checkpoint_basename=checkpoint_prefix + ".ckpt")
+
+ # Name for the protocol buffer file that will contain the list of most
+ # recent checkpoints stored as a `CheckpointState` protocol buffer.
+ # This file, kept in the same directory as the checkpoint files, is
+ # automatically managed by the `Saver` to keep track of recent checkpoints.
+ # The default name used by the `Saver` for this file is "checkpoint". Here
+ # we use the name "checkpoint_<checkpoint_prefix>" so that in case the
+ # `checkpoint_dir` is the same as the model checkpoint directory, there are
+ # no conflicts during restore.
+ self._latest_filename = "checkpoint_" + checkpoint_prefix
+ self._first_run = True
+
+ def begin(self):
+ # Build a Saver that saves all iterators in the `GLOBAL_ITERATORS`
+ # collection if no `Saver` or `Scaffold` is provided.
+ # pylint: disable=protected-access
+ if (self._checkpoint_saver_hook._saver is None and
+ self._checkpoint_saver_hook._scaffold is None):
+ iterators = ops.get_collection(iterator_ops.GLOBAL_ITERATORS)
+ saveables = [_Saveable(i) for i in iterators]
+ self._checkpoint_saver_hook._saver = _CustomSaver(saveables,
+ self._latest_filename)
+ # pylint: enable=protected-access
+ self._checkpoint_saver_hook.begin()
+
+ def _restore_or_save_initial_ckpt(self, session):
+ # Ideally this should be run in after_create_session but is not for the
+ # following reason:
+ # Currently there is no way of enforcing an order of running the
+ # `SessionRunHooks`. Hence it is possible that the `_DatasetInitializerHook`
+ # is run *after* this hook. That is troublesome because
+ # 1. If a checkpoint exists and this hook restores it, the initializer hook
+ # will override it.
+ # 2. If no checkpoint exists, this hook will try to save an initialized
+ # iterator which will result in an exception.
+ #
+ # As a temporary fix we enter the following implicit contract between this
+ # hook and the _DatasetInitializerHook.
+ # 1. The _DatasetInitializerHook initializes the iterator in the call to
+ # after_create_session.
+ # 2. This hook saves the iterator on the first call to `before_run()`, which
+ # is guaranteed to happen after `after_create_session()` of all hooks
+ # have been run.
+
+ # Check if there is an existing checkpoint. If so, restore from it.
+ # pylint: disable=protected-access
+ latest_checkpoint_path = checkpoint_management.latest_checkpoint(
+ self._checkpoint_saver_hook._checkpoint_dir,
+ latest_filename=self._latest_filename)
+ if latest_checkpoint_path:
+ self._checkpoint_saver_hook._get_saver().restore(session,
+ latest_checkpoint_path)
+ else:
+ # The checkpoint saved here is the state at step "global_step".
+ # Note: We do not save the GraphDef or MetaGraphDef here.
+ global_step = session.run(self._checkpoint_saver_hook._global_step_tensor)
+ self._checkpoint_saver_hook._save(session, global_step)
+ self._checkpoint_saver_hook._timer.update_last_triggered_step(global_step)
+ # pylint: enable=protected-access
+
+ def before_run(self, run_context):
+ if self._first_run:
+ self._restore_or_save_initial_ckpt(run_context.session)
+ self._first_run = False
+ return self._checkpoint_saver_hook.before_run(run_context)
+
+ def after_run(self, run_context, run_values):
+ self._checkpoint_saver_hook.after_run(run_context, run_values)
+
+ def end(self, session):
+ self._checkpoint_saver_hook.end(session)
+
+
+class _CustomSaver(saver_lib.Saver):
+ """`Saver` with a different default `latest_filename`.
+
+ This is used in the `CheckpointInputPipelineHook` to avoid conflicts with
+ the model ckpt saved by the `CheckpointSaverHook`.
+ """
+
+ def __init__(self, var_list, latest_filename):
+ super(_CustomSaver, self).__init__(var_list)
+ self._latest_filename = latest_filename
+
+ def save(self,
+ sess,
+ save_path,
+ global_step=None,
+ latest_filename=None,
+ meta_graph_suffix="meta",
+ write_meta_graph=True,
+ write_state=True,
+ strip_default_attrs=False):
+ return super(_CustomSaver, self).save(
+ sess, save_path, global_step, latest_filename or self._latest_filename,
+ meta_graph_suffix, write_meta_graph, write_state, strip_default_attrs)
+
+
+tf_export("data.experimental.Optional")(optional_ops.Optional)
+tf_export("data.experimental.get_next_as_optional")(
+ iterator_ops.get_next_as_optional)
diff --git a/tensorflow/contrib/data/python/ops/map_defun.py b/tensorflow/python/data/experimental/ops/map_defun.py
index 3d0d0993c9..3d0d0993c9 100644
--- a/tensorflow/contrib/data/python/ops/map_defun.py
+++ b/tensorflow/python/data/experimental/ops/map_defun.py
diff --git a/tensorflow/contrib/data/python/ops/optimization.py b/tensorflow/python/data/experimental/ops/optimization.py
index 30348ede36..30348ede36 100644
--- a/tensorflow/contrib/data/python/ops/optimization.py
+++ b/tensorflow/python/data/experimental/ops/optimization.py
diff --git a/tensorflow/python/data/experimental/ops/parsing_ops.py b/tensorflow/python/data/experimental/ops/parsing_ops.py
new file mode 100644
index 0000000000..6615b9022a
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/parsing_ops.py
@@ -0,0 +1,152 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Experimental `dataset` API for parsing example."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import nest
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import parsing_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+class _ParseExampleDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` that parses `example` dataset into a `dict` dataset."""
+
+ def __init__(self, input_dataset, features, num_parallel_calls):
+ super(_ParseExampleDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+ if not all(types == dtypes.string
+ for types in nest.flatten(input_dataset.output_types)):
+ raise TypeError("Input dataset should be a dataset of vectors of strings")
+ self._num_parallel_calls = num_parallel_calls
+ # pylint: disable=protected-access
+ self._features = parsing_ops._prepend_none_dimension(features)
+ # sparse_keys and dense_keys come back sorted here.
+ (sparse_keys, sparse_types, dense_keys, dense_types, dense_defaults,
+ dense_shapes) = parsing_ops._features_to_raw_params(
+ self._features, [
+ parsing_ops.VarLenFeature, parsing_ops.SparseFeature,
+ parsing_ops.FixedLenFeature, parsing_ops.FixedLenSequenceFeature
+ ])
+ # TODO(b/112859642): Pass sparse_index and sparse_values for SparseFeature.
+ (_, dense_defaults_vec, sparse_keys, sparse_types, dense_keys, dense_shapes,
+ dense_shape_as_shape) = parsing_ops._process_raw_parameters(
+ None, dense_defaults, sparse_keys, sparse_types, dense_keys,
+ dense_types, dense_shapes)
+ # pylint: enable=protected-access
+ self._sparse_keys = sparse_keys
+ self._sparse_types = sparse_types
+ self._dense_keys = dense_keys
+ self._dense_defaults = dense_defaults_vec
+ self._dense_shapes = dense_shapes
+ self._dense_types = dense_types
+ dense_output_shapes = [
+ self._input_dataset.output_shapes.concatenate(shape)
+ for shape in dense_shape_as_shape
+ ]
+ sparse_output_shapes = [
+ self._input_dataset.output_shapes.concatenate([None])
+ for _ in range(len(sparse_keys))
+ ]
+
+ self._output_shapes = dict(
+ zip(self._dense_keys + self._sparse_keys,
+ dense_output_shapes + sparse_output_shapes))
+ self._output_types = dict(
+ zip(self._dense_keys + self._sparse_keys,
+ self._dense_types + self._sparse_types))
+ self._output_classes = dict(
+ zip(self._dense_keys + self._sparse_keys,
+ [ops.Tensor for _ in range(len(self._dense_defaults))] +
+ [sparse_tensor.SparseTensor for _ in range(len(self._sparse_keys))
+ ]))
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.parse_example_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._num_parallel_calls,
+ self._dense_defaults,
+ self._sparse_keys,
+ self._dense_keys,
+ self._sparse_types,
+ self._dense_shapes,
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_shapes(self):
+ return self._output_shapes
+
+ @property
+ def output_types(self):
+ return self._output_types
+
+ @property
+ def output_classes(self):
+ return self._output_classes
+
+
+# TODO(b/111553342): add arguments names and example names as well.
+@tf_export("data.experimental.parse_example_dataset")
+def parse_example_dataset(features, num_parallel_calls=1):
+ """A transformation that parses `Example` protos into a `dict` of tensors.
+
+ Parses a number of serialized `Example` protos given in `serialized`. We refer
+ to `serialized` as a batch with `batch_size` many entries of individual
+ `Example` protos.
+
+ This op parses serialized examples into a dictionary mapping keys to `Tensor`
+ and `SparseTensor` objects. `features` is a dict from keys to `VarLenFeature`,
+ `SparseFeature`, and `FixedLenFeature` objects. Each `VarLenFeature`
+ and `SparseFeature` is mapped to a `SparseTensor`, and each
+ `FixedLenFeature` is mapped to a `Tensor`. See `tf.parse_example` for more
+ details about feature dictionaries.
+
+ Args:
+ features: A `dict` mapping feature keys to `FixedLenFeature`,
+ `VarLenFeature`, and `SparseFeature` values.
+ num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`,
+ representing the number of parsing processes to call in parallel.
+
+ Returns:
+ A dataset transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+
+ Raises:
+ ValueError: if features argument is None.
+ """
+ if features is None:
+ raise ValueError("Missing: features was %s." % features)
+
+ def _apply_fn(dataset):
+ """Function from `Dataset` to `Dataset` that applies the transformation."""
+ out_dataset = _ParseExampleDataset(dataset, features, num_parallel_calls)
+ if any([
+ isinstance(feature, parsing_ops.SparseFeature)
+ for _, feature in features.items()
+ ]):
+ # pylint: disable=protected-access
+ # pylint: disable=g-long-lambda
+ out_dataset = out_dataset.map(
+ lambda x: parsing_ops._construct_sparse_tensors_for_sparse_features(
+ features, x), num_parallel_calls=num_parallel_calls)
+ return out_dataset
+
+ return _apply_fn
diff --git a/tensorflow/python/data/experimental/ops/prefetching_ops.py b/tensorflow/python/data/experimental/ops/prefetching_ops.py
new file mode 100644
index 0000000000..48d7136f95
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/prefetching_ops.py
@@ -0,0 +1,531 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python wrapper for prefetching_ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import warnings
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.data.util import nest
+from tensorflow.python.data.util import sparse
+from tensorflow.python.eager import context
+from tensorflow.python.framework import device as framework_device
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import function
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import functional_ops
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+def function_buffering_resource(string_arg,
+ target_device,
+ f,
+ buffer_size,
+ output_types,
+ container="",
+ shared_name=None,
+ name=None):
+ """Creates a FunctionBufferingResource.
+
+ A FunctionBufferingResource fills up a buffer by calling a function `f` on
+ `target_device`. `f` should take in only a single string argument as input.
+
+ Args:
+ string_arg: The single string argument to the function.
+ target_device: The device to run `f` on.
+ f: The function to be executed.
+ buffer_size: Size of the buffer to be populated.
+ output_types: The output types generated by the function.
+ container: (Optional) string. Defaults to "".
+ shared_name: (Optional) string.
+ name: (Optional) string to name the op.
+
+ Returns:
+ Handle to a FunctionBufferingResource.
+ """
+ if shared_name is None:
+ shared_name = ""
+ return ged_ops.experimental_function_buffering_resource(
+ string_arg=string_arg,
+ target_device=target_device,
+ shared_name=shared_name,
+ f=f,
+ buffer_size=buffer_size,
+ container=container,
+ name=name,
+ output_types=output_types)
+
+
+def function_buffering_resource_get_next(function_buffer_resource,
+ output_types,
+ name=None):
+ return ged_ops.experimental_function_buffering_resource_get_next(
+ function_buffer_resource=function_buffer_resource,
+ output_types=output_types,
+ name=name)
+
+
+def function_buffering_resource_reset(function_buffer_resource, name=None):
+ return ged_ops.experimental_function_buffering_resource_reset(
+ function_buffer_resource=function_buffer_resource, name=name)
+
+
+# pylint: disable=protected-access
+class _PrefetchToDeviceIterator(object):
+ """A replacement for `tf.data.Iterator` that prefetches to another device.
+
+ Args:
+ input_dataset: The input dataset
+ one_shot: If true, we make a one shot iterator that's already initialized.
+ device: A fully specified device string where we want to prefetch to
+ buffer_size: Size of the prefetching buffer.
+ shared_name: (Optional.) If non-empty, the returned iterator will be
+ shared under the given name across multiple sessions that share the
+ same devices (e.g. when using a remote server).
+
+ Returns:
+ An Iterator type object.
+ """
+
+ def __init__(self,
+ input_dataset,
+ one_shot,
+ device,
+ buffer_size,
+ shared_name=None):
+ self._input_dataset = input_dataset
+ self._get_next_call_count = 0
+ self._one_shot = one_shot
+ if shared_name is None:
+ shared_name = ""
+
+ if self._one_shot:
+ self._input_iterator = input_dataset.make_one_shot_iterator()
+ else:
+ self._input_iterator = iterator_ops.Iterator.from_structure(
+ self._input_dataset.output_types, self._input_dataset.output_shapes,
+ shared_name, self._input_dataset.output_classes)
+ input_iterator_handle = self._input_iterator.string_handle()
+
+ @function.Defun(dtypes.string)
+ def _prefetch_fn(handle):
+ """Prefetches one element from `input_iterator`."""
+ remote_iterator = iterator_ops.Iterator.from_string_handle(
+ handle, self._input_iterator.output_types,
+ self._input_iterator.output_shapes,
+ self._input_iterator.output_classes)
+ ret = remote_iterator.get_next()
+ return nest.flatten(sparse.serialize_sparse_tensors(ret))
+
+ iterator_device = ged_ops.experimental_iterator_get_device(
+ self._input_iterator._iterator_resource)
+
+ with ops.device(device):
+ self._buffering_resource = function_buffering_resource(
+ f=_prefetch_fn,
+ target_device=iterator_device,
+ string_arg=input_iterator_handle,
+ buffer_size=buffer_size,
+ shared_name=shared_name,
+ output_types=nest.flatten(
+ sparse.as_dense_types(self._input_dataset.output_types,
+ self._input_dataset.output_classes)))
+
+ if not self._one_shot:
+ reset_op = function_buffering_resource_reset(self._buffering_resource)
+ with ops.control_dependencies([reset_op]):
+ self._initializer = self._input_iterator.make_initializer(
+ self._input_dataset)
+
+ def get_next(self, name=None):
+ """See `tf.data.Iterator.get_next`."""
+ self._get_next_call_count += 1
+ if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD:
+ warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE)
+
+ flat_ret = ged_ops.experimental_function_buffering_resource_get_next(
+ self._buffering_resource,
+ output_types=nest.flatten(
+ sparse.as_dense_types(self.output_types, self.output_classes)),
+ name=name)
+
+ ret = sparse.deserialize_sparse_tensors(
+ nest.pack_sequence_as(self.output_types, flat_ret),
+ self.output_types, self.output_shapes, self.output_classes)
+
+ for tensor, shape in zip(
+ nest.flatten(ret), nest.flatten(self.output_shapes)):
+ if isinstance(tensor, ops.Tensor):
+ tensor.set_shape(shape)
+
+ return ret
+
+ @property
+ def initializer(self):
+ if self._one_shot:
+ raise NotImplementedError("Can't initialize a one_shot_iterator")
+ return self._initializer
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+
+class _PrefetchToDeviceEagerIterator(iterator_ops.EagerIterator):
+ """A replacement for `tf.data.Iterator` that prefetches to another device.
+
+ Args:
+ input_dataset: The input dataset
+ one_shot: If true, we make a one shot iterator that's already initialized.
+ device: A fully specified device string where we want to prefetch to
+ buffer_size: Size of the prefetching buffer.
+ shared_name: (Optional.) If non-empty, the returned iterator will be
+ shared under the given name across multiple sessions that share the
+ same devices (e.g. when using a remote server).
+
+ Returns:
+ An Iterator type object.
+ """
+
+ def __init__(self,
+ input_dataset,
+ device,
+ buffer_size):
+ with ops.device("/device:CPU:0"):
+ super(_PrefetchToDeviceEagerIterator, self).__init__(input_dataset)
+ input_iterator_handle = gen_dataset_ops.iterator_to_string_handle(
+ self._resource)
+
+ self._device = device
+
+ @function.Defun(dtypes.string)
+ def _prefetch_fn(handle):
+ """Prefetches one element from `input_iterator`."""
+ remote_iterator = iterator_ops.Iterator.from_string_handle(
+ handle, self.output_types, self.output_shapes, self.output_classes)
+ ret = remote_iterator.get_next()
+ return nest.flatten(sparse.serialize_sparse_tensors(ret))
+
+ _prefetch_fn.add_to_graph(None)
+
+ with ops.device(device):
+ self._buffering_resource = function_buffering_resource(
+ f=_prefetch_fn,
+ output_types=self._flat_output_types,
+ target_device=ged_ops.experimental_iterator_get_device(
+ self._resource),
+ string_arg=input_iterator_handle,
+ buffer_size=buffer_size,
+ shared_name=iterator_ops._generate_shared_name(
+ "function_buffer_resource"))
+
+ def _next_internal(self):
+ """Returns a nested structure of `tf.Tensor`s containing the next element.
+ """
+ # This runs in sync mode as iterators use an error status to communicate
+ # that there is no more data to iterate over.
+ # TODO(b/77291417): Fix
+ with context.execution_mode(context.SYNC):
+ with ops.device(self._device):
+ ret = ged_ops.experimental_function_buffering_resource_get_next(
+ function_buffer_resource=self._buffering_resource,
+ output_types=self._flat_output_types)
+ return sparse.deserialize_sparse_tensors(
+ nest.pack_sequence_as(self._output_types, ret), self._output_types,
+ self._output_shapes, self._output_classes)
+# pylint: enable=protected-access
+
+
+class _PrefetchToDeviceDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` whose iterator prefetches elements to another device."""
+
+ def __init__(self, input_dataset, device, buffer_size):
+ super(_PrefetchToDeviceDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+ self._device = device
+ self._buffer_size = buffer_size if buffer_size is not None else 1
+
+ # The static analysis cannot tell that the eager iterator's superclass has
+ # a `next()` method.
+ # pylint: disable=non-iterator-returned
+ def __iter__(self):
+ """Creates an `Iterator` for enumerating the elements of this dataset.
+
+ The returned iterator implements the Python iterator protocol and therefore
+ can only be used in eager mode.
+
+ Returns:
+ An `Iterator` over the elements of this dataset.
+
+ Raises:
+ RuntimeError: If eager execution is enabled.
+ """
+ if context.executing_eagerly():
+ return _PrefetchToDeviceEagerIterator(self._input_dataset, self._device,
+ self._buffer_size)
+ else:
+ raise RuntimeError("dataset.__iter__() is only supported when eager "
+ "execution is enabled.")
+ # pylint: enable=non-iterator-returned
+
+ def make_one_shot_iterator(self):
+ if context.executing_eagerly():
+ return _PrefetchToDeviceEagerIterator(self._input_dataset, self._device,
+ self._buffer_size)
+ else:
+ return _PrefetchToDeviceIterator(self._input_dataset, one_shot=True,
+ device=self._device,
+ buffer_size=self._buffer_size)
+
+ def make_initializable_iterator(self, shared_name=None):
+ return _PrefetchToDeviceIterator(
+ self._input_dataset,
+ one_shot=False,
+ device=self._device,
+ buffer_size=self._buffer_size,
+ shared_name=shared_name)
+
+ def _as_variant_tensor(self):
+ # TODO(mrry): Raise this error earlier (e.g. when one of the Dataset
+ # transformation methods is called.
+ # TODO(mrry): Investigate support for chaining further transformations after
+ # the prefetch, including GPU support.
+ raise NotImplementedError("`prefetch_to_device()` must be the last "
+ "transformation in a dataset pipeline.")
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+
+@tf_export("data.experimental.prefetch_to_device")
+def prefetch_to_device(device, buffer_size=None):
+ """A transformation that prefetches dataset values to the given `device`.
+
+ NOTE: Although the transformation creates a `tf.data.Dataset`, the
+ transformation must be the final `Dataset` in the input pipeline.
+
+ Args:
+ device: A string. The name of a device to which elements will be prefetched.
+ buffer_size: (Optional.) The number of elements to buffer on `device`.
+ Defaults to an automatically chosen value.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+ def _apply_fn(dataset):
+ return _PrefetchToDeviceDataset(dataset, device, buffer_size)
+
+ return _apply_fn
+
+
+@tf_export("data.experimental.copy_to_device")
+def copy_to_device(target_device, source_device="/cpu:0"):
+ """A transformation that copies dataset elements to the given `target_device`.
+
+ Args:
+ target_device: The name of a device to which elements will be copied.
+ source_device: The original device on which `input_dataset` will be placed.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+
+ def _apply_fn(dataset):
+ return _CopyToDeviceDataset(
+ dataset, target_device=target_device, source_device=source_device)
+
+ return _apply_fn
+
+
+# TODO(rohanj): Use the _input_hostmem attr on the RemoteCall ops to indicate
+# all inputs to the Op are in host memory, thereby avoiding some unnecessary
+# Sends and Recvs.
+class _CopyToDeviceDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` that copies elements to another device."""
+
+ def __init__(self, input_dataset, target_device, source_device="/cpu:0"):
+ """Constructs a _CopyToDeviceDataset.
+
+ Args:
+ input_dataset: `Dataset` to be copied
+ target_device: The name of the device to which elements would be copied.
+ source_device: Device where input_dataset would be placed.
+ """
+ super(_CopyToDeviceDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+ self._target_device = target_device
+ spec = framework_device.DeviceSpec().from_string(self._target_device)
+ self._is_gpu_target = (spec.device_type == "GPU")
+ self._source_device_string = source_device
+ self._source_device = ops.convert_to_tensor(source_device)
+
+ self._flat_output_shapes = nest.flatten(
+ sparse.as_dense_shapes(self._input_dataset.output_shapes,
+ self._input_dataset.output_classes))
+ self._flat_output_types = nest.flatten(
+ sparse.as_dense_types(self._input_dataset.output_types,
+ self._input_dataset.output_classes))
+
+ @function.Defun()
+ def _init_func():
+ """Creates an iterator for the input dataset.
+
+ Returns:
+ A `string` tensor that encapsulates the iterator created.
+ """
+ # pylint: disable=protected-access
+ ds_variant = self._input_dataset._as_variant_tensor()
+ resource = gen_dataset_ops.anonymous_iterator(
+ output_types=self._flat_output_types,
+ output_shapes=self._flat_output_shapes)
+ with ops.control_dependencies(
+ [gen_dataset_ops.make_iterator(ds_variant, resource)]):
+ return gen_dataset_ops.iterator_to_string_handle(resource)
+
+ @function.Defun()
+ def _remote_init_func():
+ return functional_ops.remote_call(
+ target=self._source_device,
+ args=_init_func.captured_inputs,
+ Tout=[dtypes.string],
+ f=_init_func)
+
+ self._init_func = _remote_init_func
+ self._init_captured_args = _remote_init_func.captured_inputs
+
+ @function.Defun(dtypes.string)
+ def _next_func(string_handle):
+ """Calls get_next for created iterator.
+
+ Args:
+ string_handle: An iterator string handle created by _init_func
+ Returns:
+ The elements generated from `input_dataset`
+ """
+ with ops.device(self._source_device_string):
+ iterator = iterator_ops.Iterator.from_string_handle(
+ string_handle, self.output_types, self.output_shapes,
+ self.output_classes)
+ ret = iterator.get_next()
+ return nest.flatten(sparse.serialize_sparse_tensors(ret))
+
+ @function.Defun(dtypes.string)
+ def _remote_next_func(string_handle):
+ return functional_ops.remote_call(
+ target=self._source_device,
+ args=[string_handle] + _next_func.captured_inputs,
+ Tout=self._flat_output_types,
+ f=_next_func)
+
+ self._next_func = _remote_next_func
+ self._next_captured_args = _remote_next_func.captured_inputs
+
+ @function.Defun(dtypes.string)
+ def _finalize_func(string_handle):
+ """Destroys the iterator resource created.
+
+ Args:
+ string_handle: An iterator string handle created by _init_func
+ Returns:
+ Tensor constant 0
+ """
+ iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2(
+ string_handle,
+ output_types=self._flat_output_types,
+ output_shapes=self._flat_output_shapes)
+ with ops.control_dependencies([
+ resource_variable_ops.destroy_resource_op(
+ iterator_resource, ignore_lookup_error=True)]):
+ return array_ops.constant(0, dtypes.int64)
+
+ @function.Defun(dtypes.string)
+ def _remote_finalize_func(string_handle):
+ return functional_ops.remote_call(
+ target=self._source_device,
+ args=[string_handle] + _finalize_func.captured_inputs,
+ Tout=[dtypes.int64],
+ f=_finalize_func)
+
+ self._finalize_func = _remote_finalize_func
+ self._finalize_captured_args = _remote_finalize_func.captured_inputs
+
+ g = ops.get_default_graph()
+ _remote_init_func.add_to_graph(g)
+ _remote_next_func.add_to_graph(g)
+ _remote_finalize_func.add_to_graph(g)
+ # pylint: enable=protected-scope
+
+ # The one_shot_iterator implementation needs a 0 arg _make_dataset function
+ # that thereby captures all the inputs required to create the dataset. Since
+ # there are strings that are inputs to the GeneratorDataset which can't be
+ # placed on a GPU, this fails for the GPU case. Therefore, disabling it for
+ # GPU
+ def make_one_shot_iterator(self):
+ if self._is_gpu_target:
+ raise ValueError("Cannot create a one shot iterator when using "
+ "`tf.data.experimental.copy_to_device()` on GPU. Please "
+ "use `Dataset.make_initializable_iterator()` instead.")
+ else:
+ return super(_CopyToDeviceDataset, self).make_one_shot_iterator()
+
+ def _as_variant_tensor(self):
+ with ops.device(self._target_device):
+ return gen_dataset_ops.generator_dataset(
+ self._init_captured_args,
+ self._next_captured_args,
+ self._finalize_captured_args,
+ init_func=self._init_func,
+ next_func=self._next_func,
+ finalize_func=self._finalize_func,
+ output_types=self._flat_output_types,
+ output_shapes=self._flat_output_shapes)
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
diff --git a/tensorflow/python/data/experimental/ops/random_ops.py b/tensorflow/python/data/experimental/ops/random_ops.py
new file mode 100644
index 0000000000..e3a2aeab31
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/random_ops.py
@@ -0,0 +1,54 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Datasets for random number generators."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import random_seed
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("data.experimental.RandomDataset")
+class RandomDataset(dataset_ops.DatasetSource):
+ """A `Dataset` of pseudorandom values."""
+
+ def __init__(self, seed=None):
+ """A `Dataset` of pseudorandom values."""
+ super(RandomDataset, self).__init__()
+ self._seed, self._seed2 = random_seed.get_seed(seed)
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.random_dataset(
+ seed=self._seed,
+ seed2=self._seed2,
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_classes(self):
+ return ops.Tensor
+
+ @property
+ def output_shapes(self):
+ return tensor_shape.scalar()
+
+ @property
+ def output_types(self):
+ return dtypes.int64
diff --git a/tensorflow/python/data/experimental/ops/readers.py b/tensorflow/python/data/experimental/ops/readers.py
new file mode 100644
index 0000000000..3b2d094514
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/readers.py
@@ -0,0 +1,904 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python wrappers for reader Datasets."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import csv
+
+import numpy as np
+
+from tensorflow.python.data.experimental.ops import batching
+from tensorflow.python.data.experimental.ops import interleave_ops
+from tensorflow.python.data.experimental.ops import optimization
+from tensorflow.python.data.experimental.ops import parsing_ops
+from tensorflow.python.data.experimental.ops import shuffle_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.ops import readers as core_readers
+from tensorflow.python.data.util import convert
+from tensorflow.python.data.util import nest
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.lib.io import file_io
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import gen_experimental_dataset_ops
+from tensorflow.python.platform import gfile
+from tensorflow.python.util.tf_export import tf_export
+
+_ACCEPTABLE_CSV_TYPES = (dtypes.float32, dtypes.float64, dtypes.int32,
+ dtypes.int64, dtypes.string)
+
+
+def _is_valid_int32(str_val):
+ try:
+ # Checks equality to prevent int32 overflow
+ return dtypes.int32.as_numpy_dtype(str_val) == dtypes.int64.as_numpy_dtype(
+ str_val)
+ except (ValueError, OverflowError):
+ return False
+
+
+def _is_valid_int64(str_val):
+ try:
+ dtypes.int64.as_numpy_dtype(str_val)
+ return True
+ except (ValueError, OverflowError):
+ return False
+
+
+def _is_valid_float(str_val, float_dtype):
+ try:
+ return float_dtype.as_numpy_dtype(str_val) < np.inf
+ except ValueError:
+ return False
+
+
+def _infer_type(str_val, na_value, prev_type):
+ """Given a string, infers its tensor type.
+
+ Infers the type of a value by picking the least 'permissive' type possible,
+ while still allowing the previous type inference for this column to be valid.
+
+ Args:
+ str_val: String value to infer the type of.
+ na_value: Additional string to recognize as a NA/NaN CSV value.
+ prev_type: Type previously inferred based on values of this column that
+ we've seen up till now.
+ Returns:
+ Inferred dtype.
+ """
+ if str_val in ("", na_value):
+ # If the field is null, it gives no extra information about its type
+ return prev_type
+
+ type_list = [
+ dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64, dtypes.string
+ ] # list of types to try, ordered from least permissive to most
+
+ type_functions = [
+ _is_valid_int32,
+ _is_valid_int64,
+ lambda str_val: _is_valid_float(str_val, dtypes.float32),
+ lambda str_val: _is_valid_float(str_val, dtypes.float64),
+ lambda str_val: True,
+ ] # Corresponding list of validation functions
+
+ for i in range(len(type_list)):
+ validation_fn = type_functions[i]
+ if validation_fn(str_val) and (prev_type is None or
+ prev_type in type_list[:i + 1]):
+ return type_list[i]
+
+
+def _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header):
+ """Generator that yields rows of CSV file(s) in order."""
+ for fn in filenames:
+ with file_io.FileIO(fn, "r") as f:
+ rdr = csv.reader(
+ f,
+ delimiter=field_delim,
+ quoting=csv.QUOTE_MINIMAL if use_quote_delim else csv.QUOTE_NONE)
+ if header:
+ next(rdr) # Skip header lines
+
+ for csv_row in rdr:
+ if len(csv_row) != num_cols:
+ raise ValueError(
+ "Problem inferring types: CSV row has different number of fields "
+ "than expected.")
+ yield csv_row
+
+
+def _infer_column_defaults(filenames, num_cols, field_delim, use_quote_delim,
+ na_value, header, num_rows_for_inference,
+ select_columns):
+ """Infers column types from the first N valid CSV records of files."""
+ if select_columns is None:
+ select_columns = range(num_cols)
+ inferred_types = [None] * len(select_columns)
+
+ for i, csv_row in enumerate(
+ _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header)):
+ if num_rows_for_inference is not None and i >= num_rows_for_inference:
+ break
+
+ for j, col_index in enumerate(select_columns):
+ inferred_types[j] = _infer_type(csv_row[col_index], na_value,
+ inferred_types[j])
+
+ # Replace None's with a default type
+ inferred_types = [t or dtypes.string for t in inferred_types]
+ # Default to 0 or '' for null values
+ return [
+ constant_op.constant([0 if t is not dtypes.string else ""], dtype=t)
+ for t in inferred_types
+ ]
+
+
+def _infer_column_names(filenames, field_delim, use_quote_delim):
+ """Infers column names from first rows of files."""
+ csv_kwargs = {
+ "delimiter": field_delim,
+ "quoting": csv.QUOTE_MINIMAL if use_quote_delim else csv.QUOTE_NONE
+ }
+ with file_io.FileIO(filenames[0], "r") as f:
+ try:
+ column_names = next(csv.reader(f, **csv_kwargs))
+ except StopIteration:
+ raise ValueError(("Received StopIteration when reading the header line "
+ "of %s. Empty file?") % filenames[0])
+
+ for name in filenames[1:]:
+ with file_io.FileIO(name, "r") as f:
+ try:
+ if next(csv.reader(f, **csv_kwargs)) != column_names:
+ raise ValueError(
+ "Files have different column names in the header row.")
+ except StopIteration:
+ raise ValueError(("Received StopIteration when reading the header line "
+ "of %s. Empty file?") % filenames[0])
+ return column_names
+
+
+def _get_sorted_col_indices(select_columns, column_names):
+ """Transforms select_columns argument into sorted column indices."""
+ names_to_indices = {n: i for i, n in enumerate(column_names)}
+ num_cols = len(column_names)
+ for i, v in enumerate(select_columns):
+ if isinstance(v, int):
+ if v < 0 or v >= num_cols:
+ raise ValueError(
+ "Column index %d specified in select_columns out of valid range." %
+ v)
+ continue
+ if v not in names_to_indices:
+ raise ValueError(
+ "Value '%s' specified in select_columns not a valid column index or "
+ "name." % v)
+ select_columns[i] = names_to_indices[v]
+
+ # Sort and ensure there are no duplicates
+ result = sorted(set(select_columns))
+ if len(result) != len(select_columns):
+ raise ValueError("select_columns contains duplicate columns")
+ return result
+
+
+def _maybe_shuffle_and_repeat(
+ dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed):
+ """Optionally shuffle and repeat dataset, as requested."""
+ if num_epochs != 1 and shuffle:
+ # Use shuffle_and_repeat for perf
+ return dataset.apply(
+ shuffle_ops.shuffle_and_repeat(shuffle_buffer_size, num_epochs,
+ shuffle_seed))
+ elif shuffle:
+ return dataset.shuffle(shuffle_buffer_size, shuffle_seed)
+ elif num_epochs != 1:
+ return dataset.repeat(num_epochs)
+ return dataset
+
+
+def make_tf_record_dataset(file_pattern,
+ batch_size,
+ parser_fn=None,
+ num_epochs=None,
+ shuffle=True,
+ shuffle_buffer_size=None,
+ shuffle_seed=None,
+ prefetch_buffer_size=optimization.AUTOTUNE,
+ num_parallel_reads=None,
+ num_parallel_parser_calls=None,
+ drop_final_batch=False):
+ """Reads and optionally parses TFRecord files into a dataset.
+
+ Provides common functionality such as batching, optional parsing, shuffling,
+ and performant defaults.
+
+ Args:
+ file_pattern: List of files or patterns of TFRecord file paths.
+ See `tf.gfile.Glob` for pattern rules.
+ batch_size: An int representing the number of records to combine
+ in a single batch.
+ parser_fn: (Optional.) A function accepting string input to parse
+ and process the record contents. This function must map records
+ to components of a fixed shape, so they may be batched. By
+ default, uses the record contents unmodified.
+ num_epochs: (Optional.) An int specifying the number of times this
+ dataset is repeated. If None (the default), cycles through the
+ dataset forever.
+ shuffle: (Optional.) A bool that indicates whether the input
+ should be shuffled. Defaults to `True`.
+ shuffle_buffer_size: (Optional.) Buffer size to use for
+ shuffling. A large buffer size ensures better shuffling, but
+ increases memory usage and startup time.
+ shuffle_seed: (Optional.) Randomization seed to use for shuffling.
+ prefetch_buffer_size: (Optional.) An int specifying the number of
+ feature batches to prefetch for performance improvement.
+ Defaults to auto-tune. Set to 0 to disable prefetching.
+ num_parallel_reads: (Optional.) Number of threads used to read
+ records from files. By default or if set to a value >1, the
+ results will be interleaved.
+ num_parallel_parser_calls: (Optional.) Number of parallel
+ records to parse in parallel. Defaults to an automatic selection.
+ drop_final_batch: (Optional.) Whether the last batch should be
+ dropped in case its size is smaller than `batch_size`; the
+ default behavior is not to drop the smaller batch.
+
+ Returns:
+ A dataset, where each element matches the output of `parser_fn`
+ except it will have an additional leading `batch-size` dimension,
+ or a `batch_size`-length 1-D tensor of strings if `parser_fn` is
+ unspecified.
+ """
+ files = dataset_ops.Dataset.list_files(
+ file_pattern, shuffle=shuffle, seed=shuffle_seed)
+
+ if num_parallel_reads is None:
+ # Note: We considered auto-tuning this value, but there is a concern
+ # that this affects the mixing of records from different files, which
+ # could affect training convergence/accuracy, so we are defaulting to
+ # a constant for now.
+ num_parallel_reads = 24
+ dataset = core_readers.TFRecordDataset(
+ files, num_parallel_reads=num_parallel_reads)
+
+ if shuffle_buffer_size is None:
+ # TODO(josh11b): Auto-tune this value when not specified
+ shuffle_buffer_size = 10000
+ dataset = _maybe_shuffle_and_repeat(
+ dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed)
+
+ # NOTE(mrry): We set `drop_final_batch=True` when `num_epochs is None` to
+ # improve the shape inference, because it makes the batch dimension static.
+ # It is safe to do this because in that case we are repeating the input
+ # indefinitely, and all batches will be full-sized.
+ drop_final_batch = drop_final_batch or num_epochs is None
+
+ if parser_fn is None:
+ dataset = dataset.batch(batch_size, drop_remainder=drop_final_batch)
+ else:
+ # TODO(josh11b): if num_parallel_parser_calls is None, use some function
+ # of num cores instead of map_and_batch's default behavior of one batch.
+ dataset = dataset.apply(batching.map_and_batch(
+ parser_fn, batch_size, num_parallel_calls=num_parallel_parser_calls,
+ drop_remainder=drop_final_batch))
+
+ if prefetch_buffer_size == 0:
+ return dataset
+ else:
+ return dataset.prefetch(buffer_size=prefetch_buffer_size)
+
+
+@tf_export("data.experimental.make_csv_dataset")
+def make_csv_dataset(
+ file_pattern,
+ batch_size,
+ column_names=None,
+ column_defaults=None,
+ label_name=None,
+ select_columns=None,
+ field_delim=",",
+ use_quote_delim=True,
+ na_value="",
+ header=True,
+ num_epochs=None,
+ shuffle=True,
+ shuffle_buffer_size=10000,
+ shuffle_seed=None,
+ prefetch_buffer_size=optimization.AUTOTUNE,
+ num_parallel_reads=1,
+ sloppy=False,
+ num_rows_for_inference=100,
+ compression_type=None,
+):
+ """Reads CSV files into a dataset.
+
+ Reads CSV files into a dataset, where each element is a (features, labels)
+ tuple that corresponds to a batch of CSV rows. The features dictionary
+ maps feature column names to `Tensor`s containing the corresponding
+ feature data, and labels is a `Tensor` containing the batch's label data.
+
+ Args:
+ file_pattern: List of files or patterns of file paths containing CSV
+ records. See `tf.gfile.Glob` for pattern rules.
+ batch_size: An int representing the number of records to combine
+ in a single batch.
+ column_names: An optional list of strings that corresponds to the CSV
+ columns, in order. One per column of the input record. If this is not
+ provided, infers the column names from the first row of the records.
+ These names will be the keys of the features dict of each dataset element.
+ column_defaults: A optional list of default values for the CSV fields. One
+ item per selected column of the input record. Each item in the list is
+ either a valid CSV dtype (float32, float64, int32, int64, or string), or a
+ `Tensor` with one of the aforementioned types. The tensor can either be
+ a scalar default value (if the column is optional), or an empty tensor (if
+ the column is required). If a dtype is provided instead of a tensor, the
+ column is also treated as required. If this list is not provided, tries
+ to infer types based on reading the first num_rows_for_inference rows of
+ files specified, and assumes all columns are optional, defaulting to `0`
+ for numeric values and `""` for string values. If both this and
+ `select_columns` are specified, these must have the same lengths, and
+ `column_defaults` is assumed to be sorted in order of increasing column
+ index.
+ label_name: A optional string corresponding to the label column. If
+ provided, the data for this column is returned as a separate `Tensor` from
+ the features dictionary, so that the dataset complies with the format
+ expected by a `tf.Estimator.train` or `tf.Estimator.evaluate` input
+ function.
+ select_columns: An optional list of integer indices or string column
+ names, that specifies a subset of columns of CSV data to select. If
+ column names are provided, these must correspond to names provided in
+ `column_names` or inferred from the file header lines. When this argument
+ is specified, only a subset of CSV columns will be parsed and returned,
+ corresponding to the columns specified. Using this results in faster
+ parsing and lower memory usage. If both this and `column_defaults` are
+ specified, these must have the same lengths, and `column_defaults` is
+ assumed to be sorted in order of increasing column index.
+ field_delim: An optional `string`. Defaults to `","`. Char delimiter to
+ separate fields in a record.
+ use_quote_delim: An optional bool. Defaults to `True`. If false, treats
+ double quotation marks as regular characters inside of the string fields.
+ na_value: Additional string to recognize as NA/NaN.
+ header: A bool that indicates whether the first rows of provided CSV files
+ correspond to header lines with column names, and should not be included
+ in the data.
+ num_epochs: An int specifying the number of times this dataset is repeated.
+ If None, cycles through the dataset forever.
+ shuffle: A bool that indicates whether the input should be shuffled.
+ shuffle_buffer_size: Buffer size to use for shuffling. A large buffer size
+ ensures better shuffling, but increases memory usage and startup time.
+ shuffle_seed: Randomization seed to use for shuffling.
+ prefetch_buffer_size: An int specifying the number of feature
+ batches to prefetch for performance improvement. Recommended value is the
+ number of batches consumed per training step. Defaults to auto-tune.
+
+ num_parallel_reads: Number of threads used to read CSV records from files.
+ If >1, the results will be interleaved.
+ sloppy: If `True`, reading performance will be improved at
+ the cost of non-deterministic ordering. If `False`, the order of elements
+ produced is deterministic prior to shuffling (elements are still
+ randomized if `shuffle=True`. Note that if the seed is set, then order
+ of elements after shuffling is deterministic). Defaults to `False`.
+ num_rows_for_inference: Number of rows of a file to use for type inference
+ if record_defaults is not provided. If None, reads all the rows of all
+ the files. Defaults to 100.
+ compression_type: (Optional.) A `tf.string` scalar evaluating to one of
+ `""` (no compression), `"ZLIB"`, or `"GZIP"`. Defaults to no compression.
+
+ Returns:
+ A dataset, where each element is a (features, labels) tuple that corresponds
+ to a batch of `batch_size` CSV rows. The features dictionary maps feature
+ column names to `Tensor`s containing the corresponding column data, and
+ labels is a `Tensor` containing the column data for the label column
+ specified by `label_name`.
+
+ Raises:
+ ValueError: If any of the arguments is malformed.
+ """
+ # Create dataset of all matching filenames
+ filenames = _get_file_names(file_pattern, False)
+ dataset = dataset_ops.Dataset.from_tensor_slices(filenames)
+ if shuffle:
+ dataset = dataset.shuffle(len(filenames), shuffle_seed)
+
+ # Clean arguments; figure out column names and defaults
+
+ if column_names is None:
+ if not header:
+ raise ValueError("Cannot infer column names without a header line.")
+ # If column names are not provided, infer from the header lines
+ column_names = _infer_column_names(filenames, field_delim, use_quote_delim)
+ if len(column_names) != len(set(column_names)):
+ raise ValueError("Cannot have duplicate column names.")
+
+ if select_columns is not None:
+ select_columns = _get_sorted_col_indices(select_columns, column_names)
+
+ if column_defaults is not None:
+ column_defaults = [
+ constant_op.constant([], dtype=x) if x in _ACCEPTABLE_CSV_TYPES else x
+ for x in column_defaults
+ ]
+ else:
+ # If column defaults are not provided, infer from records at graph
+ # construction time
+ column_defaults = _infer_column_defaults(
+ filenames, len(column_names), field_delim, use_quote_delim, na_value,
+ header, num_rows_for_inference, select_columns)
+
+ if select_columns is not None and len(column_defaults) != len(select_columns):
+ raise ValueError(
+ "If specified, column_defaults and select_columns must have same "
+ "length."
+ )
+ if select_columns is not None and len(column_names) > len(select_columns):
+ # Pick the relevant subset of column names
+ column_names = [column_names[i] for i in select_columns]
+
+ if label_name is not None and label_name not in column_names:
+ raise ValueError("`label_name` provided must be one of the columns.")
+
+ def filename_to_dataset(filename):
+ return CsvDataset(
+ filename,
+ record_defaults=column_defaults,
+ field_delim=field_delim,
+ use_quote_delim=use_quote_delim,
+ na_value=na_value,
+ select_cols=select_columns,
+ header=header,
+ compression_type=compression_type,
+ )
+
+ def map_fn(*columns):
+ """Organizes columns into a features dictionary.
+
+ Args:
+ *columns: list of `Tensor`s corresponding to one csv record.
+ Returns:
+ An OrderedDict of feature names to values for that particular record. If
+ label_name is provided, extracts the label feature to be returned as the
+ second element of the tuple.
+ """
+ features = collections.OrderedDict(zip(column_names, columns))
+ if label_name is not None:
+ label = features.pop(label_name)
+ return features, label
+ return features
+
+ # Read files sequentially (if num_parallel_reads=1) or in parallel
+ dataset = dataset.apply(
+ interleave_ops.parallel_interleave(
+ filename_to_dataset, cycle_length=num_parallel_reads, sloppy=sloppy))
+
+ dataset = _maybe_shuffle_and_repeat(
+ dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed)
+
+ # Apply batch before map for perf, because map has high overhead relative
+ # to the size of the computation in each map.
+ # NOTE(mrry): We set `drop_remainder=True` when `num_epochs is None` to
+ # improve the shape inference, because it makes the batch dimension static.
+ # It is safe to do this because in that case we are repeating the input
+ # indefinitely, and all batches will be full-sized.
+ dataset = dataset.batch(batch_size=batch_size,
+ drop_remainder=num_epochs is None)
+ dataset = dataset_ops.MapDataset(
+ dataset, map_fn, use_inter_op_parallelism=False)
+ dataset = dataset.prefetch(prefetch_buffer_size)
+
+ return dataset
+
+
+_DEFAULT_READER_BUFFER_SIZE_BYTES = 4 * 1024 * 1024 # 4 MB
+
+
+@tf_export("data.experimental.CsvDataset")
+class CsvDataset(dataset_ops.DatasetSource):
+ """A Dataset comprising lines from one or more CSV files."""
+
+ def __init__(self,
+ filenames,
+ record_defaults,
+ compression_type=None,
+ buffer_size=None,
+ header=False,
+ field_delim=",",
+ use_quote_delim=True,
+ na_value="",
+ select_cols=None):
+ """Creates a `CsvDataset` by reading and decoding CSV files.
+
+ The elements of this dataset correspond to records from the file(s).
+ RFC 4180 format is expected for CSV files
+ (https://tools.ietf.org/html/rfc4180)
+ Note that we allow leading and trailing spaces with int or float field.
+
+
+ For example, suppose we have a file 'my_file0.csv' with four CSV columns of
+ different data types:
+ ```
+ abcdefg,4.28E10,5.55E6,12
+ hijklmn,-5.3E14,,2
+ ```
+
+ We can construct a CsvDataset from it as follows:
+ ```python
+ dataset = tf.data.experimental.CsvDataset(
+ "my_file*.csv",
+ [tf.float32, # Required field, use dtype or empty tensor
+ tf.constant([0.0], dtype=tf.float32), # Optional field, default to 0.0
+ tf.int32, # Required field, use dtype or empty tensor
+ ],
+ select_cols=[1,2,3] # Only parse last three columns
+ )
+ ```
+
+ The expected output of its iterations is:
+ ```python
+ next_element = dataset.make_one_shot_iterator().get_next()
+ with tf.Session() as sess:
+ while True:
+ try:
+ print(sess.run(next_element))
+ except tf.errors.OutOfRangeError:
+ break
+
+ >> (4.28e10, 5.55e6, 12)
+ >> (-5.3e14, 0.0, 2)
+ ```
+
+ Args:
+ filenames: A `tf.string` tensor containing one or more filenames.
+ record_defaults: A list of default values for the CSV fields. Each item in
+ the list is either a valid CSV `DType` (float32, float64, int32, int64,
+ string), or a `Tensor` object with one of the above types. One per
+ column of CSV data, with either a scalar `Tensor` default value for the
+ column if it is optional, or `DType` or empty `Tensor` if required. If
+ both this and `select_columns` are specified, these must have the same
+ lengths, and `column_defaults` is assumed to be sorted in order of
+ increasing column index.
+ compression_type: (Optional.) A `tf.string` scalar evaluating to one of
+ `""` (no compression), `"ZLIB"`, or `"GZIP"`. Defaults to no
+ compression.
+ buffer_size: (Optional.) A `tf.int64` scalar denoting the number of bytes
+ to buffer while reading files. Defaults to 4MB.
+ header: (Optional.) A `tf.bool` scalar indicating whether the CSV file(s)
+ have header line(s) that should be skipped when parsing. Defaults to
+ `False`.
+ field_delim: (Optional.) A `tf.string` scalar containing the delimiter
+ character that separates fields in a record. Defaults to `","`.
+ use_quote_delim: (Optional.) A `tf.bool` scalar. If `False`, treats
+ double quotation marks as regular characters inside of string fields
+ (ignoring RFC 4180, Section 2, Bullet 5). Defaults to `True`.
+ na_value: (Optional.) A `tf.string` scalar indicating a value that will
+ be treated as NA/NaN.
+ select_cols: (Optional.) A sorted list of column indices to select from
+ the input data. If specified, only this subset of columns will be
+ parsed. Defaults to parsing all columns.
+ """
+ super(CsvDataset, self).__init__()
+ self._filenames = ops.convert_to_tensor(
+ filenames, dtype=dtypes.string, name="filenames")
+ self._compression_type = convert.optional_param_to_tensor(
+ "compression_type",
+ compression_type,
+ argument_default="",
+ argument_dtype=dtypes.string)
+ record_defaults = [
+ constant_op.constant([], dtype=x) if x in _ACCEPTABLE_CSV_TYPES else x
+ for x in record_defaults
+ ]
+ self._record_defaults = ops.convert_n_to_tensor(
+ record_defaults, name="record_defaults")
+ self._buffer_size = convert.optional_param_to_tensor(
+ "buffer_size", buffer_size, _DEFAULT_READER_BUFFER_SIZE_BYTES)
+ self._header = ops.convert_to_tensor(
+ header, dtype=dtypes.bool, name="header")
+ self._field_delim = ops.convert_to_tensor(
+ field_delim, dtype=dtypes.string, name="field_delim")
+ self._use_quote_delim = ops.convert_to_tensor(
+ use_quote_delim, dtype=dtypes.bool, name="use_quote_delim")
+ self._na_value = ops.convert_to_tensor(
+ na_value, dtype=dtypes.string, name="na_value")
+ self._select_cols = convert.optional_param_to_tensor(
+ "select_cols",
+ select_cols,
+ argument_default=[],
+ argument_dtype=dtypes.int64,
+ )
+ self._output_shapes = tuple(
+ tensor_shape.scalar() for _ in range(len(record_defaults)))
+ self._output_types = tuple(d.dtype for d in self._record_defaults)
+ self._output_classes = tuple(
+ ops.Tensor for _ in range(len(record_defaults)))
+
+ def _as_variant_tensor(self):
+ # Constructs graph node for the dataset op.
+ return gen_experimental_dataset_ops.experimental_csv_dataset(
+ filenames=self._filenames,
+ record_defaults=self._record_defaults,
+ buffer_size=self._buffer_size,
+ header=self._header,
+ output_shapes=self._output_shapes,
+ field_delim=self._field_delim,
+ use_quote_delim=self._use_quote_delim,
+ na_value=self._na_value,
+ select_cols=self._select_cols,
+ compression_type=self._compression_type,
+ )
+
+ @property
+ def output_types(self):
+ return self._output_types
+
+ @property
+ def output_shapes(self):
+ return self._output_shapes
+
+ @property
+ def output_classes(self):
+ return self._output_classes
+
+
+@tf_export("data.experimental.make_batched_features_dataset")
+def make_batched_features_dataset(file_pattern,
+ batch_size,
+ features,
+ reader=core_readers.TFRecordDataset,
+ label_key=None,
+ reader_args=None,
+ num_epochs=None,
+ shuffle=True,
+ shuffle_buffer_size=10000,
+ shuffle_seed=None,
+ prefetch_buffer_size=optimization.AUTOTUNE,
+ reader_num_threads=1,
+ parser_num_threads=2,
+ sloppy_ordering=False,
+ drop_final_batch=False):
+ """Returns a `Dataset` of feature dictionaries from `Example` protos.
+
+ If label_key argument is provided, returns a `Dataset` of tuple
+ comprising of feature dictionaries and label.
+
+ Example:
+
+ ```
+ serialized_examples = [
+ features {
+ feature { key: "age" value { int64_list { value: [ 0 ] } } }
+ feature { key: "gender" value { bytes_list { value: [ "f" ] } } }
+ feature { key: "kws" value { bytes_list { value: [ "code", "art" ] } } }
+ },
+ features {
+ feature { key: "age" value { int64_list { value: [] } } }
+ feature { key: "gender" value { bytes_list { value: [ "f" ] } } }
+ feature { key: "kws" value { bytes_list { value: [ "sports" ] } } }
+ }
+ ]
+ ```
+
+ We can use arguments:
+
+ ```
+ features: {
+ "age": FixedLenFeature([], dtype=tf.int64, default_value=-1),
+ "gender": FixedLenFeature([], dtype=tf.string),
+ "kws": VarLenFeature(dtype=tf.string),
+ }
+ ```
+
+ And the expected output is:
+
+ ```python
+ {
+ "age": [[0], [-1]],
+ "gender": [["f"], ["f"]],
+ "kws": SparseTensor(
+ indices=[[0, 0], [0, 1], [1, 0]],
+ values=["code", "art", "sports"]
+ dense_shape=[2, 2]),
+ }
+ ```
+
+ Args:
+ file_pattern: List of files or patterns of file paths containing
+ `Example` records. See `tf.gfile.Glob` for pattern rules.
+ batch_size: An int representing the number of records to combine
+ in a single batch.
+ features: A `dict` mapping feature keys to `FixedLenFeature` or
+ `VarLenFeature` values. See `tf.parse_example`.
+ reader: A function or class that can be
+ called with a `filenames` tensor and (optional) `reader_args` and returns
+ a `Dataset` of `Example` tensors. Defaults to `tf.data.TFRecordDataset`.
+ label_key: (Optional) A string corresponding to the key labels are stored in
+ `tf.Examples`. If provided, it must be one of the `features` key,
+ otherwise results in `ValueError`.
+ reader_args: Additional arguments to pass to the reader class.
+ num_epochs: Integer specifying the number of times to read through the
+ dataset. If None, cycles through the dataset forever. Defaults to `None`.
+ shuffle: A boolean, indicates whether the input should be shuffled. Defaults
+ to `True`.
+ shuffle_buffer_size: Buffer size of the ShuffleDataset. A large capacity
+ ensures better shuffling but would increase memory usage and startup time.
+ shuffle_seed: Randomization seed to use for shuffling.
+ prefetch_buffer_size: Number of feature batches to prefetch in order to
+ improve performance. Recommended value is the number of batches consumed
+ per training step. Defaults to auto-tune.
+ reader_num_threads: Number of threads used to read `Example` records. If >1,
+ the results will be interleaved.
+ parser_num_threads: Number of threads to use for parsing `Example` tensors
+ into a dictionary of `Feature` tensors.
+ sloppy_ordering: If `True`, reading performance will be improved at
+ the cost of non-deterministic ordering. If `False`, the order of elements
+ produced is deterministic prior to shuffling (elements are still
+ randomized if `shuffle=True`. Note that if the seed is set, then order
+ of elements after shuffling is deterministic). Defaults to `False`.
+ drop_final_batch: If `True`, and the batch size does not evenly divide the
+ input dataset size, the final smaller batch will be dropped. Defaults to
+ `False`.
+
+ Returns:
+ A dataset of `dict` elements, (or a tuple of `dict` elements and label).
+ Each `dict` maps feature keys to `Tensor` or `SparseTensor` objects.
+
+ Raises:
+ ValueError: If `label_key` is not one of the `features` keys.
+ """
+ # Create dataset of all matching filenames
+ filenames = _get_file_names(file_pattern, False)
+ dataset = dataset_ops.Dataset.from_tensor_slices(filenames)
+ if shuffle:
+ dataset = dataset.shuffle(len(filenames), shuffle_seed)
+
+ # Read `Example` records from files as tensor objects.
+ if reader_args is None:
+ reader_args = []
+
+ # Read files sequentially (if reader_num_threads=1) or in parallel
+ dataset = dataset.apply(
+ interleave_ops.parallel_interleave(
+ lambda filename: reader(filename, *reader_args),
+ cycle_length=reader_num_threads,
+ sloppy=sloppy_ordering))
+
+ # Extract values if the `Example` tensors are stored as key-value tuples.
+ if dataset.output_types == (dtypes.string, dtypes.string):
+ dataset = dataset_ops.MapDataset(
+ dataset, lambda _, v: v, use_inter_op_parallelism=False)
+
+ # Apply dataset repeat and shuffle transformations.
+ dataset = _maybe_shuffle_and_repeat(
+ dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed)
+
+ # NOTE(mrry): We set `drop_remainder=True` when `num_epochs is None` to
+ # improve the shape inference, because it makes the batch dimension static.
+ # It is safe to do this because in that case we are repeating the input
+ # indefinitely, and all batches will be full-sized.
+ dataset = dataset.batch(
+ batch_size, drop_remainder=drop_final_batch or num_epochs is None)
+
+ # Parse `Example` tensors to a dictionary of `Feature` tensors.
+ dataset = dataset.apply(
+ parsing_ops.parse_example_dataset(
+ features, num_parallel_calls=parser_num_threads))
+
+ if label_key:
+ if label_key not in features:
+ raise ValueError(
+ "The `label_key` provided (%r) must be one of the `features` keys." %
+ label_key)
+ dataset = dataset.map(lambda x: (x, x.pop(label_key)))
+
+ dataset = dataset.prefetch(prefetch_buffer_size)
+ return dataset
+
+
+def _get_file_names(file_pattern, shuffle):
+ """Parse list of file names from pattern, optionally shuffled.
+
+ Args:
+ file_pattern: File glob pattern, or list of glob patterns.
+ shuffle: Whether to shuffle the order of file names.
+
+ Returns:
+ List of file names matching `file_pattern`.
+
+ Raises:
+ ValueError: If `file_pattern` is empty, or pattern matches no files.
+ """
+ if isinstance(file_pattern, list):
+ if not file_pattern:
+ raise ValueError("File pattern is empty.")
+ file_names = []
+ for entry in file_pattern:
+ file_names.extend(gfile.Glob(entry))
+ else:
+ file_names = list(gfile.Glob(file_pattern))
+
+ if not file_names:
+ raise ValueError("No files match %s." % file_pattern)
+
+ # Sort files so it will be deterministic for unit tests.
+ if not shuffle:
+ file_names = sorted(file_names)
+ return file_names
+
+
+@tf_export("data.experimental.SqlDataset")
+class SqlDataset(dataset_ops.DatasetSource):
+ """A `Dataset` consisting of the results from a SQL query."""
+
+ def __init__(self, driver_name, data_source_name, query, output_types):
+ """Creates a `SqlDataset`.
+
+ `SqlDataset` allows a user to read data from the result set of a SQL query.
+ For example:
+
+ ```python
+ dataset = tf.data.experimental.SqlDataset("sqlite", "/foo/bar.sqlite3",
+ "SELECT name, age FROM people",
+ (tf.string, tf.int32))
+ iterator = dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+ # Prints the rows of the result set of the above query.
+ while True:
+ try:
+ print(sess.run(next_element))
+ except tf.errors.OutOfRangeError:
+ break
+ ```
+
+ Args:
+ driver_name: A 0-D `tf.string` tensor containing the database type.
+ Currently, the only supported value is 'sqlite'.
+ data_source_name: A 0-D `tf.string` tensor containing a connection string
+ to connect to the database.
+ query: A 0-D `tf.string` tensor containing the SQL query to execute.
+ output_types: A tuple of `tf.DType` objects representing the types of the
+ columns returned by `query`.
+ """
+ super(SqlDataset, self).__init__()
+ self._driver_name = ops.convert_to_tensor(
+ driver_name, dtype=dtypes.string, name="driver_name")
+ self._data_source_name = ops.convert_to_tensor(
+ data_source_name, dtype=dtypes.string, name="data_source_name")
+ self._query = ops.convert_to_tensor(
+ query, dtype=dtypes.string, name="query")
+ self._output_types = output_types
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.sql_dataset(self._driver_name,
+ self._data_source_name, self._query,
+ nest.flatten(self.output_types),
+ nest.flatten(self.output_shapes))
+
+ @property
+ def output_classes(self):
+ return nest.map_structure(lambda _: ops.Tensor, self._output_types)
+
+ @property
+ def output_shapes(self):
+ return nest.map_structure(lambda _: tensor_shape.TensorShape([]),
+ self._output_types)
+
+ @property
+ def output_types(self):
+ return self._output_types
diff --git a/tensorflow/python/data/experimental/ops/resampling.py b/tensorflow/python/data/experimental/ops/resampling.py
new file mode 100644
index 0000000000..3a3040ae9a
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/resampling.py
@@ -0,0 +1,296 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Resampling dataset transformations."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.ops import batching
+from tensorflow.python.data.experimental.ops import interleave_ops
+from tensorflow.python.data.experimental.ops import scan_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import logging_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("data.experimental.rejection_resample")
+def rejection_resample(class_func, target_dist, initial_dist=None, seed=None):
+ """A transformation that resamples a dataset to achieve a target distribution.
+
+ **NOTE** Resampling is performed via rejection sampling; some fraction
+ of the input values will be dropped.
+
+ Args:
+ class_func: A function mapping an element of the input dataset to a scalar
+ `tf.int32` tensor. Values should be in `[0, num_classes)`.
+ target_dist: A floating point type tensor, shaped `[num_classes]`.
+ initial_dist: (Optional.) A floating point type tensor, shaped
+ `[num_classes]`. If not provided, the true class distribution is
+ estimated live in a streaming fashion.
+ seed: (Optional.) Python integer seed for the resampler.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+ def _apply_fn(dataset):
+ """Function from `Dataset` to `Dataset` that applies the transformation."""
+ target_dist_t = ops.convert_to_tensor(target_dist, name="target_dist")
+ class_values_ds = dataset.map(class_func)
+
+ # Get initial distribution.
+ if initial_dist is not None:
+ initial_dist_t = ops.convert_to_tensor(initial_dist, name="initial_dist")
+ acceptance_dist, prob_of_original = (
+ _calculate_acceptance_probs_with_mixing(initial_dist_t,
+ target_dist_t))
+ initial_dist_ds = dataset_ops.Dataset.from_tensors(
+ initial_dist_t).repeat()
+ acceptance_dist_ds = dataset_ops.Dataset.from_tensors(
+ acceptance_dist).repeat()
+ prob_of_original_ds = dataset_ops.Dataset.from_tensors(
+ prob_of_original).repeat()
+ else:
+ initial_dist_ds = _estimate_initial_dist_ds(
+ target_dist_t, class_values_ds)
+ acceptance_and_original_prob_ds = initial_dist_ds.map(
+ lambda initial: _calculate_acceptance_probs_with_mixing( # pylint: disable=g-long-lambda
+ initial, target_dist_t))
+ acceptance_dist_ds = acceptance_and_original_prob_ds.map(
+ lambda accept_prob, _: accept_prob)
+ prob_of_original_ds = acceptance_and_original_prob_ds.map(
+ lambda _, prob_original: prob_original)
+ filtered_ds = _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds,
+ class_values_ds, seed)
+ # Prefetch filtered dataset for speed.
+ filtered_ds = filtered_ds.prefetch(3)
+
+ prob_original_static = _get_prob_original_static(
+ initial_dist_t, target_dist_t) if initial_dist is not None else None
+ if prob_original_static == 1:
+ return dataset_ops.Dataset.zip((class_values_ds, dataset))
+ elif prob_original_static == 0:
+ return filtered_ds
+ else:
+ return interleave_ops.sample_from_datasets(
+ [dataset_ops.Dataset.zip((class_values_ds, dataset)), filtered_ds],
+ weights=prob_of_original_ds.map(lambda prob: [(prob, 1.0 - prob)]),
+ seed=seed)
+
+ return _apply_fn
+
+
+def _get_prob_original_static(initial_dist_t, target_dist_t):
+ """Returns the static probability of sampling from the original.
+
+ `tensor_util.constant_value(prob_of_original)` returns `None` if it encounters
+ an Op that it isn't defined for. We have some custom logic to avoid this.
+
+ Args:
+ initial_dist_t: A tensor of the initial distribution.
+ target_dist_t: A tensor of the target distribution.
+
+ Returns:
+ The probability of sampling from the original distribution as a constant,
+ if it is a constant, or `None`.
+ """
+ init_static = tensor_util.constant_value(initial_dist_t)
+ target_static = tensor_util.constant_value(target_dist_t)
+
+ if init_static is None or target_static is None:
+ return None
+ else:
+ return np.min(target_static / init_static)
+
+
+def _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds, class_values_ds,
+ seed):
+ """Filters a dataset based on per-class acceptance probabilities.
+
+ Args:
+ dataset: The dataset to be filtered.
+ acceptance_dist_ds: A dataset of acceptance probabilities.
+ initial_dist_ds: A dataset of the initial probability distribution, given or
+ estimated.
+ class_values_ds: A dataset of the corresponding classes.
+ seed: (Optional.) Python integer seed for the resampler.
+
+ Returns:
+ A dataset of (class value, data) after filtering.
+ """
+ def maybe_warn_on_large_rejection(accept_dist, initial_dist):
+ proportion_rejected = math_ops.reduce_sum((1 - accept_dist) * initial_dist)
+ return control_flow_ops.cond(
+ math_ops.less(proportion_rejected, .5),
+ lambda: accept_dist,
+ lambda: logging_ops.Print( # pylint: disable=g-long-lambda
+ accept_dist, [proportion_rejected, initial_dist, accept_dist],
+ message="Proportion of examples rejected by sampler is high: ",
+ summarize=100,
+ first_n=10))
+
+ acceptance_dist_ds = (dataset_ops.Dataset.zip((acceptance_dist_ds,
+ initial_dist_ds))
+ .map(maybe_warn_on_large_rejection))
+
+ def _gather_and_copy(class_val, acceptance_prob, data):
+ return class_val, array_ops.gather(acceptance_prob, class_val), data
+
+ current_probabilities_and_class_and_data_ds = dataset_ops.Dataset.zip(
+ (class_values_ds, acceptance_dist_ds, dataset)).map(_gather_and_copy)
+ filtered_ds = (
+ current_probabilities_and_class_and_data_ds
+ .filter(lambda _1, p, _2: random_ops.random_uniform([], seed=seed) < p))
+ return filtered_ds.map(lambda class_value, _, data: (class_value, data))
+
+
+def _estimate_initial_dist_ds(
+ target_dist_t, class_values_ds, dist_estimation_batch_size=32,
+ smoothing_constant=10):
+ num_classes = (target_dist_t.shape[0].value or
+ array_ops.shape(target_dist_t)[0])
+ initial_examples_per_class_seen = array_ops.fill(
+ [num_classes], np.int64(smoothing_constant))
+
+ def update_estimate_and_tile(num_examples_per_class_seen, c):
+ updated_examples_per_class_seen, dist = _estimate_data_distribution(
+ c, num_examples_per_class_seen)
+ tiled_dist = array_ops.tile(
+ array_ops.expand_dims(dist, 0), [dist_estimation_batch_size, 1])
+ return updated_examples_per_class_seen, tiled_dist
+
+ initial_dist_ds = (class_values_ds.batch(dist_estimation_batch_size)
+ .apply(scan_ops.scan(initial_examples_per_class_seen,
+ update_estimate_and_tile))
+ .apply(batching.unbatch()))
+
+ return initial_dist_ds
+
+
+def _get_target_to_initial_ratio(initial_probs, target_probs):
+ # Add tiny to initial_probs to avoid divide by zero.
+ denom = (initial_probs + np.finfo(initial_probs.dtype.as_numpy_dtype).tiny)
+ return target_probs / denom
+
+
+def _estimate_data_distribution(c, num_examples_per_class_seen):
+ """Estimate data distribution as labels are seen.
+
+ Args:
+ c: The class labels. Type `int32`, shape `[batch_size]`.
+ num_examples_per_class_seen: Type `int64`, shape `[num_classes]`,
+ containing counts.
+
+ Returns:
+ num_examples_per_lass_seen: Updated counts. Type `int64`, shape
+ `[num_classes]`.
+ dist: The updated distribution. Type `float32`, shape `[num_classes]`.
+ """
+ num_classes = num_examples_per_class_seen.get_shape()[0].value
+ # Update the class-count based on what labels are seen in batch.
+ num_examples_per_class_seen = math_ops.add(
+ num_examples_per_class_seen, math_ops.reduce_sum(
+ array_ops.one_hot(c, num_classes, dtype=dtypes.int64), 0))
+ init_prob_estimate = math_ops.truediv(
+ num_examples_per_class_seen,
+ math_ops.reduce_sum(num_examples_per_class_seen))
+ dist = math_ops.cast(init_prob_estimate, dtypes.float32)
+ return num_examples_per_class_seen, dist
+
+
+def _calculate_acceptance_probs_with_mixing(initial_probs, target_probs):
+ """Calculates the acceptance probabilities and mixing ratio.
+
+ In this case, we assume that we can *either* sample from the original data
+ distribution with probability `m`, or sample from a reshaped distribution
+ that comes from rejection sampling on the original distribution. This
+ rejection sampling is done on a per-class basis, with `a_i` representing the
+ probability of accepting data from class `i`.
+
+ This method is based on solving the following analysis for the reshaped
+ distribution:
+
+ Let F be the probability of a rejection (on any example).
+ Let p_i be the proportion of examples in the data in class i (init_probs)
+ Let a_i is the rate the rejection sampler should *accept* class i
+ Let t_i is the target proportion in the minibatches for class i (target_probs)
+
+ ```
+ F = sum_i(p_i * (1-a_i))
+ = 1 - sum_i(p_i * a_i) using sum_i(p_i) = 1
+ ```
+
+ An example with class `i` will be accepted if `k` rejections occur, then an
+ example with class `i` is seen by the rejector, and it is accepted. This can
+ be written as follows:
+
+ ```
+ t_i = sum_k=0^inf(F^k * p_i * a_i)
+ = p_i * a_j / (1 - F) using geometric series identity, since 0 <= F < 1
+ = p_i * a_i / sum_j(p_j * a_j) using F from above
+ ```
+
+ Note that the following constraints hold:
+ ```
+ 0 <= p_i <= 1, sum_i(p_i) = 1
+ 0 <= a_i <= 1
+ 0 <= t_i <= 1, sum_i(t_i) = 1
+ ```
+
+ A solution for a_i in terms of the other variables is the following:
+ ```a_i = (t_i / p_i) / max_i[t_i / p_i]```
+
+ If we try to minimize the amount of data rejected, we get the following:
+
+ M_max = max_i [ t_i / p_i ]
+ M_min = min_i [ t_i / p_i ]
+
+ The desired probability of accepting data if it comes from class `i`:
+
+ a_i = (t_i/p_i - m) / (M_max - m)
+
+ The desired probability of pulling a data element from the original dataset,
+ rather than the filtered one:
+
+ m = M_min
+
+ Args:
+ initial_probs: A Tensor of the initial probability distribution, given or
+ estimated.
+ target_probs: A Tensor of the corresponding classes.
+
+ Returns:
+ (A 1D Tensor with the per-class acceptance probabilities, the desired
+ probability of pull from the original distribution.)
+ """
+ ratio_l = _get_target_to_initial_ratio(initial_probs, target_probs)
+ max_ratio = math_ops.reduce_max(ratio_l)
+ min_ratio = math_ops.reduce_min(ratio_l)
+
+ # Target prob to sample from original distribution.
+ m = min_ratio
+
+ # TODO(joelshor): Simplify fraction, if possible.
+ a_i = (ratio_l - m) / (max_ratio - m)
+ return a_i, m
diff --git a/tensorflow/python/data/experimental/ops/scan_ops.py b/tensorflow/python/data/experimental/ops/scan_ops.py
new file mode 100644
index 0000000000..e05e7c5a18
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/scan_ops.py
@@ -0,0 +1,177 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Scan dataset transformation."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import nest
+from tensorflow.python.data.util import sparse
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+class _ScanDataset(dataset_ops.UnaryDataset):
+ """A dataset that scans a function across its input."""
+
+ def __init__(self, input_dataset, initial_state, scan_func):
+ """See `scan()` for details."""
+ super(_ScanDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+
+ with ops.name_scope("initial_state"):
+ # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
+ # values to tensors.
+ self._initial_state = nest.pack_sequence_as(initial_state, [
+ sparse_tensor.SparseTensor.from_value(t)
+ if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(
+ t, name="component_%d" % i)
+ for i, t in enumerate(nest.flatten(initial_state))
+ ])
+
+ # Compute initial values for the state classes, shapes and types based on
+ # the initial state. The shapes may be refined by running `tf_scan_func` one
+ # or more times below.
+ self._state_classes = sparse.get_classes(self._initial_state)
+ self._state_shapes = nest.pack_sequence_as(
+ self._initial_state,
+ [t.get_shape() for t in nest.flatten(self._initial_state)])
+ self._state_types = nest.pack_sequence_as(
+ self._initial_state,
+ [t.dtype for t in nest.flatten(self._initial_state)])
+
+ # Will be populated by calling `tf_scan_func`.
+ self._output_classes = None
+ self._output_shapes = None
+ self._output_types = None
+
+ # Iteratively rerun the scan function until reaching a fixed point on
+ # `self._state_shapes`.
+ need_to_rerun = True
+ while need_to_rerun:
+
+ wrapped_func = dataset_ops.StructuredFunctionWrapper(
+ scan_func,
+ "tf.data.experimental.scan()",
+ input_classes=(self._state_classes, input_dataset.output_classes),
+ input_shapes=(self._state_shapes, input_dataset.output_shapes),
+ input_types=(self._state_types, input_dataset.output_types),
+ add_to_graph=False)
+ if not (
+ isinstance(wrapped_func.output_types, collections.Sequence) and
+ len(wrapped_func.output_types) == 2):
+ raise TypeError("The scan function must return a pair comprising the "
+ "new state and the output value.")
+
+ new_state_classes, self._output_classes = wrapped_func.output_classes
+
+ # Extract and validate class information from the returned values.
+ for new_state_class, state_class in zip(
+ nest.flatten(new_state_classes),
+ nest.flatten(self._state_classes)):
+ if not issubclass(new_state_class, state_class):
+ raise TypeError(
+ "The element classes for the new state must match the initial "
+ "state. Expected %s; got %s." %
+ (self._state_classes, new_state_classes))
+
+ # Extract and validate type information from the returned values.
+ new_state_types, self._output_types = wrapped_func.output_types
+ for new_state_type, state_type in zip(
+ nest.flatten(new_state_types), nest.flatten(self._state_types)):
+ if new_state_type != state_type:
+ raise TypeError(
+ "The element types for the new state must match the initial "
+ "state. Expected %s; got %s." %
+ (self._state_types, new_state_types))
+
+ # Extract shape information from the returned values.
+ new_state_shapes, self._output_shapes = wrapped_func.output_shapes
+
+ flat_state_shapes = nest.flatten(self._state_shapes)
+ flat_new_state_shapes = nest.flatten(new_state_shapes)
+ weakened_state_shapes = [
+ original.most_specific_compatible_shape(new)
+ for original, new in zip(flat_state_shapes, flat_new_state_shapes)
+ ]
+
+ need_to_rerun = False
+ for original_shape, weakened_shape in zip(flat_state_shapes,
+ weakened_state_shapes):
+ if original_shape.ndims is not None and (
+ weakened_shape.ndims is None or
+ original_shape.as_list() != weakened_shape.as_list()):
+ need_to_rerun = True
+ break
+
+ if need_to_rerun:
+ self._state_shapes = nest.pack_sequence_as(self._state_shapes,
+ weakened_state_shapes)
+
+ self._scan_func = wrapped_func.function
+ self._scan_func.add_to_graph(ops.get_default_graph())
+
+ def _as_variant_tensor(self):
+ input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access
+ return gen_dataset_ops.scan_dataset(
+ input_t,
+ nest.flatten(sparse.serialize_sparse_tensors(self._initial_state)),
+ self._scan_func.captured_inputs,
+ f=self._scan_func,
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_classes(self):
+ return self._output_classes
+
+ @property
+ def output_shapes(self):
+ return self._output_shapes
+
+ @property
+ def output_types(self):
+ return self._output_types
+
+
+@tf_export("data.experimental.scan")
+def scan(initial_state, scan_func):
+ """A transformation that scans a function across an input dataset.
+
+ This transformation is a stateful relative of `tf.data.Dataset.map`.
+ In addition to mapping `scan_func` across the elements of the input dataset,
+ `scan()` accumulates one or more state tensors, whose initial values are
+ `initial_state`.
+
+ Args:
+ initial_state: A nested structure of tensors, representing the initial state
+ of the accumulator.
+ scan_func: A function that maps `(old_state, input_element)` to
+ `(new_state, output_element). It must take two arguments and return a
+ pair of nested structures of tensors. The `new_state` must match the
+ structure of `initial_state`.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+ def _apply_fn(dataset):
+ return _ScanDataset(dataset, initial_state, scan_func)
+
+ return _apply_fn
diff --git a/tensorflow/python/data/experimental/ops/shuffle_ops.py b/tensorflow/python/data/experimental/ops/shuffle_ops.py
new file mode 100644
index 0000000000..a4307212da
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/shuffle_ops.py
@@ -0,0 +1,102 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Experimental shuffle ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import random_seed
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+class _ShuffleAndRepeatDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` that fuses `shuffle` and `repeat`."""
+
+ def __init__(self, input_dataset, buffer_size, count=None, seed=None):
+ super(_ShuffleAndRepeatDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+ self._buffer_size = ops.convert_to_tensor(
+ buffer_size, dtype=dtypes.int64, name="buffer_size")
+ if count is None:
+ self._count = constant_op.constant(-1, dtype=dtypes.int64, name="count")
+ else:
+ self._count = ops.convert_to_tensor(
+ count, dtype=dtypes.int64, name="count")
+ self._seed, self._seed2 = random_seed.get_seed(seed)
+
+ def _as_variant_tensor(self):
+ # pylint: disable=protected-access
+ input_resource = self._input_dataset._as_variant_tensor()
+ return gen_dataset_ops.shuffle_and_repeat_dataset(
+ input_resource,
+ buffer_size=self._buffer_size,
+ count=self._count,
+ seed=self._seed,
+ seed2=self._seed2,
+ **dataset_ops.flat_structure(self))
+ # pylint: enable=protected-access
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+
+@tf_export("data.experimental.shuffle_and_repeat")
+def shuffle_and_repeat(buffer_size, count=None, seed=None):
+ """Shuffles and repeats a Dataset returning a new permutation for each epoch.
+
+ `dataset.apply(tf.data.experimental.shuffle_and_repeat(buffer_size, count))`
+
+ is equivalent to
+
+ `dataset.shuffle(buffer_size, reshuffle_each_iteration=True).repeat(count)`
+
+ The difference is that the latter dataset is not serializable. So,
+ if you need to checkpoint an input pipeline with reshuffling you must use
+ this implementation.
+
+ Args:
+ buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the
+ maximum number elements that will be buffered when prefetching.
+ count: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
+ number of times the dataset should be repeated. The default behavior
+ (if `count` is `None` or `-1`) is for the dataset be repeated
+ indefinitely.
+ seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
+ random seed that will be used to create the distribution. See
+ `tf.set_random_seed` for behavior.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+
+ def _apply_fn(dataset): # pylint: disable=missing-docstring
+ return _ShuffleAndRepeatDataset(dataset, buffer_size, count, seed)
+
+ return _apply_fn
diff --git a/tensorflow/contrib/data/python/ops/stats_ops.py b/tensorflow/python/data/experimental/ops/stats_ops.py
index bc47c5989d..c918d223e8 100644
--- a/tensorflow/contrib/data/python/ops/stats_ops.py
+++ b/tensorflow/python/data/experimental/ops/stats_ops.py
@@ -21,8 +21,10 @@ from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.util.tf_export import tf_export
+@tf_export("data.experimental.StatsAggregator")
class StatsAggregator(object):
"""A stateful resource that aggregates statistics from one or more iterators.
@@ -34,7 +36,7 @@ class StatsAggregator(object):
```python
dataset = ...
- dataset = dataset.apply(stats_ops.latency_stats("total_bytes"))
+ dataset = dataset.apply(tf.data.experimental.latency_stats("total_bytes"))
```
To associate a `StatsAggregator` with a `tf.data.Dataset` object, use
@@ -46,7 +48,7 @@ class StatsAggregator(object):
# Apply `set_stats_aggregator` to associate `dataset` with `stats_aggregator`.
dataset = dataset.apply(
- tf.contrib.data.set_stats_aggregator(stats_aggregator))
+ tf.data.experimental.set_stats_aggregator(stats_aggregator))
iterator = dataset.make_one_shot_iterator()
```
@@ -111,11 +113,12 @@ class _SetStatsAggregatorDataset(dataset_ops.UnaryDataset):
return self._input_dataset.output_classes
+@tf_export("data.experimental.set_stats_aggregator")
def set_stats_aggregator(stats_aggregator):
"""Set the given `stats_aggregator` for aggregating the input dataset stats.
Args:
- stats_aggregator: A `tf.contrib.data.StatsAggregator` object.
+ stats_aggregator: A `tf.data.experimental.StatsAggregator` object.
Returns:
A `Dataset` transformation function, which can be passed to
@@ -128,8 +131,8 @@ def set_stats_aggregator(stats_aggregator):
return _apply_fn
-# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable
-# or make private / remove.
+# TODO(b/38416882): Properly export in the `tf.data.experimental` API when
+# stable or make private / remove.
def bytes_produced_stats(tag):
"""Records the number of bytes produced by each element of the input dataset.
@@ -152,6 +155,7 @@ def bytes_produced_stats(tag):
return _apply_fn
+@tf_export("data.experimental.latency_stats")
def latency_stats(tag):
"""Records the latency of producing each element of the input dataset.
diff --git a/tensorflow/python/data/experimental/ops/threadpool.py b/tensorflow/python/data/experimental/ops/threadpool.py
new file mode 100644
index 0000000000..3ea017c6e8
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/threadpool.py
@@ -0,0 +1,104 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Experimental API for controlling threading in `tf.data` pipelines."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import threading
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.eager import context
+from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
+from tensorflow.python.ops import resource_variable_ops
+
+_uid_counter = 0
+_uid_lock = threading.Lock()
+
+
+def _generate_shared_name(prefix):
+ with _uid_lock:
+ global _uid_counter
+ uid = _uid_counter
+ _uid_counter += 1
+ return "{}{}".format(prefix, uid)
+
+
+# TODO(b/73383364): Properly export in the `tf.data.experimental` API when
+# stable or make private / remove.
+class PrivateThreadPool(object):
+ """A stateful resource that represents a private thread pool."""
+
+ def __init__(self, num_threads, display_name=None,
+ max_intra_op_parallelism=1):
+ """Creates a `PrivateThreadPool` with the given number of threads."""
+ if context.executing_eagerly():
+ shared_name = _generate_shared_name("privatethreadpool")
+ self._resource = ged_ops.experimental_thread_pool_handle(
+ num_threads=num_threads,
+ max_intra_op_parallelism=max_intra_op_parallelism,
+ display_name=display_name,
+ shared_name=shared_name)
+ self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
+ handle=self._resource, handle_device=context.context().device_name)
+ else:
+ self._resource = ged_ops.experimental_thread_pool_handle(
+ num_threads=num_threads,
+ max_intra_op_parallelism=max_intra_op_parallelism,
+ display_name=display_name)
+
+
+class _ThreadPoolDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` that acts as an identity, and sets a custom threadpool."""
+
+ def __init__(self, input_dataset, thread_pool):
+ super(_ThreadPoolDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+ self._thread_pool = thread_pool
+
+ def _as_variant_tensor(self):
+ return ged_ops.experimental_thread_pool_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._thread_pool._resource, # pylint: disable=protected-access
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+
+# TODO(b/73383364): Properly export in the `tf.data.experimental` API when
+# stable or make private / remove.
+def override_threadpool(dataset, thread_pool):
+ """Returns a new dataset that uses the given thread pool for its operations.
+
+ Args:
+ dataset: A `tf.data.Dataset` object.
+ thread_pool: A `PrivateThreadPool` object.
+
+ Returns:
+ A dataset containing the same values as `dataset`, but which uses
+ `thread_pool` to compute any of its parallel operations (such as
+ `tf.data.Dataset.map`).
+ """
+ return _ThreadPoolDataset(dataset, thread_pool)
diff --git a/tensorflow/python/data/experimental/ops/unique.py b/tensorflow/python/data/experimental/ops/unique.py
new file mode 100644
index 0000000000..2a7775c456
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/unique.py
@@ -0,0 +1,79 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Unique element dataset transformations."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import gen_experimental_dataset_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("data.experimental.unique")
+def unique():
+ """Creates a `Dataset` from another `Dataset`, discarding duplicates.
+
+ Use this transformation to produce a dataset that contains one instance of
+ each unique element in the input. For example:
+
+ ```python
+ dataset = tf.data.Dataset.from_tensor_slices([1, 37, 2, 37, 2, 1])
+
+ # Using `unique()` will drop the duplicate elements.
+ dataset = dataset.apply(tf.data.experimental.unique()) # ==> { 1, 37, 2 }
+ ```
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+
+ def _apply_fn(dataset):
+ return _UniqueDataset(dataset)
+
+ return _apply_fn
+
+
+class _UniqueDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` contains the unique elements from its input."""
+
+ def __init__(self, input_dataset):
+ """See `unique()` for details."""
+ super(_UniqueDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+ if input_dataset.output_types not in (dtypes.int32, dtypes.int64,
+ dtypes.string):
+ raise TypeError(
+ "`tf.data.experimental.unique()` only supports inputs with a single "
+ "`tf.int32`, `tf.int64`, or `tf.string` component.")
+
+ def _as_variant_tensor(self):
+ return gen_experimental_dataset_ops.experimental_unique_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
diff --git a/tensorflow/python/data/experimental/ops/writers.py b/tensorflow/python/data/experimental/ops/writers.py
new file mode 100644
index 0000000000..994447cb4d
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/writers.py
@@ -0,0 +1,60 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python wrappers for tf.data writers."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import convert
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("data.experimental.TFRecordWriter")
+class TFRecordWriter(object):
+ """Writes data to a TFRecord file."""
+
+ def __init__(self, filename, compression_type=None):
+ self._filename = ops.convert_to_tensor(
+ filename, dtypes.string, name="filename")
+ self._compression_type = convert.optional_param_to_tensor(
+ "compression_type",
+ compression_type,
+ argument_default="",
+ argument_dtype=dtypes.string)
+
+ def write(self, dataset):
+ """Returns a `tf.Operation` to write a dataset to a file.
+
+ Args:
+ dataset: a `tf.data.Dataset` whose elements are to be written to a file
+
+ Returns:
+ A `tf.Operation` that, when run, writes contents of `dataset` to a file.
+ """
+ if not isinstance(dataset, dataset_ops.Dataset):
+ raise TypeError("`dataset` must be a `tf.data.Dataset` object.")
+ if (dataset.output_types != dtypes.string or
+ dataset.output_shapes != tensor_shape.scalar()):
+ raise TypeError(
+ "`dataset` must produce scalar `DT_STRING` tensors whereas it "
+ "produces shape {0} and types {1}".format(dataset.output_shapes,
+ dataset.output_types))
+ return gen_dataset_ops.dataset_to_tf_record(
+ dataset._as_variant_tensor(), self._filename, self._compression_type) # pylint: disable=protected-access
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index 6bba72a8e9..3b9d3a639d 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -889,8 +889,8 @@ class Dataset(object):
will be padded out to the maximum length of all elements in that
dimension.
- See also `tf.contrib.data.dense_to_sparse_batch`, which combines elements
- that may have different shapes into a `tf.SparseTensor`.
+ See also `tf.data.experimental.dense_to_sparse_batch`, which combines
+ elements that may have different shapes into a `tf.SparseTensor`.
Args:
batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
diff --git a/tensorflow/python/data/ops/optional_ops.py b/tensorflow/python/data/ops/optional_ops.py
index 3bbebd7878..aca989e03a 100644
--- a/tensorflow/python/data/ops/optional_ops.py
+++ b/tensorflow/python/data/ops/optional_ops.py
@@ -31,7 +31,7 @@ class Optional(object):
An `Optional` can represent the result of an operation that may fail as a
value, rather than raising an exception and halting execution. For example,
- `tf.contrib.data.get_next_as_optional` returns an `Optional` that either
+ `tf.data.experimental.get_next_as_optional` returns an `Optional` that either
contains the next value from a `tf.data.Iterator` if one exists, or a "none"
value that indicates the end of the sequence has been reached.
"""
@@ -111,7 +111,7 @@ class Optional(object):
class _OptionalImpl(Optional):
- """Concrete implementation of `tf.contrib.data.Optional`.
+ """Concrete implementation of `tf.data.experimental.Optional`.
NOTE(mrry): This implementation is kept private, to avoid defining
`Optional.__init__()` in the public API.
diff --git a/tensorflow/python/data/ops/readers.py b/tensorflow/python/data/ops/readers.py
index b0f26631f9..d08da6704c 100644
--- a/tensorflow/python/data/ops/readers.py
+++ b/tensorflow/python/data/ops/readers.py
@@ -129,7 +129,7 @@ class ParallelInterleaveDataset(dataset_ops.InterleaveDataset):
def __init__(self, input_dataset, map_func, cycle_length, block_length,
sloppy, buffer_output_elements, prefetch_input_elements):
- """See `tf.contrib.data.parallel_interleave()` for details."""
+ """See `tf.data.experimental.parallel_interleave()` for details."""
super(ParallelInterleaveDataset, self).__init__(input_dataset, map_func,
cycle_length, block_length)
self._sloppy = ops.convert_to_tensor(
@@ -158,7 +158,7 @@ class ParallelInterleaveDataset(dataset_ops.InterleaveDataset):
# pylint: enable=protected-access
def _transformation_name(self):
- return "tf.contrib.data.parallel_interleave()"
+ return "tf.data.experimental.parallel_interleave()"
@tf_export("data.TFRecordDataset")
diff --git a/tensorflow/python/debug/examples/debug_tflearn_iris.py b/tensorflow/python/debug/examples/debug_tflearn_iris.py
index 019f13c450..f9bb3148fb 100644
--- a/tensorflow/python/debug/examples/debug_tflearn_iris.py
+++ b/tensorflow/python/debug/examples/debug_tflearn_iris.py
@@ -94,13 +94,15 @@ def main(_):
"sepal_length", "sepal_width", "petal_length", "petal_width", "label"]
batch_size = 32
def training_input_fn():
- return tf.contrib.data.make_csv_dataset(
- [training_data_path], batch_size,
- column_names=column_names, label_name="label")
+ return tf.data.experimental.make_csv_dataset([training_data_path],
+ batch_size,
+ column_names=column_names,
+ label_name="label")
def test_input_fn():
- return tf.contrib.data.make_csv_dataset(
- [test_data_path], batch_size,
- column_names=column_names, label_name="label")
+ return tf.data.experimental.make_csv_dataset([test_data_path],
+ batch_size,
+ column_names=column_names,
+ label_name="label")
feature_columns = [tf.feature_column.numeric_column(feature)
for feature in column_names[:-1]]
diff --git a/tensorflow/python/tools/api/generator/api_init_files.bzl b/tensorflow/python/tools/api/generator/api_init_files.bzl
index 5ce5410e0b..533a138a39 100644
--- a/tensorflow/python/tools/api/generator/api_init_files.bzl
+++ b/tensorflow/python/tools/api/generator/api_init_files.bzl
@@ -8,6 +8,7 @@ TENSORFLOW_API_INIT_FILES = [
"bitwise/__init__.py",
"compat/__init__.py",
"data/__init__.py",
+ "data/experimental/__init__.py",
"debugging/__init__.py",
"distributions/__init__.py",
"dtypes/__init__.py",
diff --git a/tensorflow/python/tools/api/generator/api_init_files_v1.bzl b/tensorflow/python/tools/api/generator/api_init_files_v1.bzl
index 587eb232f5..0747424eab 100644
--- a/tensorflow/python/tools/api/generator/api_init_files_v1.bzl
+++ b/tensorflow/python/tools/api/generator/api_init_files_v1.bzl
@@ -8,6 +8,7 @@ TENSORFLOW_API_INIT_FILES_V1 = [
"bitwise/__init__.py",
"compat/__init__.py",
"data/__init__.py",
+ "data/experimental/__init__.py",
"debugging/__init__.py",
"distributions/__init__.py",
"dtypes/__init__.py",
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-checkpoint-input-pipeline-hook.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-checkpoint-input-pipeline-hook.pbtxt
new file mode 100644
index 0000000000..03c16cda8b
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-checkpoint-input-pipeline-hook.pbtxt
@@ -0,0 +1,30 @@
+path: "tensorflow.data.experimental.CheckpointInputPipelineHook"
+tf_class {
+ is_instance: "<class \'tensorflow.python.data.experimental.ops.iterator_ops.CheckpointInputPipelineHook\'>"
+ is_instance: "<class \'tensorflow.python.training.session_run_hook.SessionRunHook\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'estimator\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "after_create_session"
+ argspec: "args=[\'self\', \'session\', \'coord\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "after_run"
+ argspec: "args=[\'self\', \'run_context\', \'run_values\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "before_run"
+ argspec: "args=[\'self\', \'run_context\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "begin"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "end"
+ argspec: "args=[\'self\', \'session\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.__metaclass__.pbtxt
new file mode 100644
index 0000000000..3eeaa1b185
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.__metaclass__.pbtxt
@@ -0,0 +1,14 @@
+path: "tensorflow.data.experimental.CsvDataset.__metaclass__"
+tf_class {
+ is_instance: "<class \'abc.ABCMeta\'>"
+ member_method {
+ name: "__init__"
+ }
+ member_method {
+ name: "mro"
+ }
+ member_method {
+ name: "register"
+ argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.pbtxt
new file mode 100644
index 0000000000..0c0405ee02
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.pbtxt
@@ -0,0 +1,127 @@
+path: "tensorflow.data.experimental.CsvDataset"
+tf_class {
+ is_instance: "<class \'tensorflow.python.data.experimental.ops.readers.CsvDataset\'>"
+ is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetSource\'>"
+ is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.Dataset\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "output_classes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_shapes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_types"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'filenames\', \'record_defaults\', \'compression_type\', \'buffer_size\', \'header\', \'field_delim\', \'use_quote_delim\', \'na_value\', \'select_cols\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \',\', \'True\', \'\', \'None\'], "
+ }
+ member_method {
+ name: "apply"
+ argspec: "args=[\'self\', \'transformation_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "batch"
+ argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], "
+ }
+ member_method {
+ name: "cache"
+ argspec: "args=[\'self\', \'filename\'], varargs=None, keywords=None, defaults=[\'\'], "
+ }
+ member_method {
+ name: "concatenate"
+ argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "filter"
+ argspec: "args=[\'self\', \'predicate\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "flat_map"
+ argspec: "args=[\'self\', \'map_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_generator"
+ argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "from_sparse_tensor_slices"
+ argspec: "args=[\'sparse_tensor\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_tensor_slices"
+ argspec: "args=[\'tensors\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_tensors"
+ argspec: "args=[\'tensors\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "interleave"
+ argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
+ }
+ member_method {
+ name: "list_files"
+ argspec: "args=[\'file_pattern\', \'shuffle\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "make_initializable_iterator"
+ argspec: "args=[\'self\', \'shared_name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "make_one_shot_iterator"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "map"
+ argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "padded_batch"
+ argspec: "args=[\'self\', \'batch_size\', \'padded_shapes\', \'padding_values\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], "
+ }
+ member_method {
+ name: "prefetch"
+ argspec: "args=[\'self\', \'buffer_size\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "range"
+ argspec: "args=[], varargs=args, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "reduce"
+ argspec: "args=[\'self\', \'initial_state\', \'reduce_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "repeat"
+ argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "shard"
+ argspec: "args=[\'self\', \'num_shards\', \'index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "shuffle"
+ argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "skip"
+ argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "take"
+ argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "window"
+ argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
+ }
+ member_method {
+ name: "zip"
+ argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-optional.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-optional.pbtxt
new file mode 100644
index 0000000000..b4c9459098
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-optional.pbtxt
@@ -0,0 +1,28 @@
+path: "tensorflow.data.experimental.Optional"
+tf_class {
+ is_instance: "<class \'tensorflow.python.data.ops.optional_ops.Optional\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "value_structure"
+ mtype: "<class \'abc.abstractproperty\'>"
+ }
+ member_method {
+ name: "__init__"
+ }
+ member_method {
+ name: "from_value"
+ argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_value"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "has_value"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "none_from_structure"
+ argspec: "args=[\'value_structure\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.__metaclass__.pbtxt
new file mode 100644
index 0000000000..2991b12f64
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.__metaclass__.pbtxt
@@ -0,0 +1,14 @@
+path: "tensorflow.data.experimental.RandomDataset.__metaclass__"
+tf_class {
+ is_instance: "<class \'abc.ABCMeta\'>"
+ member_method {
+ name: "__init__"
+ }
+ member_method {
+ name: "mro"
+ }
+ member_method {
+ name: "register"
+ argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.pbtxt
new file mode 100644
index 0000000000..bce0be4b17
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.pbtxt
@@ -0,0 +1,127 @@
+path: "tensorflow.data.experimental.RandomDataset"
+tf_class {
+ is_instance: "<class \'tensorflow.python.data.experimental.ops.random_ops.RandomDataset\'>"
+ is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetSource\'>"
+ is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.Dataset\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "output_classes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_shapes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_types"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "apply"
+ argspec: "args=[\'self\', \'transformation_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "batch"
+ argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], "
+ }
+ member_method {
+ name: "cache"
+ argspec: "args=[\'self\', \'filename\'], varargs=None, keywords=None, defaults=[\'\'], "
+ }
+ member_method {
+ name: "concatenate"
+ argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "filter"
+ argspec: "args=[\'self\', \'predicate\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "flat_map"
+ argspec: "args=[\'self\', \'map_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_generator"
+ argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "from_sparse_tensor_slices"
+ argspec: "args=[\'sparse_tensor\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_tensor_slices"
+ argspec: "args=[\'tensors\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_tensors"
+ argspec: "args=[\'tensors\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "interleave"
+ argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
+ }
+ member_method {
+ name: "list_files"
+ argspec: "args=[\'file_pattern\', \'shuffle\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "make_initializable_iterator"
+ argspec: "args=[\'self\', \'shared_name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "make_one_shot_iterator"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "map"
+ argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "padded_batch"
+ argspec: "args=[\'self\', \'batch_size\', \'padded_shapes\', \'padding_values\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], "
+ }
+ member_method {
+ name: "prefetch"
+ argspec: "args=[\'self\', \'buffer_size\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "range"
+ argspec: "args=[], varargs=args, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "reduce"
+ argspec: "args=[\'self\', \'initial_state\', \'reduce_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "repeat"
+ argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "shard"
+ argspec: "args=[\'self\', \'num_shards\', \'index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "shuffle"
+ argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "skip"
+ argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "take"
+ argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "window"
+ argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
+ }
+ member_method {
+ name: "zip"
+ argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-reducer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-reducer.pbtxt
new file mode 100644
index 0000000000..6b477a8a72
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-reducer.pbtxt
@@ -0,0 +1,21 @@
+path: "tensorflow.data.experimental.Reducer"
+tf_class {
+ is_instance: "<class \'tensorflow.python.data.experimental.ops.grouping.Reducer\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "finalize_func"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "init_func"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "reduce_func"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'init_func\', \'reduce_func\', \'finalize_func\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.__metaclass__.pbtxt
new file mode 100644
index 0000000000..948e99ef86
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.__metaclass__.pbtxt
@@ -0,0 +1,14 @@
+path: "tensorflow.data.experimental.SqlDataset.__metaclass__"
+tf_class {
+ is_instance: "<class \'abc.ABCMeta\'>"
+ member_method {
+ name: "__init__"
+ }
+ member_method {
+ name: "mro"
+ }
+ member_method {
+ name: "register"
+ argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.pbtxt
new file mode 100644
index 0000000000..8aeae92d96
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.pbtxt
@@ -0,0 +1,127 @@
+path: "tensorflow.data.experimental.SqlDataset"
+tf_class {
+ is_instance: "<class \'tensorflow.python.data.experimental.ops.readers.SqlDataset\'>"
+ is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetSource\'>"
+ is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.Dataset\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "output_classes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_shapes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_types"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'driver_name\', \'data_source_name\', \'query\', \'output_types\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "apply"
+ argspec: "args=[\'self\', \'transformation_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "batch"
+ argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], "
+ }
+ member_method {
+ name: "cache"
+ argspec: "args=[\'self\', \'filename\'], varargs=None, keywords=None, defaults=[\'\'], "
+ }
+ member_method {
+ name: "concatenate"
+ argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "filter"
+ argspec: "args=[\'self\', \'predicate\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "flat_map"
+ argspec: "args=[\'self\', \'map_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_generator"
+ argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "from_sparse_tensor_slices"
+ argspec: "args=[\'sparse_tensor\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_tensor_slices"
+ argspec: "args=[\'tensors\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_tensors"
+ argspec: "args=[\'tensors\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "interleave"
+ argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
+ }
+ member_method {
+ name: "list_files"
+ argspec: "args=[\'file_pattern\', \'shuffle\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "make_initializable_iterator"
+ argspec: "args=[\'self\', \'shared_name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "make_one_shot_iterator"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "map"
+ argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "padded_batch"
+ argspec: "args=[\'self\', \'batch_size\', \'padded_shapes\', \'padding_values\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], "
+ }
+ member_method {
+ name: "prefetch"
+ argspec: "args=[\'self\', \'buffer_size\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "range"
+ argspec: "args=[], varargs=args, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "reduce"
+ argspec: "args=[\'self\', \'initial_state\', \'reduce_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "repeat"
+ argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "shard"
+ argspec: "args=[\'self\', \'num_shards\', \'index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "shuffle"
+ argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "skip"
+ argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "take"
+ argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "window"
+ argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
+ }
+ member_method {
+ name: "zip"
+ argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-stats-aggregator.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-stats-aggregator.pbtxt
new file mode 100644
index 0000000000..0bcc8cf3e8
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-stats-aggregator.pbtxt
@@ -0,0 +1,13 @@
+path: "tensorflow.data.experimental.StatsAggregator"
+tf_class {
+ is_instance: "<class \'tensorflow.python.data.experimental.ops.stats_ops.StatsAggregator\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_summary"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-t-f-record-writer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-t-f-record-writer.pbtxt
new file mode 100644
index 0000000000..6f9d18a701
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-t-f-record-writer.pbtxt
@@ -0,0 +1,13 @@
+path: "tensorflow.data.experimental.TFRecordWriter"
+tf_class {
+ is_instance: "<class \'tensorflow.python.data.experimental.ops.writers.TFRecordWriter\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'filename\', \'compression_type\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "write"
+ argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.pbtxt
new file mode 100644
index 0000000000..b14585f8d7
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.pbtxt
@@ -0,0 +1,139 @@
+path: "tensorflow.data.experimental"
+tf_module {
+ member {
+ name: "CheckpointInputPipelineHook"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "CsvDataset"
+ mtype: "<class \'abc.ABCMeta\'>"
+ }
+ member {
+ name: "Optional"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "RandomDataset"
+ mtype: "<class \'abc.ABCMeta\'>"
+ }
+ member {
+ name: "Reducer"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "SqlDataset"
+ mtype: "<class \'abc.ABCMeta\'>"
+ }
+ member {
+ name: "StatsAggregator"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "TFRecordWriter"
+ mtype: "<type \'type\'>"
+ }
+ member_method {
+ name: "Counter"
+ argspec: "args=[\'start\', \'step\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\"], "
+ }
+ member_method {
+ name: "bucket_by_sequence_length"
+ argspec: "args=[\'element_length_func\', \'bucket_boundaries\', \'bucket_batch_sizes\', \'padded_shapes\', \'padding_values\', \'pad_to_bucket_boundary\', \'no_padding\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'False\'], "
+ }
+ member_method {
+ name: "choose_from_datasets"
+ argspec: "args=[\'datasets\', \'choice_dataset\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "copy_to_device"
+ argspec: "args=[\'target_device\', \'source_device\'], varargs=None, keywords=None, defaults=[\'/cpu:0\'], "
+ }
+ member_method {
+ name: "dense_to_sparse_batch"
+ argspec: "args=[\'batch_size\', \'row_shape\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "enumerate_dataset"
+ argspec: "args=[\'start\'], varargs=None, keywords=None, defaults=[\'0\'], "
+ }
+ member_method {
+ name: "get_next_as_optional"
+ argspec: "args=[\'iterator\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_single_element"
+ argspec: "args=[\'dataset\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "group_by_reducer"
+ argspec: "args=[\'key_func\', \'reducer\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "group_by_window"
+ argspec: "args=[\'key_func\', \'reduce_func\', \'window_size\', \'window_size_func\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "ignore_errors"
+ argspec: "args=[], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "latency_stats"
+ argspec: "args=[\'tag\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "make_batched_features_dataset"
+ argspec: "args=[\'file_pattern\', \'batch_size\', \'features\', \'reader\', \'label_key\', \'reader_args\', \'num_epochs\', \'shuffle\', \'shuffle_buffer_size\', \'shuffle_seed\', \'prefetch_buffer_size\', \'reader_num_threads\', \'parser_num_threads\', \'sloppy_ordering\', \'drop_final_batch\'], varargs=None, keywords=None, defaults=[\"<class \'tensorflow.python.data.ops.readers.TFRecordDataset\'>\", \'None\', \'None\', \'None\', \'True\', \'10000\', \'None\', \'-1\', \'1\', \'2\', \'False\', \'False\'], "
+ }
+ member_method {
+ name: "make_csv_dataset"
+ argspec: "args=[\'file_pattern\', \'batch_size\', \'column_names\', \'column_defaults\', \'label_name\', \'select_columns\', \'field_delim\', \'use_quote_delim\', \'na_value\', \'header\', \'num_epochs\', \'shuffle\', \'shuffle_buffer_size\', \'shuffle_seed\', \'prefetch_buffer_size\', \'num_parallel_reads\', \'sloppy\', \'num_rows_for_inference\', \'compression_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \',\', \'True\', \'\', \'True\', \'None\', \'True\', \'10000\', \'None\', \'-1\', \'1\', \'False\', \'100\', \'None\'], "
+ }
+ member_method {
+ name: "make_saveable_from_iterator"
+ argspec: "args=[\'iterator\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "map_and_batch"
+ argspec: "args=[\'map_func\', \'batch_size\', \'num_parallel_batches\', \'drop_remainder\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
+ name: "parallel_interleave"
+ argspec: "args=[\'map_func\', \'cycle_length\', \'block_length\', \'sloppy\', \'buffer_output_elements\', \'prefetch_input_elements\'], varargs=None, keywords=None, defaults=[\'1\', \'False\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "parse_example_dataset"
+ argspec: "args=[\'features\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\'], "
+ }
+ member_method {
+ name: "prefetch_to_device"
+ argspec: "args=[\'device\', \'buffer_size\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "rejection_resample"
+ argspec: "args=[\'class_func\', \'target_dist\', \'initial_dist\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "sample_from_datasets"
+ argspec: "args=[\'datasets\', \'weights\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "scan"
+ argspec: "args=[\'initial_state\', \'scan_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "set_stats_aggregator"
+ argspec: "args=[\'stats_aggregator\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "shuffle_and_repeat"
+ argspec: "args=[\'buffer_size\', \'count\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "unbatch"
+ argspec: "args=[], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "unique"
+ argspec: "args=[], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.pbtxt
index 56fb270a49..e205157523 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.pbtxt
@@ -20,4 +20,8 @@ tf_module {
name: "TextLineDataset"
mtype: "<class \'abc.ABCMeta\'>"
}
+ member {
+ name: "experimental"
+ mtype: "<type \'module\'>"
+ }
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-checkpoint-input-pipeline-hook.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-checkpoint-input-pipeline-hook.pbtxt
new file mode 100644
index 0000000000..03c16cda8b
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-checkpoint-input-pipeline-hook.pbtxt
@@ -0,0 +1,30 @@
+path: "tensorflow.data.experimental.CheckpointInputPipelineHook"
+tf_class {
+ is_instance: "<class \'tensorflow.python.data.experimental.ops.iterator_ops.CheckpointInputPipelineHook\'>"
+ is_instance: "<class \'tensorflow.python.training.session_run_hook.SessionRunHook\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'estimator\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "after_create_session"
+ argspec: "args=[\'self\', \'session\', \'coord\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "after_run"
+ argspec: "args=[\'self\', \'run_context\', \'run_values\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "before_run"
+ argspec: "args=[\'self\', \'run_context\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "begin"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "end"
+ argspec: "args=[\'self\', \'session\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.__metaclass__.pbtxt
new file mode 100644
index 0000000000..3eeaa1b185
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.__metaclass__.pbtxt
@@ -0,0 +1,14 @@
+path: "tensorflow.data.experimental.CsvDataset.__metaclass__"
+tf_class {
+ is_instance: "<class \'abc.ABCMeta\'>"
+ member_method {
+ name: "__init__"
+ }
+ member_method {
+ name: "mro"
+ }
+ member_method {
+ name: "register"
+ argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt
new file mode 100644
index 0000000000..0c0405ee02
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt
@@ -0,0 +1,127 @@
+path: "tensorflow.data.experimental.CsvDataset"
+tf_class {
+ is_instance: "<class \'tensorflow.python.data.experimental.ops.readers.CsvDataset\'>"
+ is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetSource\'>"
+ is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.Dataset\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "output_classes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_shapes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_types"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'filenames\', \'record_defaults\', \'compression_type\', \'buffer_size\', \'header\', \'field_delim\', \'use_quote_delim\', \'na_value\', \'select_cols\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \',\', \'True\', \'\', \'None\'], "
+ }
+ member_method {
+ name: "apply"
+ argspec: "args=[\'self\', \'transformation_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "batch"
+ argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], "
+ }
+ member_method {
+ name: "cache"
+ argspec: "args=[\'self\', \'filename\'], varargs=None, keywords=None, defaults=[\'\'], "
+ }
+ member_method {
+ name: "concatenate"
+ argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "filter"
+ argspec: "args=[\'self\', \'predicate\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "flat_map"
+ argspec: "args=[\'self\', \'map_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_generator"
+ argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "from_sparse_tensor_slices"
+ argspec: "args=[\'sparse_tensor\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_tensor_slices"
+ argspec: "args=[\'tensors\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_tensors"
+ argspec: "args=[\'tensors\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "interleave"
+ argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
+ }
+ member_method {
+ name: "list_files"
+ argspec: "args=[\'file_pattern\', \'shuffle\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "make_initializable_iterator"
+ argspec: "args=[\'self\', \'shared_name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "make_one_shot_iterator"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "map"
+ argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "padded_batch"
+ argspec: "args=[\'self\', \'batch_size\', \'padded_shapes\', \'padding_values\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], "
+ }
+ member_method {
+ name: "prefetch"
+ argspec: "args=[\'self\', \'buffer_size\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "range"
+ argspec: "args=[], varargs=args, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "reduce"
+ argspec: "args=[\'self\', \'initial_state\', \'reduce_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "repeat"
+ argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "shard"
+ argspec: "args=[\'self\', \'num_shards\', \'index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "shuffle"
+ argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "skip"
+ argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "take"
+ argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "window"
+ argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
+ }
+ member_method {
+ name: "zip"
+ argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-optional.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-optional.pbtxt
new file mode 100644
index 0000000000..b4c9459098
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-optional.pbtxt
@@ -0,0 +1,28 @@
+path: "tensorflow.data.experimental.Optional"
+tf_class {
+ is_instance: "<class \'tensorflow.python.data.ops.optional_ops.Optional\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "value_structure"
+ mtype: "<class \'abc.abstractproperty\'>"
+ }
+ member_method {
+ name: "__init__"
+ }
+ member_method {
+ name: "from_value"
+ argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_value"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "has_value"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "none_from_structure"
+ argspec: "args=[\'value_structure\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.__metaclass__.pbtxt
new file mode 100644
index 0000000000..2991b12f64
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.__metaclass__.pbtxt
@@ -0,0 +1,14 @@
+path: "tensorflow.data.experimental.RandomDataset.__metaclass__"
+tf_class {
+ is_instance: "<class \'abc.ABCMeta\'>"
+ member_method {
+ name: "__init__"
+ }
+ member_method {
+ name: "mro"
+ }
+ member_method {
+ name: "register"
+ argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt
new file mode 100644
index 0000000000..bce0be4b17
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt
@@ -0,0 +1,127 @@
+path: "tensorflow.data.experimental.RandomDataset"
+tf_class {
+ is_instance: "<class \'tensorflow.python.data.experimental.ops.random_ops.RandomDataset\'>"
+ is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetSource\'>"
+ is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.Dataset\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "output_classes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_shapes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_types"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "apply"
+ argspec: "args=[\'self\', \'transformation_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "batch"
+ argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], "
+ }
+ member_method {
+ name: "cache"
+ argspec: "args=[\'self\', \'filename\'], varargs=None, keywords=None, defaults=[\'\'], "
+ }
+ member_method {
+ name: "concatenate"
+ argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "filter"
+ argspec: "args=[\'self\', \'predicate\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "flat_map"
+ argspec: "args=[\'self\', \'map_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_generator"
+ argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "from_sparse_tensor_slices"
+ argspec: "args=[\'sparse_tensor\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_tensor_slices"
+ argspec: "args=[\'tensors\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_tensors"
+ argspec: "args=[\'tensors\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "interleave"
+ argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
+ }
+ member_method {
+ name: "list_files"
+ argspec: "args=[\'file_pattern\', \'shuffle\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "make_initializable_iterator"
+ argspec: "args=[\'self\', \'shared_name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "make_one_shot_iterator"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "map"
+ argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "padded_batch"
+ argspec: "args=[\'self\', \'batch_size\', \'padded_shapes\', \'padding_values\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], "
+ }
+ member_method {
+ name: "prefetch"
+ argspec: "args=[\'self\', \'buffer_size\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "range"
+ argspec: "args=[], varargs=args, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "reduce"
+ argspec: "args=[\'self\', \'initial_state\', \'reduce_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "repeat"
+ argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "shard"
+ argspec: "args=[\'self\', \'num_shards\', \'index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "shuffle"
+ argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "skip"
+ argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "take"
+ argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "window"
+ argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
+ }
+ member_method {
+ name: "zip"
+ argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-reducer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-reducer.pbtxt
new file mode 100644
index 0000000000..6b477a8a72
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-reducer.pbtxt
@@ -0,0 +1,21 @@
+path: "tensorflow.data.experimental.Reducer"
+tf_class {
+ is_instance: "<class \'tensorflow.python.data.experimental.ops.grouping.Reducer\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "finalize_func"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "init_func"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "reduce_func"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'init_func\', \'reduce_func\', \'finalize_func\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.__metaclass__.pbtxt
new file mode 100644
index 0000000000..948e99ef86
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.__metaclass__.pbtxt
@@ -0,0 +1,14 @@
+path: "tensorflow.data.experimental.SqlDataset.__metaclass__"
+tf_class {
+ is_instance: "<class \'abc.ABCMeta\'>"
+ member_method {
+ name: "__init__"
+ }
+ member_method {
+ name: "mro"
+ }
+ member_method {
+ name: "register"
+ argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt
new file mode 100644
index 0000000000..8aeae92d96
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt
@@ -0,0 +1,127 @@
+path: "tensorflow.data.experimental.SqlDataset"
+tf_class {
+ is_instance: "<class \'tensorflow.python.data.experimental.ops.readers.SqlDataset\'>"
+ is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetSource\'>"
+ is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.Dataset\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "output_classes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_shapes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_types"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'driver_name\', \'data_source_name\', \'query\', \'output_types\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "apply"
+ argspec: "args=[\'self\', \'transformation_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "batch"
+ argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], "
+ }
+ member_method {
+ name: "cache"
+ argspec: "args=[\'self\', \'filename\'], varargs=None, keywords=None, defaults=[\'\'], "
+ }
+ member_method {
+ name: "concatenate"
+ argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "filter"
+ argspec: "args=[\'self\', \'predicate\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "flat_map"
+ argspec: "args=[\'self\', \'map_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_generator"
+ argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "from_sparse_tensor_slices"
+ argspec: "args=[\'sparse_tensor\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_tensor_slices"
+ argspec: "args=[\'tensors\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_tensors"
+ argspec: "args=[\'tensors\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "interleave"
+ argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
+ }
+ member_method {
+ name: "list_files"
+ argspec: "args=[\'file_pattern\', \'shuffle\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "make_initializable_iterator"
+ argspec: "args=[\'self\', \'shared_name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "make_one_shot_iterator"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "map"
+ argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "padded_batch"
+ argspec: "args=[\'self\', \'batch_size\', \'padded_shapes\', \'padding_values\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], "
+ }
+ member_method {
+ name: "prefetch"
+ argspec: "args=[\'self\', \'buffer_size\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "range"
+ argspec: "args=[], varargs=args, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "reduce"
+ argspec: "args=[\'self\', \'initial_state\', \'reduce_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "repeat"
+ argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "shard"
+ argspec: "args=[\'self\', \'num_shards\', \'index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "shuffle"
+ argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "skip"
+ argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "take"
+ argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "window"
+ argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
+ }
+ member_method {
+ name: "zip"
+ argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-stats-aggregator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-stats-aggregator.pbtxt
new file mode 100644
index 0000000000..0bcc8cf3e8
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-stats-aggregator.pbtxt
@@ -0,0 +1,13 @@
+path: "tensorflow.data.experimental.StatsAggregator"
+tf_class {
+ is_instance: "<class \'tensorflow.python.data.experimental.ops.stats_ops.StatsAggregator\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_summary"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-t-f-record-writer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-t-f-record-writer.pbtxt
new file mode 100644
index 0000000000..6f9d18a701
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-t-f-record-writer.pbtxt
@@ -0,0 +1,13 @@
+path: "tensorflow.data.experimental.TFRecordWriter"
+tf_class {
+ is_instance: "<class \'tensorflow.python.data.experimental.ops.writers.TFRecordWriter\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'filename\', \'compression_type\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "write"
+ argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.pbtxt
new file mode 100644
index 0000000000..b14585f8d7
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.pbtxt
@@ -0,0 +1,139 @@
+path: "tensorflow.data.experimental"
+tf_module {
+ member {
+ name: "CheckpointInputPipelineHook"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "CsvDataset"
+ mtype: "<class \'abc.ABCMeta\'>"
+ }
+ member {
+ name: "Optional"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "RandomDataset"
+ mtype: "<class \'abc.ABCMeta\'>"
+ }
+ member {
+ name: "Reducer"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "SqlDataset"
+ mtype: "<class \'abc.ABCMeta\'>"
+ }
+ member {
+ name: "StatsAggregator"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "TFRecordWriter"
+ mtype: "<type \'type\'>"
+ }
+ member_method {
+ name: "Counter"
+ argspec: "args=[\'start\', \'step\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\"], "
+ }
+ member_method {
+ name: "bucket_by_sequence_length"
+ argspec: "args=[\'element_length_func\', \'bucket_boundaries\', \'bucket_batch_sizes\', \'padded_shapes\', \'padding_values\', \'pad_to_bucket_boundary\', \'no_padding\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'False\'], "
+ }
+ member_method {
+ name: "choose_from_datasets"
+ argspec: "args=[\'datasets\', \'choice_dataset\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "copy_to_device"
+ argspec: "args=[\'target_device\', \'source_device\'], varargs=None, keywords=None, defaults=[\'/cpu:0\'], "
+ }
+ member_method {
+ name: "dense_to_sparse_batch"
+ argspec: "args=[\'batch_size\', \'row_shape\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "enumerate_dataset"
+ argspec: "args=[\'start\'], varargs=None, keywords=None, defaults=[\'0\'], "
+ }
+ member_method {
+ name: "get_next_as_optional"
+ argspec: "args=[\'iterator\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_single_element"
+ argspec: "args=[\'dataset\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "group_by_reducer"
+ argspec: "args=[\'key_func\', \'reducer\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "group_by_window"
+ argspec: "args=[\'key_func\', \'reduce_func\', \'window_size\', \'window_size_func\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "ignore_errors"
+ argspec: "args=[], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "latency_stats"
+ argspec: "args=[\'tag\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "make_batched_features_dataset"
+ argspec: "args=[\'file_pattern\', \'batch_size\', \'features\', \'reader\', \'label_key\', \'reader_args\', \'num_epochs\', \'shuffle\', \'shuffle_buffer_size\', \'shuffle_seed\', \'prefetch_buffer_size\', \'reader_num_threads\', \'parser_num_threads\', \'sloppy_ordering\', \'drop_final_batch\'], varargs=None, keywords=None, defaults=[\"<class \'tensorflow.python.data.ops.readers.TFRecordDataset\'>\", \'None\', \'None\', \'None\', \'True\', \'10000\', \'None\', \'-1\', \'1\', \'2\', \'False\', \'False\'], "
+ }
+ member_method {
+ name: "make_csv_dataset"
+ argspec: "args=[\'file_pattern\', \'batch_size\', \'column_names\', \'column_defaults\', \'label_name\', \'select_columns\', \'field_delim\', \'use_quote_delim\', \'na_value\', \'header\', \'num_epochs\', \'shuffle\', \'shuffle_buffer_size\', \'shuffle_seed\', \'prefetch_buffer_size\', \'num_parallel_reads\', \'sloppy\', \'num_rows_for_inference\', \'compression_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \',\', \'True\', \'\', \'True\', \'None\', \'True\', \'10000\', \'None\', \'-1\', \'1\', \'False\', \'100\', \'None\'], "
+ }
+ member_method {
+ name: "make_saveable_from_iterator"
+ argspec: "args=[\'iterator\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "map_and_batch"
+ argspec: "args=[\'map_func\', \'batch_size\', \'num_parallel_batches\', \'drop_remainder\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
+ name: "parallel_interleave"
+ argspec: "args=[\'map_func\', \'cycle_length\', \'block_length\', \'sloppy\', \'buffer_output_elements\', \'prefetch_input_elements\'], varargs=None, keywords=None, defaults=[\'1\', \'False\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "parse_example_dataset"
+ argspec: "args=[\'features\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\'], "
+ }
+ member_method {
+ name: "prefetch_to_device"
+ argspec: "args=[\'device\', \'buffer_size\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "rejection_resample"
+ argspec: "args=[\'class_func\', \'target_dist\', \'initial_dist\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "sample_from_datasets"
+ argspec: "args=[\'datasets\', \'weights\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "scan"
+ argspec: "args=[\'initial_state\', \'scan_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "set_stats_aggregator"
+ argspec: "args=[\'stats_aggregator\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "shuffle_and_repeat"
+ argspec: "args=[\'buffer_size\', \'count\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "unbatch"
+ argspec: "args=[], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "unique"
+ argspec: "args=[], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.pbtxt
index 56fb270a49..e205157523 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.pbtxt
@@ -20,4 +20,8 @@ tf_module {
name: "TextLineDataset"
mtype: "<class \'abc.ABCMeta\'>"
}
+ member {
+ name: "experimental"
+ mtype: "<type \'module\'>"
+ }
}
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index 3a1c4a45d4..164b3d8303 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -64,8 +64,6 @@ COMMON_PIP_DEPS = [
"//tensorflow/contrib/cluster_resolver:cluster_resolver_pip",
"//tensorflow/contrib/compiler:xla",
"//tensorflow/contrib/constrained_optimization:constrained_optimization_pip",
- "//tensorflow/contrib/data/python/kernel_tests/serialization:dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/kernel_tests:stats_dataset_test_base",
"//tensorflow/contrib/eager/python/examples:examples_pip",
"//tensorflow/contrib/eager/python:evaluator",
"//tensorflow/contrib/gan:gan",
@@ -106,6 +104,8 @@ COMMON_PIP_DEPS = [
"//tensorflow/python:meta_graph_testdata",
"//tensorflow/python:spectral_ops_test_util",
"//tensorflow/python:util_example_parser_configuration",
+ "//tensorflow/python/data/experimental/kernel_tests/serialization:dataset_serialization_test_base",
+ "//tensorflow/python/data/experimental/kernel_tests:stats_dataset_test_base",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/debug:debug_pip",
"//tensorflow/python/eager:eager_pip",