aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-10-01 16:45:11 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-01 16:50:05 -0700
commitb72265dc002e712fc3d0f33434f13c7a36a484b2 (patch)
treef92d1f23c329654772f95d93f5cf4458741b72df /tensorflow/contrib/data
parentbb1f9e1a57c8bc18325b3c86298be96e6647a0a3 (diff)
[tf.data] Deprecate `tf.contrib.data` and introduce `tf.data.experimental` to replace it.
This change prepares `tf.data` for TensorFlow 2.0, where `tf.contrib` will no longer exist. It retains the pre-existing endpoints in `tf.contrib.data` with deprecation warnings. Note there are some exceptions to the move: * Deprecated symbols in `tf.contrib.data` have not been moved to `tf.data.experimental`, because replacements already exist. * `tf.contrib.data.LMDBDataset` has not been moved, because we plan to move it to a SIG-maintained repository. * `tf.contrib.data.assert_element_shape()` has not yet been moved, because it depends on functionality in `tf.contrib`, and it will move in a later change. * `tf.contrib.data.AUTOTUNE` has not yet been moved, because we have not yet determined how to `tf_export()` a Python integer. * The stats-related API endpoints have not yet appeared in a released version of TensorFlow, so these are moved to `tf.data.experimental` without retaining an endpoint in `tf.contrib.data`. In addition, this change includes some build rule and ApiDef refactoring: * Some of the "//third_party/tensorflow/python:training" dependencies had to be split in order to avoid a circular dependency. * The `tf.contrib.stateless` ops now have a private core library for the generated wrappers (and accordingly are hidden in their ApiDef) so that `tf.data.experimental.sample_from_datasets()` can depend on them. PiperOrigin-RevId: 215304249
Diffstat (limited to 'tensorflow/contrib/data')
-rw-r--r--tensorflow/contrib/data/README.md18
-rw-r--r--tensorflow/contrib/data/__init__.py11
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/BUILD560
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py226
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py987
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/bucketing_test.py824
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py632
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py71
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py148
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py76
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py79
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py811
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py125
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py359
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py281
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/BUILD164
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py65
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/hoist_random_uniform_test.py103
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py58
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py225
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py85
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py223
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py183
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py58
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py109
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py851
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py948
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py78
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py1083
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py353
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/reduce_dataset_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py)40
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/resample_test.py182
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py172
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/BUILD555
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/batch_dataset_serialization_test.py83
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/cache_dataset_serialization_test.py253
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/concatenate_dataset_serialization_test.py49
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/csv_dataset_serialization_test.py73
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/dataset_constructor_serialization_test.py95
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py692
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/filter_dataset_serialization_test.py71
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py45
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/flat_map_dataset_serialization_test.py122
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/group_by_reducer_serialization_test.py61
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/group_by_window_serialization_test.py57
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/ignore_errors_serialization_test.py46
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py83
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py88
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/map_dataset_serialization_test.py140
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/optimize_dataset_serialization_test.py39
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/padded_batch_dataset_serialization_test.py66
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py101
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/parallel_map_dataset_serialization_test.py139
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/parse_example_dataset_serialization_test.py50
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/prefetch_dataset_serialization_test.py39
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/range_dataset_serialization_test.py118
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/sample_from_datasets_serialization_test.py46
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/scan_dataset_serialization_test.py40
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/sequence_dataset_serialization_test.py129
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/serialization_integration_test.py85
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py39
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_dataset_serialization_test.py148
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/sql_dataset_serialization_test.py53
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py106
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/textline_dataset_serialization_test.py53
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/tf_record_dataset_serialization_test.py99
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/unbatch_dataset_serialization_test.py51
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/unique_dataset_serialization_test.py40
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/zip_dataset_serialization_test.py54
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py115
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py590
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py95
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py253
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py71
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py91
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py83
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py527
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py118
-rw-r--r--tensorflow/contrib/data/python/ops/BUILD170
-rw-r--r--tensorflow/contrib/data/python/ops/batching.py549
-rw-r--r--tensorflow/contrib/data/python/ops/counter.py13
-rw-r--r--tensorflow/contrib/data/python/ops/enumerate_ops.py15
-rw-r--r--tensorflow/contrib/data/python/ops/error_ops.py37
-rw-r--r--tensorflow/contrib/data/python/ops/get_single_element.py29
-rw-r--r--tensorflow/contrib/data/python/ops/grouping.py441
-rw-r--r--tensorflow/contrib/data/python/ops/indexed_dataset_ops.py177
-rw-r--r--tensorflow/contrib/data/python/ops/interleave_ops.py149
-rw-r--r--tensorflow/contrib/data/python/ops/iterator_ops.py167
-rw-r--r--tensorflow/contrib/data/python/ops/map_defun.py56
-rw-r--r--tensorflow/contrib/data/python/ops/optimization.py171
-rw-r--r--tensorflow/contrib/data/python/ops/parsing_ops.py107
-rw-r--r--tensorflow/contrib/data/python/ops/prefetching_ops.py486
-rw-r--r--tensorflow/contrib/data/python/ops/random_ops.py34
-rw-r--r--tensorflow/contrib/data/python/ops/readers.py674
-rw-r--r--tensorflow/contrib/data/python/ops/resampling.py260
-rw-r--r--tensorflow/contrib/data/python/ops/scan_ops.py137
-rw-r--r--tensorflow/contrib/data/python/ops/shuffle_ops.py56
-rw-r--r--tensorflow/contrib/data/python/ops/stats_ops.py201
-rw-r--r--tensorflow/contrib/data/python/ops/threadpool.py88
-rw-r--r--tensorflow/contrib/data/python/ops/unique.py43
-rw-r--r--tensorflow/contrib/data/python/ops/writers.py40
101 files changed, 415 insertions, 19824 deletions
diff --git a/tensorflow/contrib/data/README.md b/tensorflow/contrib/data/README.md
index 848782e8d8..90be7a66ca 100644
--- a/tensorflow/contrib/data/README.md
+++ b/tensorflow/contrib/data/README.md
@@ -1,10 +1,12 @@
`tf.contrib.data` API
=====================
-NOTE: The `tf.contrib.data` module has been deprecated. Use `tf.data` instead.
-We are continuing to support existing code using the `tf.contrib.data` APIs in
-the current version of TensorFlow, but will eventually remove support. The
-`tf.data` APIs are subject to backwards compatibility guarantees.
+NOTE: The `tf.contrib.data` module has been deprecated. Use `tf.data` instead,
+or `tf.data.experimental` for the experimental transformations previously hosted
+in this module. We are continuing to support existing code using the
+`tf.contrib.data` APIs in the current version of TensorFlow, but will eventually
+remove support. The non-experimental `tf.data` APIs are subject to backwards
+compatibility guarantees.
Porting your code to `tf.data`
------------------------------
@@ -25,13 +27,13 @@ instead apply them using `Dataset.apply()` transformation. The full list of
changes is as follows:
* `dataset.dense_to_sparse_batch(...)` is now
- `dataset.apply(tf.contrib.data.dense_to_sparse_batch(...)`.
+ `dataset.apply(tf.data.experimental.dense_to_sparse_batch(...)`.
* `dataset.enumerate(...)` is now
- `dataset.apply(tf.contrib.data.enumerate_dataset(...))`.
+ `dataset.apply(tf.data.experimental.enumerate_dataset(...))`.
* `dataset.group_by_window(...)` is now
- `dataset.apply(tf.contrib.data.group_by_window(...))`.
+ `dataset.apply(tf.data.experimental.group_by_window(...))`.
* `dataset.ignore_errors()` is now
- `dataset.apply(tf.contrib.data.ignore_errors())`.
+ `dataset.apply(tf.data.experimental.ignore_errors())`.
* `dataset.unbatch()` is now `dataset.apply(tf.contrib.data.unbatch())`.
The `Dataset.make_dataset_resource()` and `Iterator.dispose_op()` methods have
diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py
index 3cb51279c3..c3d3e981fa 100644
--- a/tensorflow/contrib/data/__init__.py
+++ b/tensorflow/contrib/data/__init__.py
@@ -96,10 +96,6 @@ from tensorflow.contrib.data.python.ops.interleave_ops import sample_from_datase
from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave
from tensorflow.contrib.data.python.ops.iterator_ops import CheckpointInputPipelineHook
from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator
-
-# Optimization constant that can be used to enable auto-tuning.
-from tensorflow.contrib.data.python.ops.optimization import AUTOTUNE
-
from tensorflow.contrib.data.python.ops.parsing_ops import parse_example_dataset
from tensorflow.contrib.data.python.ops.prefetching_ops import copy_to_device
from tensorflow.contrib.data.python.ops.prefetching_ops import prefetch_to_device
@@ -114,11 +110,12 @@ from tensorflow.contrib.data.python.ops.resampling import rejection_resample
from tensorflow.contrib.data.python.ops.scan_ops import scan
from tensorflow.contrib.data.python.ops.shuffle_ops import shuffle_and_repeat
from tensorflow.contrib.data.python.ops.sliding import sliding_window_batch
-from tensorflow.contrib.data.python.ops.stats_ops import latency_stats
-from tensorflow.contrib.data.python.ops.stats_ops import set_stats_aggregator
-from tensorflow.contrib.data.python.ops.stats_ops import StatsAggregator
from tensorflow.contrib.data.python.ops.unique import unique
from tensorflow.contrib.data.python.ops.writers import TFRecordWriter
+
+# Optimization constant that can be used to enable auto-tuning.
+from tensorflow.python.data.experimental.ops.optimization import AUTOTUNE
+
from tensorflow.python.data.ops.iterator_ops import get_next_as_optional
from tensorflow.python.data.ops.optional_ops import Optional
# pylint: enable=unused-import
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index 33784afa3f..42f538b4ba 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -8,51 +8,17 @@ load("//tensorflow:tensorflow.bzl", "cuda_py_test")
load("//tensorflow:tensorflow.bzl", "py_test")
py_test(
- name = "batch_dataset_op_test",
- size = "medium",
- srcs = ["batch_dataset_op_test.py"],
+ name = "assert_element_shape_test",
+ srcs = ["assert_element_shape_test.py"],
srcs_version = "PY2AND3",
- tags = [
- "no_oss", # (b/79552534)
- "no_pip",
- ],
deps = [
"//tensorflow/contrib/data/python/ops:batching",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
- "//tensorflow/python:math_ops",
"//tensorflow/python:script_ops",
- "//tensorflow/python:session",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python:string_ops",
- "//tensorflow/python:tensor_shape",
- "//tensorflow/python:util",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- "@absl_py//absl/testing:parameterized",
- ],
-)
-
-py_test(
- name = "bucketing_test",
- size = "medium",
- srcs = ["bucketing_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/data/python/ops:grouping",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:sparse_tensor",
"//tensorflow/python:string_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python/data/kernel_tests:test_base",
@@ -62,147 +28,6 @@ py_test(
)
py_test(
- name = "csv_dataset_op_test",
- size = "medium",
- srcs = ["csv_dataset_op_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- "//tensorflow/contrib/data/python/ops:error_ops",
- "//tensorflow/contrib/data/python/ops:readers",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:parsing_ops",
- "//tensorflow/python:platform",
- "//tensorflow/python:platform_test",
- "//tensorflow/python:session",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:readers",
- "//tensorflow/python/eager:context",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "dataset_constructor_op_test",
- size = "medium",
- srcs = ["dataset_constructor_op_test.py"],
- srcs_version = "PY2AND3",
- tags = [
- "manual",
- "nomac", # b/62040583
- ],
- deps = [
- "//tensorflow/contrib/data/python/ops:batching",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- ],
-)
-
-py_test(
- name = "directed_interleave_dataset_test",
- size = "medium",
- srcs = ["directed_interleave_dataset_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/data/python/ops:interleave_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:errors",
- "//tensorflow/python:random_seed",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "get_single_element_test",
- size = "small",
- srcs = ["get_single_element_test.py"],
- deps = [
- "//tensorflow/contrib/data/python/ops:get_single_element",
- "//tensorflow/contrib/data/python/ops:grouping",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "@absl_py//absl/testing:parameterized",
- ],
-)
-
-py_test(
- name = "indexed_dataset_ops_test",
- srcs = ["indexed_dataset_ops_test.py"],
- deps = [
- "//tensorflow/contrib/data/python/ops:indexed_dataset_ops",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:experimental_dataset_ops_gen",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "interleave_dataset_op_test",
- size = "medium",
- srcs = ["interleave_dataset_op_test.py"],
- srcs_version = "PY2AND3",
- tags = [
- "no_oss",
- "no_pip",
- "notap",
- ],
- deps = [
- "//tensorflow/contrib/data/python/ops:interleave_ops",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:script_ops",
- "//tensorflow/python:sparse_ops",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "@six_archive//:six",
- ],
-)
-
-py_test(
- name = "iterator_ops_test",
- size = "small",
- srcs = ["iterator_ops_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- "//tensorflow/contrib/data/python/ops:iterator_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:training",
- "//tensorflow/python:variables",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/estimator:estimator_py",
- ],
-)
-
-py_test(
name = "lmdb_dataset_op_test",
size = "medium",
srcs = ["lmdb_dataset_op_test.py"],
@@ -229,252 +54,18 @@ py_test(
)
py_test(
- name = "map_dataset_op_test",
- size = "medium",
- srcs = ["map_dataset_op_test.py"],
- srcs_version = "PY2AND3",
- tags = [
- "no_pip",
- "noasan", # times out
- "optonly",
- ],
- deps = [
- "//tensorflow/contrib/data/python/ops:batching",
- "//tensorflow/contrib/data/python/ops:error_ops",
- "//tensorflow/contrib/data/python/ops:optimization",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:io_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:util",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "filter_dataset_op_test",
- size = "medium",
- srcs = ["filter_dataset_op_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/data/python/ops:optimization",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:io_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:util",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "map_defun_op_test",
+ name = "reduce_dataset_test",
size = "small",
- srcs = ["map_defun_op_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
+ srcs = ["reduce_dataset_test.py"],
deps = [
- "//tensorflow/contrib/data/python/ops:map_defun",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:check_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:data_flow_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:function",
- "//tensorflow/python:functional_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:session",
- "//tensorflow/python/data/kernel_tests:test_base",
- ],
-)
-
-py_test(
- name = "parsing_ops_test",
- size = "small",
- srcs = ["parsing_ops_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/data/python/ops:parsing_ops",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:parsing_ops",
- "//tensorflow/python:platform",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- "//third_party/py/numpy",
- ],
-)
-
-cuda_py_test(
- name = "prefetching_ops_test",
- size = "small",
- srcs = ["prefetching_ops_test.py"],
- additional_deps = [
- "//tensorflow/contrib/data/python/ops:prefetching_ops",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:function",
- "//tensorflow/python:resource_variable_ops",
- "//tensorflow/python/compat:compat",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/ops:iterator_ops",
- ],
- tags = ["no_windows_gpu"],
-)
-
-py_test(
- name = "range_dataset_op_test",
- size = "small",
- srcs = ["range_dataset_op_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/data/python/ops:counter",
- "//tensorflow/contrib/data/python/ops:enumerate_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:tensor_shape",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- ],
-)
-
-py_library(
- name = "reader_dataset_ops_test_base",
- testonly = 1,
- srcs = [
- "reader_dataset_ops_test_base.py",
- ],
- srcs_version = "PY2AND3",
- visibility = [
- "//tensorflow/contrib/data/python/kernel_tests:__pkg__",
- "//tensorflow/contrib/data/python/kernel_tests/serialization:__pkg__",
- ],
- deps = [
- "//tensorflow/contrib/data/python/ops:readers",
- "//tensorflow/core:protos_all_py",
+ "//tensorflow/contrib/data/python/ops:get_single_element",
+ "//tensorflow/contrib/data/python/ops:grouping",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
- "//tensorflow/python:lib",
- "//tensorflow/python:parsing_ops",
- "//tensorflow/python:util",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:iterator_ops",
- "//tensorflow/python/data/ops:readers",
- ],
-)
-
-py_test(
- name = "reader_dataset_ops_test",
- size = "medium",
- srcs = ["reader_dataset_ops_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":reader_dataset_ops_test_base",
- "//tensorflow/contrib/data/python/ops:readers",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:parsing_ops",
- "//tensorflow/python:string_ops",
- "//tensorflow/python/data/ops:readers",
- "//tensorflow/python/data/util:nest",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "resample_test",
- size = "medium",
- srcs = ["resample_test.py"],
- shard_count = 2,
- srcs_version = "PY2AND3",
- tags = [
- "noasan",
- "optonly",
- ],
- deps = [
- "//tensorflow/contrib/data/python/ops:resampling",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:random_ops",
- "//tensorflow/python:string_ops",
- "//tensorflow/python:util",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
- "@six_archive//:six",
- ],
-)
-
-py_test(
- name = "scan_dataset_op_test",
- size = "small",
- srcs = ["scan_dataset_op_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- "//tensorflow/contrib/data/python/ops:scan_ops",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/eager:context",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "shuffle_dataset_op_test",
- size = "medium",
- srcs = ["shuffle_dataset_op_test.py"],
- srcs_version = "PY2AND3",
- tags = [
- "no_pip",
- "optonly",
- ],
- deps = [
- "//tensorflow/contrib/data/python/ops:shuffle_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
],
)
@@ -496,142 +87,3 @@ py_test(
"@absl_py//absl/testing:parameterized",
],
)
-
-py_library(
- name = "sql_dataset_op_test_base",
- srcs = ["sql_dataset_op_test_base.py"],
- srcs_version = "PY2AND3",
- visibility = [
- "//tensorflow/contrib/data/python/kernel_tests:__pkg__",
- "//tensorflow/contrib/data/python/kernel_tests/serialization:__pkg__",
- ],
- deps = [
- "//tensorflow/contrib/data/python/ops:readers",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python/data/kernel_tests:test_base",
- "@org_sqlite//:python",
- ],
-)
-
-py_test(
- name = "sql_dataset_op_test",
- size = "small",
- srcs = ["sql_dataset_op_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":sql_dataset_op_test_base",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- ],
-)
-
-py_test(
- name = "stats_dataset_ops_test",
- size = "medium",
- srcs = ["stats_dataset_ops_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":reader_dataset_ops_test_base",
- ":stats_dataset_test_base",
- "//tensorflow/contrib/data/python/ops:stats_ops",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_library(
- name = "stats_dataset_test_base",
- srcs = ["stats_dataset_test_base.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python/data/kernel_tests:test_base",
- ],
-)
-
-py_test(
- name = "threadpool_dataset_ops_test",
- size = "small",
- srcs = ["threadpool_dataset_ops_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- "//tensorflow/contrib/data/python/ops:threadpool",
- "//tensorflow/contrib/data/python/ops:unique",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:script_ops",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- "@absl_py//absl/testing:parameterized",
- ],
-)
-
-py_test(
- name = "unique_dataset_op_test",
- size = "small",
- srcs = ["unique_dataset_op_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- "//tensorflow/contrib/data/python/ops:unique",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:util",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- ],
-)
-
-py_test(
- name = "window_dataset_op_test",
- size = "medium",
- srcs = ["window_dataset_op_test.py"],
- srcs_version = "PY2AND3",
- tags = [
- "no_pip",
- ],
- deps = [
- "//tensorflow/contrib/data/python/ops:batching",
- "//tensorflow/contrib/data/python/ops:grouping",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- "@absl_py//absl/testing:parameterized",
- ],
-)
-
-py_test(
- name = "writer_ops_test",
- size = "small",
- srcs = ["writer_ops_test.py"],
- deps = [
- "//tensorflow/contrib/data/python/ops:writers",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:lib",
- "//tensorflow/python:util",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/ops:readers",
- ],
-)
diff --git a/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py b/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py
new file mode 100644
index 0000000000..0456463a19
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py
@@ -0,0 +1,226 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.data.python.ops import batching
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import script_ops
+from tensorflow.python.platform import test
+
+
+class AssertElementShapeTest(test_base.DatasetTestBase):
+
+ def test_assert_element_shape(self):
+
+ def create_dataset(_):
+ return (array_ops.ones(2, dtype=dtypes.float32),
+ array_ops.zeros((3, 4), dtype=dtypes.int32))
+
+ dataset = dataset_ops.Dataset.range(5).map(create_dataset)
+ expected_shapes = (tensor_shape.TensorShape(2),
+ tensor_shape.TensorShape((3, 4)))
+ self.assertEqual(expected_shapes, dataset.output_shapes)
+
+ result = dataset.apply(batching.assert_element_shape(expected_shapes))
+ self.assertEqual(expected_shapes, result.output_shapes)
+
+ iterator = result.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ for _ in range(5):
+ sess.run(get_next)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def test_assert_wrong_element_shape(self):
+
+ def create_dataset(_):
+ return (array_ops.ones(2, dtype=dtypes.float32),
+ array_ops.zeros((3, 4), dtype=dtypes.int32))
+
+ dataset = dataset_ops.Dataset.range(3).map(create_dataset)
+ wrong_shapes = (tensor_shape.TensorShape(2),
+ tensor_shape.TensorShape((3, 10)))
+ with self.assertRaises(ValueError):
+ dataset.apply(batching.assert_element_shape(wrong_shapes))
+
+ def test_assert_element_shape_on_unknown_shape_dataset(self):
+
+ def create_unknown_shape_dataset(x):
+ return script_ops.py_func(
+ lambda _: ( # pylint: disable=g-long-lambda
+ np.ones(2, dtype=np.float32),
+ np.zeros((3, 4), dtype=np.int32)),
+ [x],
+ [dtypes.float32, dtypes.int32])
+
+ dataset = dataset_ops.Dataset.range(5).map(create_unknown_shape_dataset)
+ unknown_shapes = (tensor_shape.TensorShape(None),
+ tensor_shape.TensorShape(None))
+ self.assertEqual(unknown_shapes, dataset.output_shapes)
+
+ expected_shapes = (tensor_shape.TensorShape(2),
+ tensor_shape.TensorShape((3, 4)))
+ result = dataset.apply(batching.assert_element_shape(expected_shapes))
+ self.assertEqual(expected_shapes, result.output_shapes)
+
+ iterator = result.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ for _ in range(5):
+ sess.run(get_next)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def test_assert_wrong_element_shape_on_unknown_shape_dataset(self):
+
+ def create_unknown_shape_dataset(x):
+ return script_ops.py_func(
+ lambda _: ( # pylint: disable=g-long-lambda
+ np.ones(2, dtype=np.float32),
+ np.zeros((3, 4), dtype=np.int32)),
+ [x],
+ [dtypes.float32, dtypes.int32])
+
+ dataset = dataset_ops.Dataset.range(3).map(create_unknown_shape_dataset)
+ unknown_shapes = (tensor_shape.TensorShape(None),
+ tensor_shape.TensorShape(None))
+ self.assertEqual(unknown_shapes, dataset.output_shapes)
+
+ wrong_shapes = (tensor_shape.TensorShape(2),
+ tensor_shape.TensorShape((3, 10)))
+ iterator = (
+ dataset.apply(batching.assert_element_shape(wrong_shapes))
+ .make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(get_next)
+
+ def test_assert_partial_element_shape(self):
+
+ def create_dataset(_):
+ return (array_ops.ones(2, dtype=dtypes.float32),
+ array_ops.zeros((3, 4), dtype=dtypes.int32))
+
+ dataset = dataset_ops.Dataset.range(5).map(create_dataset)
+ partial_expected_shape = (
+ tensor_shape.TensorShape(None), # Unknown shape
+ tensor_shape.TensorShape((None, 4))) # Partial shape
+ result = dataset.apply(
+ batching.assert_element_shape(partial_expected_shape))
+ # Partial shapes are merged with actual shapes:
+ actual_shapes = (tensor_shape.TensorShape(2),
+ tensor_shape.TensorShape((3, 4)))
+ self.assertEqual(actual_shapes, result.output_shapes)
+
+ iterator = result.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ for _ in range(5):
+ sess.run(get_next)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def test_assert_wrong_partial_element_shape(self):
+
+ def create_dataset(_):
+ return (array_ops.ones(2, dtype=dtypes.float32),
+ array_ops.zeros((3, 4), dtype=dtypes.int32))
+
+ dataset = dataset_ops.Dataset.range(3).map(create_dataset)
+ wrong_shapes = (tensor_shape.TensorShape(2),
+ tensor_shape.TensorShape((None, 10)))
+ with self.assertRaises(ValueError):
+ dataset.apply(batching.assert_element_shape(wrong_shapes))
+
+ def test_assert_partial_element_shape_on_unknown_shape_dataset(self):
+
+ def create_unknown_shape_dataset(x):
+ return script_ops.py_func(
+ lambda _: ( # pylint: disable=g-long-lambda
+ np.ones(2, dtype=np.float32),
+ np.zeros((3, 4), dtype=np.int32)),
+ [x],
+ [dtypes.float32, dtypes.int32])
+
+ dataset = dataset_ops.Dataset.range(5).map(create_unknown_shape_dataset)
+ unknown_shapes = (tensor_shape.TensorShape(None),
+ tensor_shape.TensorShape(None))
+ self.assertEqual(unknown_shapes, dataset.output_shapes)
+
+ expected_shapes = (tensor_shape.TensorShape(2),
+ tensor_shape.TensorShape((None, 4)))
+ result = dataset.apply(batching.assert_element_shape(expected_shapes))
+ self.assertEqual(expected_shapes, result.output_shapes)
+
+ iterator = result.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ for _ in range(5):
+ sess.run(get_next)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def test_assert_wrong_partial_element_shape_on_unknown_shape_dataset(self):
+
+ def create_unknown_shape_dataset(x):
+ return script_ops.py_func(
+ lambda _: ( # pylint: disable=g-long-lambda
+ np.ones(2, dtype=np.float32),
+ np.zeros((3, 4), dtype=np.int32)),
+ [x],
+ [dtypes.float32, dtypes.int32])
+
+ dataset = dataset_ops.Dataset.range(3).map(create_unknown_shape_dataset)
+ unknown_shapes = (tensor_shape.TensorShape(None),
+ tensor_shape.TensorShape(None))
+ self.assertEqual(unknown_shapes, dataset.output_shapes)
+
+ wrong_shapes = (tensor_shape.TensorShape(2),
+ tensor_shape.TensorShape((None, 10)))
+ iterator = (
+ dataset.apply(batching.assert_element_shape(wrong_shapes))
+ .make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(get_next)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
deleted file mode 100644
index fed7de5f2b..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
+++ /dev/null
@@ -1,987 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import math
-import time
-
-from absl.testing import parameterized
-import numpy as np
-
-from tensorflow.contrib.data.python.ops import batching
-from tensorflow.python.client import session
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import script_ops
-from tensorflow.python.ops import string_ops
-from tensorflow.python.platform import test
-from tensorflow.python.util import compat
-
-
-class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
-
-
- def testDenseToSparseBatchDataset(self):
- components = np.random.randint(12, size=(100,)).astype(np.int32)
- iterator = (
- dataset_ops.Dataset.from_tensor_slices(components)
- .map(lambda x: array_ops.fill([x], x)).apply(
- batching.dense_to_sparse_batch(4, [12]))
- .make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
-
- for start in range(0, len(components), 4):
- results = sess.run(get_next)
- self.assertAllEqual([[i, j]
- for i, c in enumerate(components[start:start + 4])
- for j in range(c)], results.indices)
- self.assertAllEqual(
- [c for c in components[start:start + 4] for _ in range(c)],
- results.values)
- self.assertAllEqual([min(4,
- len(components) - start), 12],
- results.dense_shape)
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testDenseToSparseBatchDatasetWithUnknownShape(self):
- components = np.random.randint(5, size=(40,)).astype(np.int32)
- iterator = (
- dataset_ops.Dataset.from_tensor_slices(components)
- .map(lambda x: array_ops.fill([x, x], x)).apply(
- batching.dense_to_sparse_batch(
- 4, [5, None])).make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
-
- for start in range(0, len(components), 4):
- results = sess.run(get_next)
- self.assertAllEqual([[i, j, z]
- for i, c in enumerate(components[start:start + 4])
- for j in range(c)
- for z in range(c)], results.indices)
- self.assertAllEqual([
- c
- for c in components[start:start + 4] for _ in range(c)
- for _ in range(c)
- ], results.values)
- self.assertAllEqual([
- min(4,
- len(components) - start), 5,
- np.max(components[start:start + 4])
- ], results.dense_shape)
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testDenseToSparseBatchDatasetWithInvalidShape(self):
- input_tensor = array_ops.constant([[1]])
- with self.assertRaisesRegexp(ValueError, "Dimension -2 must be >= 0"):
- dataset_ops.Dataset.from_tensors(input_tensor).apply(
- batching.dense_to_sparse_batch(4, [-2])).make_initializable_iterator()
-
- def testDenseToSparseBatchDatasetShapeErrors(self):
- input_tensor = array_ops.placeholder(dtypes.int32)
- iterator = (
- dataset_ops.Dataset.from_tensors(input_tensor).apply(
- batching.dense_to_sparse_batch(4, [12]))
- .make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- # Initialize with an input tensor of incompatible rank.
- sess.run(init_op, feed_dict={input_tensor: [[1]]})
- with self.assertRaisesRegexp(errors.InvalidArgumentError,
- "incompatible with the row shape"):
- sess.run(get_next)
-
- # Initialize with an input tensor that is larger than `row_shape`.
- sess.run(init_op, feed_dict={input_tensor: range(13)})
- with self.assertRaisesRegexp(errors.DataLossError,
- "larger than the row shape"):
- sess.run(get_next)
-
- def testUnbatchScalarDataset(self):
- data = tuple([math_ops.range(10) for _ in range(3)])
- data = dataset_ops.Dataset.from_tensor_slices(data)
- expected_types = (dtypes.int32,) * 3
- data = data.batch(2)
- self.assertEqual(expected_types, data.output_types)
- data = data.apply(batching.unbatch())
- self.assertEqual(expected_types, data.output_types)
-
- iterator = data.make_one_shot_iterator()
- op = iterator.get_next()
-
- with self.cached_session() as sess:
- for i in range(10):
- self.assertEqual((i,) * 3, sess.run(op))
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(op)
-
- def testUnbatchDatasetWithStrings(self):
- data = tuple([math_ops.range(10) for _ in range(3)])
- data = dataset_ops.Dataset.from_tensor_slices(data)
- data = data.map(lambda x, y, z: (x, string_ops.as_string(y), z))
- expected_types = (dtypes.int32, dtypes.string, dtypes.int32)
- data = data.batch(2)
- self.assertEqual(expected_types, data.output_types)
- data = data.apply(batching.unbatch())
- self.assertEqual(expected_types, data.output_types)
-
- iterator = data.make_one_shot_iterator()
- op = iterator.get_next()
-
- with self.cached_session() as sess:
- for i in range(10):
- self.assertEqual((i, compat.as_bytes(str(i)), i), sess.run(op))
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(op)
-
- def testUnbatchDatasetWithSparseTensor(self):
- st = sparse_tensor.SparseTensorValue(
- indices=[[i, i] for i in range(10)],
- values=list(range(10)),
- dense_shape=[10, 10])
- data = dataset_ops.Dataset.from_tensors(st)
- data = data.apply(batching.unbatch())
- data = data.batch(5)
- data = data.apply(batching.unbatch())
- iterator = data.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- for i in range(10):
- st_row = sess.run(next_element)
- self.assertEqual([i], st_row.indices)
- self.assertEqual([i], st_row.values)
- self.assertEqual([10], st_row.dense_shape)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testUnbatchDatasetWithDenseAndSparseTensor(self):
- st = sparse_tensor.SparseTensorValue(
- indices=[[i, i] for i in range(10)],
- values=list(range(10)),
- dense_shape=[10, 10])
- data = dataset_ops.Dataset.from_tensors((list(range(10)), st))
- data = data.apply(batching.unbatch())
- data = data.batch(5)
- data = data.apply(batching.unbatch())
- iterator = data.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- for i in range(10):
- dense_elem, st_row = sess.run(next_element)
- self.assertEqual(i, dense_elem)
- self.assertEqual([i], st_row.indices)
- self.assertEqual([i], st_row.values)
- self.assertEqual([10], st_row.dense_shape)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testUnbatchSingleElementTupleDataset(self):
- data = tuple([(math_ops.range(10),) for _ in range(3)])
- data = dataset_ops.Dataset.from_tensor_slices(data)
- expected_types = ((dtypes.int32,),) * 3
- data = data.batch(2)
- self.assertEqual(expected_types, data.output_types)
- data = data.apply(batching.unbatch())
- self.assertEqual(expected_types, data.output_types)
-
- iterator = data.make_one_shot_iterator()
- op = iterator.get_next()
-
- with self.cached_session() as sess:
- for i in range(10):
- self.assertEqual(((i,),) * 3, sess.run(op))
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(op)
-
- def testUnbatchMultiElementTupleDataset(self):
- data = tuple([(math_ops.range(10 * i, 10 * i + 10),
- array_ops.fill([10], "hi")) for i in range(3)])
- data = dataset_ops.Dataset.from_tensor_slices(data)
- expected_types = ((dtypes.int32, dtypes.string),) * 3
- data = data.batch(2)
- self.assertAllEqual(expected_types, data.output_types)
- data = data.apply(batching.unbatch())
- self.assertAllEqual(expected_types, data.output_types)
-
- iterator = data.make_one_shot_iterator()
- op = iterator.get_next()
-
- with self.cached_session() as sess:
- for i in range(10):
- self.assertEqual(((i, b"hi"), (10 + i, b"hi"), (20 + i, b"hi")),
- sess.run(op))
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(op)
-
- def testUnbatchEmpty(self):
- data = dataset_ops.Dataset.from_tensors(
- (constant_op.constant([]), constant_op.constant([], shape=[0, 4]),
- constant_op.constant([], shape=[0, 4, 0])))
- data = data.apply(batching.unbatch())
- iterator = data.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testUnbatchStaticShapeMismatch(self):
- data = dataset_ops.Dataset.from_tensors((np.arange(7), np.arange(8),
- np.arange(9)))
- with self.assertRaises(ValueError):
- data.apply(batching.unbatch())
-
- def testUnbatchDynamicShapeMismatch(self):
- ph1 = array_ops.placeholder(dtypes.int32, shape=[None])
- ph2 = array_ops.placeholder(dtypes.int32, shape=None)
- data = dataset_ops.Dataset.from_tensors((ph1, ph2))
- data = data.apply(batching.unbatch())
- iterator = data.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- # Mismatch in the 0th dimension.
- sess.run(
- iterator.initializer,
- feed_dict={
- ph1: np.arange(7).astype(np.int32),
- ph2: np.arange(8).astype(np.int32)
- })
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(next_element)
-
- # No 0th dimension (i.e. scalar value) for one component.
- sess.run(
- iterator.initializer,
- feed_dict={
- ph1: np.arange(7).astype(np.int32),
- ph2: 7
- })
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(next_element)
-
- def testBatchAndDropRemainder(self):
- components = (np.arange(7),
- np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
- np.array(37.0) * np.arange(7))
-
- batch_size = array_ops.placeholder(dtypes.int64, shape=[])
-
- iterator = (
- dataset_ops.Dataset.from_tensor_slices(components).apply(
- batching.batch_and_drop_remainder(batch_size))
- .make_initializable_iterator())
-
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- for test_batch_size in [1, 3, 7, 10]:
- sess.run(iterator.initializer, feed_dict={batch_size: test_batch_size})
- num_batches = 7 // test_batch_size
- for i in range(num_batches):
- result = sess.run(next_element)
- for component, result_component in zip(components, result):
- for j in range(test_batch_size):
- self.assertAllEqual(component[(i * test_batch_size + j)],
- result_component[j])
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testBatchAndDropRemainderSparse(self):
-
- def _sparse(i):
- return sparse_tensor.SparseTensorValue(
- indices=[[0]], values=(i * [1]), dense_shape=[1])
-
- iterator = dataset_ops.Dataset.range(12).map(_sparse).apply(
- batching.batch_and_drop_remainder(5)).make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
- for i in range(2):
- actual = sess.run(get_next)
- expected = sparse_tensor.SparseTensorValue(
- indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
- values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4],
- dense_shape=[5, 1])
- self.assertTrue(sparse_tensor.is_sparse(actual))
- self.assertSparseValuesEqual(actual, expected)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testPaddedBatchAndDropRemainder(self):
- els = []
- for length in [3, 6, 9, 4, 12, 10, 2]:
- els.append((np.array(length), np.arange(length) + 1,
- np.array(length * 2)))
-
- dataset = dataset_ops.Dataset.from_tensors(els[0])
- for el in els[1:]:
- dataset = dataset.concatenate(dataset_ops.Dataset.from_tensors(el))
-
- batch_size = array_ops.placeholder(dtypes.int64, shape=[])
- iterator = (
- dataset.apply(
- batching.padded_batch_and_drop_remainder(
- batch_size, ([], [None], []))).make_initializable_iterator())
-
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- for test_batch_size in [1, 3, 7, 10]:
- sess.run(iterator.initializer, feed_dict={batch_size: test_batch_size})
- num_batches = 7 // test_batch_size
- for i in range(num_batches):
- result = sess.run(next_element)
- for component_idx, result_component in enumerate(result):
- for j in range(test_batch_size):
- data_idx = i * test_batch_size + j
- comp = result_component[j]
- unpadded = comp[comp > 0]
- if np.isscalar(comp):
- # The boolean mask indexing above adds a dim back. Rm it.
- unpadded = unpadded[0]
- self.assertAllEqual(els[data_idx][component_idx], unpadded)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testPaddedBatchAndDropRemainderSparseError(self):
-
- def _map_fn(i):
- return sparse_tensor.SparseTensorValue(
- indices=[[0, 0]], values=(i * [1]), dense_shape=[1, 1]), i
-
- with self.assertRaises(TypeError):
- _ = dataset_ops.Dataset.range(10).map(_map_fn).apply(
- batching.padded_batch_and_drop_remainder(5))
-
- def testBatchAndDropRemainderShapeInference(self):
- components = (array_ops.placeholder(dtypes.int32),
- (array_ops.placeholder(dtypes.int32, shape=[None]),
- array_ops.placeholder(dtypes.int32, shape=[20, 30])))
-
- # Test with a statically known batch size.
- dataset = (
- dataset_ops.Dataset.from_tensor_slices(components).apply(
- batching.batch_and_drop_remainder(128)))
-
- self.assertIs(None, dataset.output_shapes[0].ndims)
- self.assertEqual([128], dataset.output_shapes[1][0].as_list())
- self.assertEqual([128, 30], dataset.output_shapes[1][1].as_list())
-
- # Test with a dynamic batch size: the static shape will be unknown, because
- # `batch_size` is a placeholder.
- batch_size = array_ops.placeholder(dtypes.int64)
- dataset = (
- dataset_ops.Dataset.from_tensor_slices(components).apply(
- batching.batch_and_drop_remainder(batch_size)))
-
- self.assertIs(None, dataset.output_shapes[0].ndims)
- self.assertEqual([None], dataset.output_shapes[1][0].as_list())
- self.assertEqual([None, 30], dataset.output_shapes[1][1].as_list())
-
- @parameterized.named_parameters(
- ("Default", None, None),
- ("SequentialCalls", 1, None),
- ("ParallelCalls", 2, None),
- ("ParallelBatches", None, 10),
- )
- def testMapAndBatch(self, num_parallel_calls, num_parallel_batches):
- """Test a dataset that maps a TF function across its input elements."""
- # The pipeline is TensorSliceDataset ->
- # RepeatDataset(count) -> MapAndBatchDataset(square_3, batch_size).
- components = (np.arange(7),
- np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
- np.array(37.0) * np.arange(7))
-
- count = array_ops.placeholder(dtypes.int64, shape=[])
- batch_size = array_ops.placeholder(dtypes.int64, shape=[])
-
- def _map_fn(x, y, z):
- return math_ops.square(x), math_ops.square(y), math_ops.square(z)
-
- iterator = (
- dataset_ops.Dataset.from_tensor_slices(components).repeat(count).apply(
- batching.map_and_batch(
- map_func=_map_fn,
- batch_size=batch_size,
- num_parallel_calls=num_parallel_calls,
- num_parallel_batches=num_parallel_batches))
- .make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- self.assertEqual([[None] + list(c.shape[1:]) for c in components],
- [t.shape.as_list() for t in get_next])
-
- with self.cached_session() as sess:
- # Batch of a finite input, where the batch_size divides the
- # total number of elements.
- sess.run(init_op, feed_dict={count: 28, batch_size: 14})
- num_batches = (28 * 7) // 14
- for i in range(num_batches):
- result = sess.run(get_next)
- for component, result_component in zip(components, result):
- for j in range(14):
- self.assertAllEqual(component[(i * 14 + j) % 7]**2,
- result_component[j])
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Batch of a finite input, where the batch_size does not
- # divide the total number of elements.
- sess.run(init_op, feed_dict={count: 14, batch_size: 8})
-
- # We expect (num_batches - 1) full-sized batches.
- num_batches = int(math.ceil((14 * 7) / 8))
- for i in range(num_batches - 1):
- result = sess.run(get_next)
- for component, result_component in zip(components, result):
- for j in range(8):
- self.assertAllEqual(component[(i * 8 + j) % 7]**2,
- result_component[j])
- result = sess.run(get_next)
- for component, result_component in zip(components, result):
- for j in range((14 * 7) % 8):
- self.assertAllEqual(component[((num_batches - 1) * 8 + j) % 7]**2,
- result_component[j])
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Batch of an empty input should fail straight away.
- sess.run(init_op, feed_dict={count: 0, batch_size: 8})
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Empty batch should be an initialization time error.
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(init_op, feed_dict={count: 14, batch_size: 0})
-
- @parameterized.named_parameters(
- ("Even", False),
- ("Uneven", True),
- )
- def testMapAndBatchPartialBatch(self, drop_remainder):
- iterator = (
- dataset_ops.Dataset.range(10).apply(
- batching.map_and_batch(
- lambda x: array_ops.reshape(x * x, [1]),
- batch_size=4,
- drop_remainder=drop_remainder)).make_one_shot_iterator())
- if drop_remainder:
- self.assertEqual([4, 1], iterator.output_shapes.as_list())
- else:
- self.assertEqual([None, 1], iterator.output_shapes.as_list())
- next_element = iterator.get_next()
- with self.cached_session() as sess:
- self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
- self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
- if not drop_remainder:
- self.assertAllEqual([[64], [81]], sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testMapAndBatchYieldsPartialBatch(self):
- iterator = (dataset_ops.Dataset.range(10)
- .apply(batching.map_and_batch(
- lambda x: array_ops.reshape(x * x, [1]), 4))
- .make_one_shot_iterator())
- self.assertEqual([None, 1], iterator.output_shapes.as_list())
- next_element = iterator.get_next()
- with self.cached_session() as sess:
- self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
- self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
- self.assertAllEqual([[64], [81]], sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testMapAndBatchParallelGetNext(self):
- iterator = (dataset_ops.Dataset.range(50000)
- .apply(batching.map_and_batch(lambda x: x, batch_size=100))
- .make_one_shot_iterator())
- elements = []
- for _ in range(100):
- elements.append(iterator.get_next())
- with self.cached_session() as sess:
- for i in range(5):
- got = sess.run(elements)
- got.sort(key=lambda x: x[0])
- expected = []
- for j in range(100):
- expected.append(range(i*10000+j*100, i*10000+(j+1)*100))
- self.assertAllEqual(got, expected)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(elements)
-
- def testMapAndBatchParallelGetNextDropRemainder(self):
- iterator = (
- dataset_ops.Dataset.range(49999).apply(
- batching.map_and_batch(
- lambda x: x, batch_size=100, drop_remainder=True))
- .make_one_shot_iterator())
- elements = []
- for _ in range(100):
- elements.append(iterator.get_next())
- with self.cached_session() as sess:
- for i in range(4):
- got = sess.run(elements)
- got.sort(key=lambda x: x[0])
- expected = []
- for j in range(100):
- expected.append(range(i*10000+j*100, i*10000+(j+1)*100))
- self.assertAllEqual(got, expected)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(elements)
-
- def testMapAndBatchSparse(self):
-
- def _sparse(i):
- return sparse_tensor.SparseTensorValue(
- indices=[[0]], values=(i * [1]), dense_shape=[1])
-
- iterator = dataset_ops.Dataset.range(10).apply(
- batching.map_and_batch(_sparse, 5)).make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
- for i in range(2):
- actual = sess.run(get_next)
- expected = sparse_tensor.SparseTensorValue(
- indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
- values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4],
- dense_shape=[5, 1])
- self.assertTrue(sparse_tensor.is_sparse(actual))
- self.assertSparseValuesEqual(actual, expected)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testMapAndBatchFails(self):
- """Test a dataset that maps a TF function across its input elements."""
- dataset = dataset_ops.Dataset.from_tensors(
- array_ops.check_numerics(
- constant_op.constant(1.0) / constant_op.constant(0.0), "oops"))
- batch_size = array_ops.placeholder(dtypes.int64, shape=[])
- iterator = (
- dataset.apply(batching.map_and_batch(lambda x: x, batch_size))
- .make_initializable_iterator())
- init_op = iterator.initializer
- with self.cached_session() as sess:
- with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"):
- sess.run(init_op, feed_dict={batch_size: 14})
-
- def testMapAndBatchShapeMismatch(self):
- """Test a dataset that maps a TF function across its input elements."""
-
- def generator():
- yield [1]
- yield [2]
- yield [3]
- yield [[4, 5, 6]]
-
- dataset = dataset_ops.Dataset.from_generator(
- generator, output_types=dtypes.int32)
- batch_size = 4
- iterator = (
- dataset.apply(batching.map_and_batch(lambda x: x, batch_size))
- .make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- sess.run(init_op)
- with self.assertRaisesRegexp(errors.InvalidArgumentError,
- "number of elements does not match"):
- sess.run(get_next)
-
- def testMapAndBatchImplicitDispose(self):
- # Tests whether a map and batch dataset will be cleaned up correctly when
- # the pipeline does not run it until exhaustion.
- # The pipeline is TensorSliceDataset -> RepeatDataset(1000) ->
- # MapAndBatchDataset(f=square_3, batch_size=100).
- components = (np.arange(1000),
- np.array([[1, 2, 3]]) * np.arange(1000)[:, np.newaxis],
- np.array(37.0) * np.arange(1000))
-
- def _map_fn(x, y, z):
- return math_ops.square(x), math_ops.square(y), math_ops.square(z)
-
- dataset = dataset_ops.Dataset.from_tensor_slices(components).repeat(
- 1000).apply(batching.map_and_batch(_map_fn, batch_size=100))
- dataset = dataset.prefetch(5)
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- for _ in range(3):
- sess.run(get_next)
-
- @parameterized.named_parameters(
- ("1", 0),
- ("2", 5),
- ("3", 10),
- ("4", 90),
- ("5", 95),
- ("6", 99),
- )
- def testMapAndBatchOutOfRangeError(self, threshold):
-
- def raising_py_fn(i):
- if i >= threshold:
- raise StopIteration()
- else:
- return i
-
- iterator = (
- dataset_ops.Dataset.range(100).apply(
- batching.map_and_batch(
- lambda x: script_ops.py_func(raising_py_fn, [x], dtypes.int64),
- batch_size=10)).make_one_shot_iterator())
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- for i in range(threshold // 10):
- self.assertAllEqual([i * 10 + j for j in range(10)], sess.run(get_next))
- if threshold % 10 != 0:
- self.assertAllEqual(
- [threshold // 10 * 10 + j for j in range(threshold % 10)],
- sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- @parameterized.named_parameters(
- ("1", False, dtypes.bool),
- ("2", -42, dtypes.int8),
- ("3", -42, dtypes.int16),
- ("4", -42, dtypes.int32),
- ("5", -42, dtypes.int64),
- ("6", 42, dtypes.uint8),
- ("7", 42, dtypes.uint16),
- ("8", 42.0, dtypes.float16),
- ("9", 42.0, dtypes.float32),
- ("10", 42.0, dtypes.float64),
- ("11", b"hello", dtypes.string),
- )
- def testMapAndBatchTypes(self, element, dtype):
- def gen():
- yield element
-
- dataset = dataset_ops.Dataset.from_generator(gen, dtype).repeat(100).apply(
- batching.map_and_batch(lambda x: x, batch_size=10))
-
- get_next = dataset.make_one_shot_iterator().get_next()
-
- with self.cached_session() as sess:
- for _ in range(10):
- self.assertAllEqual([element for _ in range(10)], sess.run(get_next))
-
-
-class RestructuredDatasetTest(test_base.DatasetTestBase):
-
- def test_assert_element_shape(self):
-
- def create_dataset(_):
- return (array_ops.ones(2, dtype=dtypes.float32),
- array_ops.zeros((3, 4), dtype=dtypes.int32))
-
- dataset = dataset_ops.Dataset.range(5).map(create_dataset)
- expected_shapes = (tensor_shape.TensorShape(2),
- tensor_shape.TensorShape((3, 4)))
- self.assertEqual(expected_shapes, dataset.output_shapes)
-
- result = dataset.apply(batching.assert_element_shape(expected_shapes))
- self.assertEqual(expected_shapes, result.output_shapes)
-
- iterator = result.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- sess.run(init_op)
- for _ in range(5):
- sess.run(get_next)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def test_assert_wrong_element_shape(self):
-
- def create_dataset(_):
- return (array_ops.ones(2, dtype=dtypes.float32),
- array_ops.zeros((3, 4), dtype=dtypes.int32))
-
- dataset = dataset_ops.Dataset.range(3).map(create_dataset)
- wrong_shapes = (tensor_shape.TensorShape(2),
- tensor_shape.TensorShape((3, 10)))
- with self.assertRaises(ValueError):
- dataset.apply(batching.assert_element_shape(wrong_shapes))
-
- def test_assert_element_shape_on_unknown_shape_dataset(self):
-
- def create_unknown_shape_dataset(x):
- return script_ops.py_func(
- lambda _: ( # pylint: disable=g-long-lambda
- np.ones(2, dtype=np.float32),
- np.zeros((3, 4), dtype=np.int32)),
- [x],
- [dtypes.float32, dtypes.int32])
-
- dataset = dataset_ops.Dataset.range(5).map(create_unknown_shape_dataset)
- unknown_shapes = (tensor_shape.TensorShape(None),
- tensor_shape.TensorShape(None))
- self.assertEqual(unknown_shapes, dataset.output_shapes)
-
- expected_shapes = (tensor_shape.TensorShape(2),
- tensor_shape.TensorShape((3, 4)))
- result = dataset.apply(batching.assert_element_shape(expected_shapes))
- self.assertEqual(expected_shapes, result.output_shapes)
-
- iterator = result.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- sess.run(init_op)
- for _ in range(5):
- sess.run(get_next)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def test_assert_wrong_element_shape_on_unknown_shape_dataset(self):
-
- def create_unknown_shape_dataset(x):
- return script_ops.py_func(
- lambda _: ( # pylint: disable=g-long-lambda
- np.ones(2, dtype=np.float32),
- np.zeros((3, 4), dtype=np.int32)),
- [x],
- [dtypes.float32, dtypes.int32])
-
- dataset = dataset_ops.Dataset.range(3).map(create_unknown_shape_dataset)
- unknown_shapes = (tensor_shape.TensorShape(None),
- tensor_shape.TensorShape(None))
- self.assertEqual(unknown_shapes, dataset.output_shapes)
-
- wrong_shapes = (tensor_shape.TensorShape(2),
- tensor_shape.TensorShape((3, 10)))
- iterator = (
- dataset.apply(batching.assert_element_shape(wrong_shapes))
- .make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- sess.run(init_op)
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(get_next)
-
- def test_assert_partial_element_shape(self):
-
- def create_dataset(_):
- return (array_ops.ones(2, dtype=dtypes.float32),
- array_ops.zeros((3, 4), dtype=dtypes.int32))
-
- dataset = dataset_ops.Dataset.range(5).map(create_dataset)
- partial_expected_shape = (tensor_shape.TensorShape(None), # Unknown shape
- tensor_shape.TensorShape((None, 4))) # Partial shape
- result = dataset.apply(
- batching.assert_element_shape(partial_expected_shape))
- # Partial shapes are merged with actual shapes:
- actual_shapes = (tensor_shape.TensorShape(2),
- tensor_shape.TensorShape((3, 4)))
- self.assertEqual(actual_shapes, result.output_shapes)
-
- iterator = result.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- sess.run(init_op)
- for _ in range(5):
- sess.run(get_next)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def test_assert_wrong_partial_element_shape(self):
-
- def create_dataset(_):
- return (array_ops.ones(2, dtype=dtypes.float32),
- array_ops.zeros((3, 4), dtype=dtypes.int32))
-
- dataset = dataset_ops.Dataset.range(3).map(create_dataset)
- wrong_shapes = (tensor_shape.TensorShape(2),
- tensor_shape.TensorShape((None, 10)))
- with self.assertRaises(ValueError):
- dataset.apply(batching.assert_element_shape(wrong_shapes))
-
- def test_assert_partial_element_shape_on_unknown_shape_dataset(self):
-
- def create_unknown_shape_dataset(x):
- return script_ops.py_func(
- lambda _: ( # pylint: disable=g-long-lambda
- np.ones(2, dtype=np.float32),
- np.zeros((3, 4), dtype=np.int32)),
- [x],
- [dtypes.float32, dtypes.int32])
-
- dataset = dataset_ops.Dataset.range(5).map(create_unknown_shape_dataset)
- unknown_shapes = (tensor_shape.TensorShape(None),
- tensor_shape.TensorShape(None))
- self.assertEqual(unknown_shapes, dataset.output_shapes)
-
- expected_shapes = (tensor_shape.TensorShape(2),
- tensor_shape.TensorShape((None, 4)))
- result = dataset.apply(batching.assert_element_shape(expected_shapes))
- self.assertEqual(expected_shapes, result.output_shapes)
-
- iterator = result.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- sess.run(init_op)
- for _ in range(5):
- sess.run(get_next)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def test_assert_wrong_partial_element_shape_on_unknown_shape_dataset(self):
-
- def create_unknown_shape_dataset(x):
- return script_ops.py_func(
- lambda _: ( # pylint: disable=g-long-lambda
- np.ones(2, dtype=np.float32),
- np.zeros((3, 4), dtype=np.int32)),
- [x],
- [dtypes.float32, dtypes.int32])
-
- dataset = dataset_ops.Dataset.range(3).map(create_unknown_shape_dataset)
- unknown_shapes = (tensor_shape.TensorShape(None),
- tensor_shape.TensorShape(None))
- self.assertEqual(unknown_shapes, dataset.output_shapes)
-
- wrong_shapes = (tensor_shape.TensorShape(2),
- tensor_shape.TensorShape((None, 10)))
- iterator = (
- dataset.apply(batching.assert_element_shape(wrong_shapes))
- .make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- sess.run(init_op)
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(get_next)
-
-
-class UnbatchDatasetBenchmark(test.Benchmark):
-
- def benchmarkNativeUnbatch(self):
- batch_sizes = [1, 2, 5, 10, 20, 50]
- elems_per_trial = 10000
- with ops.Graph().as_default():
- dataset = dataset_ops.Dataset.from_tensors("element").repeat(None)
- batch_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
- dataset = dataset.batch(batch_size_placeholder)
- dataset = dataset.apply(batching.unbatch())
- dataset = dataset.skip(elems_per_trial)
- iterator = dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with session.Session() as sess:
- for batch_size in batch_sizes:
- deltas = []
- for _ in range(5):
- sess.run(
- iterator.initializer,
- feed_dict={batch_size_placeholder: batch_size})
- start = time.time()
- sess.run(next_element.op)
- end = time.time()
- deltas.append((end - start) / elems_per_trial)
-
- median_wall_time = np.median(deltas)
- print("Unbatch (native) batch size: %d Median wall time per element:"
- " %f microseconds" % (batch_size, median_wall_time * 1e6))
- self.report_benchmark(
- iters=10000,
- wall_time=median_wall_time,
- name="benchmark_unbatch_dataset_native_batch_size_%d" %
- batch_size)
-
- # Include a benchmark of the previous `unbatch()` implementation that uses
- # a composition of more primitive ops. Eventually we'd hope to generate code
- # that is as good in both cases.
- def benchmarkOldUnbatchImplementation(self):
- batch_sizes = [1, 2, 5, 10, 20, 50]
- elems_per_trial = 10000
- with ops.Graph().as_default():
- dataset = dataset_ops.Dataset.from_tensors("element").repeat(None)
- batch_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
- dataset = dataset.batch(batch_size_placeholder)
- dataset = dataset.flat_map(dataset_ops.Dataset.from_tensor_slices)
- dataset = dataset.skip(elems_per_trial)
- iterator = dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with session.Session() as sess:
- for batch_size in batch_sizes:
- deltas = []
- for _ in range(5):
- sess.run(
- iterator.initializer,
- feed_dict={batch_size_placeholder: batch_size})
- start = time.time()
- sess.run(next_element.op)
- end = time.time()
- deltas.append((end - start) / elems_per_trial)
-
- median_wall_time = np.median(deltas)
- print("Unbatch (unfused) batch size: %d Median wall time per element:"
- " %f microseconds" % (batch_size, median_wall_time * 1e6))
- self.report_benchmark(
- iters=10000,
- wall_time=median_wall_time,
- name="benchmark_unbatch_dataset_unfused_batch_size_%d" %
- batch_size)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
deleted file mode 100644
index ae401f786c..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
+++ /dev/null
@@ -1,824 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import random
-
-import numpy as np
-
-from tensorflow.contrib.data.python.ops import grouping
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import string_ops
-from tensorflow.python.platform import test
-
-
-class GroupByReducerTest(test_base.DatasetTestBase):
-
- def checkResults(self, dataset, shapes, values):
- self.assertEqual(shapes, dataset.output_shapes)
- get_next = dataset.make_one_shot_iterator().get_next()
- with self.cached_session() as sess:
- for expected in values:
- got = sess.run(get_next)
- self.assertEqual(got, expected)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testSum(self):
- reducer = grouping.Reducer(
- init_func=lambda _: np.int64(0),
- reduce_func=lambda x, y: x + y,
- finalize_func=lambda x: x)
- for i in range(1, 11):
- dataset = dataset_ops.Dataset.range(2 * i).apply(
- grouping.group_by_reducer(lambda x: x % 2, reducer))
- self.checkResults(
- dataset, shapes=tensor_shape.scalar(), values=[(i - 1) * i, i * i])
-
- def testAverage(self):
-
- def reduce_fn(x, y):
- return (x[0] * x[1] + math_ops.cast(y, dtypes.float32)) / (
- x[1] + 1), x[1] + 1
-
- reducer = grouping.Reducer(
- init_func=lambda _: (0.0, 0.0),
- reduce_func=reduce_fn,
- finalize_func=lambda x, _: x)
- for i in range(1, 11):
- dataset = dataset_ops.Dataset.range(2 * i).apply(
- grouping.group_by_reducer(
- lambda x: math_ops.cast(x, dtypes.int64) % 2, reducer))
- self.checkResults(
- dataset, shapes=tensor_shape.scalar(), values=[i - 1, i])
-
- def testConcat(self):
- components = np.array(list("abcdefghijklmnopqrst")).view(np.chararray)
- reducer = grouping.Reducer(
- init_func=lambda x: "",
- reduce_func=lambda x, y: x + y[0],
- finalize_func=lambda x: x)
- for i in range(1, 11):
- dataset = dataset_ops.Dataset.zip(
- (dataset_ops.Dataset.from_tensor_slices(components),
- dataset_ops.Dataset.range(2 * i))).apply(
- grouping.group_by_reducer(lambda x, y: y % 2, reducer))
- self.checkResults(
- dataset,
- shapes=tensor_shape.scalar(),
- values=[b"acegikmoqs" [:i], b"bdfhjlnprt" [:i]])
-
- def testSparseSum(self):
- def _sparse(i):
- return sparse_tensor.SparseTensorValue(
- indices=np.array([[0, 0]]),
- values=(i * np.array([1], dtype=np.int64)),
- dense_shape=np.array([1, 1]))
-
- reducer = grouping.Reducer(
- init_func=lambda _: _sparse(np.int64(0)),
- reduce_func=lambda x, y: _sparse(x.values[0] + y.values[0]),
- finalize_func=lambda x: x.values[0])
- for i in range(1, 11):
- dataset = dataset_ops.Dataset.range(2 * i).map(_sparse).apply(
- grouping.group_by_reducer(lambda x: x.values[0] % 2, reducer))
- self.checkResults(
- dataset, shapes=tensor_shape.scalar(), values=[(i - 1) * i, i * i])
-
- def testChangingStateShape(self):
-
- def reduce_fn(x, _):
- # Statically known rank, but dynamic length.
- larger_dim = array_ops.concat([x[0], x[0]], 0)
- # Statically unknown rank.
- larger_rank = array_ops.expand_dims(x[1], 0)
- return larger_dim, larger_rank
-
- reducer = grouping.Reducer(
- init_func=lambda x: ([0], 1),
- reduce_func=reduce_fn,
- finalize_func=lambda x, y: (x, y))
-
- for i in range(1, 11):
- dataset = dataset_ops.Dataset.from_tensors(np.int64(0)).repeat(i).apply(
- grouping.group_by_reducer(lambda x: x, reducer))
- self.assertEqual([None], dataset.output_shapes[0].as_list())
- self.assertIs(None, dataset.output_shapes[1].ndims)
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- x, y = sess.run(get_next)
- self.assertAllEqual([0] * (2**i), x)
- self.assertAllEqual(np.array(1, ndmin=i), y)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testTypeMismatch(self):
- reducer = grouping.Reducer(
- init_func=lambda x: constant_op.constant(1, dtype=dtypes.int32),
- reduce_func=lambda x, y: constant_op.constant(1, dtype=dtypes.int64),
- finalize_func=lambda x: x)
-
- dataset = dataset_ops.Dataset.range(10)
- with self.assertRaisesRegexp(
- TypeError,
- "The element types for the new state must match the initial state."):
- dataset.apply(
- grouping.group_by_reducer(lambda _: np.int64(0), reducer))
-
- # TODO(b/78665031): Remove once non-scalar keys are supported.
- def testInvalidKeyShape(self):
- reducer = grouping.Reducer(
- init_func=lambda x: np.int64(0),
- reduce_func=lambda x, y: x + y,
- finalize_func=lambda x: x)
-
- dataset = dataset_ops.Dataset.range(10)
- with self.assertRaisesRegexp(
- ValueError, "`key_func` must return a single tf.int64 tensor."):
- dataset.apply(
- grouping.group_by_reducer(lambda _: np.int64((0, 0)), reducer))
-
- # TODO(b/78665031): Remove once non-int64 keys are supported.
- def testInvalidKeyType(self):
- reducer = grouping.Reducer(
- init_func=lambda x: np.int64(0),
- reduce_func=lambda x, y: x + y,
- finalize_func=lambda x: x)
-
- dataset = dataset_ops.Dataset.range(10)
- with self.assertRaisesRegexp(
- ValueError, "`key_func` must return a single tf.int64 tensor."):
- dataset.apply(
- grouping.group_by_reducer(lambda _: "wrong", reducer))
-
- def testTuple(self):
- def init_fn(_):
- return np.array([], dtype=np.int64), np.int64(0)
-
- def reduce_fn(state, value):
- s1, s2 = state
- v1, v2 = value
- return array_ops.concat([s1, [v1]], 0), s2 + v2
-
- def finalize_fn(s1, s2):
- return s1, s2
-
- reducer = grouping.Reducer(init_fn, reduce_fn, finalize_fn)
- dataset = dataset_ops.Dataset.zip(
- (dataset_ops.Dataset.range(10), dataset_ops.Dataset.range(10))).apply(
- grouping.group_by_reducer(lambda x, y: np.int64(0), reducer))
- get_next = dataset.make_one_shot_iterator().get_next()
- with self.cached_session() as sess:
- x, y = sess.run(get_next)
- self.assertAllEqual(x, np.asarray([x for x in range(10)]))
- self.assertEqual(y, 45)
-
-
-class GroupByWindowTest(test_base.DatasetTestBase):
-
- def testSimple(self):
- components = np.random.randint(100, size=(200,)).astype(np.int64)
- iterator = (
- dataset_ops.Dataset.from_tensor_slices(components).map(lambda x: x * x)
- .apply(
- grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(4),
- 4)).make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
- counts = []
- with self.assertRaises(errors.OutOfRangeError):
- while True:
- result = sess.run(get_next)
- self.assertTrue(
- all(x % 2 == 0
- for x in result) or all(x % 2 == 1)
- for x in result)
- counts.append(result.shape[0])
-
- self.assertEqual(len(components), sum(counts))
- num_full_batches = len([c for c in counts if c == 4])
- self.assertGreaterEqual(num_full_batches, 24)
- self.assertTrue(all(c == 4 for c in counts[:num_full_batches]))
-
- def testImmediateOutput(self):
- components = np.array(
- [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0], dtype=np.int64)
- iterator = (
- dataset_ops.Dataset.from_tensor_slices(components).repeat(-1).apply(
- grouping.group_by_window(lambda x: x % 3, lambda _, xs: xs.batch(4),
- 4)).make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
- # The input is infinite, so this test demonstrates that:
- # 1. We produce output without having to consume the entire input,
- # 2. Different buckets can produce output at different rates, and
- # 3. For deterministic input, the output is deterministic.
- for _ in range(3):
- self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
- self.assertAllEqual([1, 1, 1, 1], sess.run(get_next))
- self.assertAllEqual([2, 2, 2, 2], sess.run(get_next))
- self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
-
- def testSmallGroups(self):
- components = np.array([0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0], dtype=np.int64)
- iterator = (
- dataset_ops.Dataset.from_tensor_slices(components).apply(
- grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(4),
- 4)).make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
- self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
- self.assertAllEqual([1, 1, 1, 1], sess.run(get_next))
- # The small outputs at the end are deterministically produced in key
- # order.
- self.assertAllEqual([0, 0, 0], sess.run(get_next))
- self.assertAllEqual([1], sess.run(get_next))
-
- def testEmpty(self):
- iterator = (
- dataset_ops.Dataset.range(4).apply(
- grouping.group_by_window(lambda _: 0, lambda _, xs: xs, 0))
- .make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
- with self.assertRaisesRegexp(
- errors.InvalidArgumentError,
- "Window size must be greater than zero, but got 0."):
- print(sess.run(get_next))
-
- def testReduceFuncError(self):
- components = np.random.randint(100, size=(200,)).astype(np.int64)
-
- def reduce_func(_, xs):
- # Introduce an incorrect padded shape that cannot (currently) be
- # detected at graph construction time.
- return xs.padded_batch(
- 4,
- padded_shapes=(tensor_shape.TensorShape([]),
- constant_op.constant([5], dtype=dtypes.int64) * -1))
-
- iterator = (
- dataset_ops.Dataset.from_tensor_slices(components)
- .map(lambda x: (x, ops.convert_to_tensor([x * x]))).apply(
- grouping.group_by_window(lambda x, _: x % 2, reduce_func,
- 32)).make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(get_next)
-
- def testConsumeWindowDatasetMoreThanOnce(self):
- components = np.random.randint(50, size=(200,)).astype(np.int64)
-
- def reduce_func(key, window):
- # Apply two different kinds of padding to the input: tight
- # padding, and quantized (to a multiple of 10) padding.
- return dataset_ops.Dataset.zip((
- window.padded_batch(
- 4, padded_shapes=tensor_shape.TensorShape([None])),
- window.padded_batch(
- 4, padded_shapes=ops.convert_to_tensor([(key + 1) * 10])),
- ))
-
- iterator = (
- dataset_ops.Dataset.from_tensor_slices(components)
- .map(lambda x: array_ops.fill([math_ops.cast(x, dtypes.int32)], x))
- .apply(grouping.group_by_window(
- lambda x: math_ops.cast(array_ops.shape(x)[0] // 10, dtypes.int64),
- reduce_func, 4))
- .make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
- counts = []
- with self.assertRaises(errors.OutOfRangeError):
- while True:
- tight_result, multiple_of_10_result = sess.run(get_next)
- self.assertEqual(0, multiple_of_10_result.shape[1] % 10)
- self.assertAllEqual(tight_result,
- multiple_of_10_result[:, :tight_result.shape[1]])
- counts.append(tight_result.shape[0])
- self.assertEqual(len(components), sum(counts))
-
-
-# NOTE(mrry): These tests are based on the tests in bucket_ops_test.py.
-# Currently, they use a constant batch size, though should be made to use a
-# different batch size per key.
-class BucketTest(test_base.DatasetTestBase):
-
- def _dynamicPad(self, bucket, window, window_size):
- # TODO(mrry): To match `tf.contrib.training.bucket()`, implement a
- # generic form of padded_batch that pads every component
- # dynamically and does not rely on static shape information about
- # the arguments.
- return dataset_ops.Dataset.zip(
- (dataset_ops.Dataset.from_tensors(bucket),
- window.padded_batch(
- 32, (tensor_shape.TensorShape([]), tensor_shape.TensorShape(
- [None]), tensor_shape.TensorShape([3])))))
-
- def testSingleBucket(self):
-
- def _map_fn(v):
- return (v, array_ops.fill([v], v),
- array_ops.fill([3], string_ops.as_string(v)))
-
- input_dataset = (
- dataset_ops.Dataset.from_tensor_slices(math_ops.range(32)).map(_map_fn))
-
- bucketed_dataset = input_dataset.apply(
- grouping.group_by_window(
- lambda x, y, z: 0,
- lambda k, bucket: self._dynamicPad(k, bucket, 32), 32))
-
- iterator = bucketed_dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
-
- which_bucket, bucketed_values = sess.run(get_next)
-
- self.assertEqual(0, which_bucket)
-
- expected_scalar_int = np.arange(32, dtype=np.int64)
- expected_unk_int64 = np.zeros((32, 31)).astype(np.int64)
- for i in range(32):
- expected_unk_int64[i, :i] = i
- expected_vec3_str = np.vstack(3 * [np.arange(32).astype(bytes)]).T
-
- self.assertAllEqual(expected_scalar_int, bucketed_values[0])
- self.assertAllEqual(expected_unk_int64, bucketed_values[1])
- self.assertAllEqual(expected_vec3_str, bucketed_values[2])
-
- def testEvenOddBuckets(self):
-
- def _map_fn(v):
- return (v, array_ops.fill([v], v),
- array_ops.fill([3], string_ops.as_string(v)))
-
- input_dataset = (
- dataset_ops.Dataset.from_tensor_slices(math_ops.range(64)).map(_map_fn))
-
- bucketed_dataset = input_dataset.apply(
- grouping.group_by_window(
- lambda x, y, z: math_ops.cast(x % 2, dtypes.int64),
- lambda k, bucket: self._dynamicPad(k, bucket, 32), 32))
-
- iterator = bucketed_dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
-
- # Get two minibatches (one containing even values, one containing odds)
- which_bucket_even, bucketed_values_even = sess.run(get_next)
- which_bucket_odd, bucketed_values_odd = sess.run(get_next)
-
- # Count number of bucket_tensors.
- self.assertEqual(3, len(bucketed_values_even))
- self.assertEqual(3, len(bucketed_values_odd))
-
- # Ensure bucket 0 was used for all minibatch entries.
- self.assertAllEqual(0, which_bucket_even)
- self.assertAllEqual(1, which_bucket_odd)
-
- # Test the first bucket outputted, the events starting at 0
- expected_scalar_int = np.arange(0, 32 * 2, 2, dtype=np.int64)
- expected_unk_int64 = np.zeros((32, 31 * 2)).astype(np.int64)
- for i in range(0, 32):
- expected_unk_int64[i, :2 * i] = 2 * i
- expected_vec3_str = np.vstack(
- 3 * [np.arange(0, 32 * 2, 2).astype(bytes)]).T
-
- self.assertAllEqual(expected_scalar_int, bucketed_values_even[0])
- self.assertAllEqual(expected_unk_int64, bucketed_values_even[1])
- self.assertAllEqual(expected_vec3_str, bucketed_values_even[2])
-
- # Test the second bucket outputted, the odds starting at 1
- expected_scalar_int = np.arange(1, 32 * 2 + 1, 2, dtype=np.int64)
- expected_unk_int64 = np.zeros((32, 31 * 2 + 1)).astype(np.int64)
- for i in range(0, 32):
- expected_unk_int64[i, :2 * i + 1] = 2 * i + 1
- expected_vec3_str = np.vstack(
- 3 * [np.arange(1, 32 * 2 + 1, 2).astype(bytes)]).T
-
- self.assertAllEqual(expected_scalar_int, bucketed_values_odd[0])
- self.assertAllEqual(expected_unk_int64, bucketed_values_odd[1])
- self.assertAllEqual(expected_vec3_str, bucketed_values_odd[2])
-
- def testEvenOddBucketsFilterOutAllOdd(self):
-
- def _map_fn(v):
- return {
- "x": v,
- "y": array_ops.fill([v], v),
- "z": array_ops.fill([3], string_ops.as_string(v))
- }
-
- def _dynamic_pad_fn(bucket, window, _):
- return dataset_ops.Dataset.zip(
- (dataset_ops.Dataset.from_tensors(bucket),
- window.padded_batch(
- 32, {
- "x": tensor_shape.TensorShape([]),
- "y": tensor_shape.TensorShape([None]),
- "z": tensor_shape.TensorShape([3])
- })))
-
- input_dataset = (
- dataset_ops.Dataset.from_tensor_slices(math_ops.range(128)).map(_map_fn)
- .filter(lambda d: math_ops.equal(d["x"] % 2, 0)))
-
- bucketed_dataset = input_dataset.apply(
- grouping.group_by_window(
- lambda d: math_ops.cast(d["x"] % 2, dtypes.int64),
- lambda k, bucket: _dynamic_pad_fn(k, bucket, 32), 32))
-
- iterator = bucketed_dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
-
- # Get two minibatches ([0, 2, ...] and [64, 66, ...])
- which_bucket0, bucketed_values_even0 = sess.run(get_next)
- which_bucket1, bucketed_values_even1 = sess.run(get_next)
-
- # Ensure that bucket 1 was completely filtered out
- self.assertAllEqual(0, which_bucket0)
- self.assertAllEqual(0, which_bucket1)
- self.assertAllEqual(
- np.arange(0, 64, 2, dtype=np.int64), bucketed_values_even0["x"])
- self.assertAllEqual(
- np.arange(64, 128, 2, dtype=np.int64), bucketed_values_even1["x"])
-
- def testDynamicWindowSize(self):
- components = np.arange(100).astype(np.int64)
-
- # Key fn: even/odd
- # Reduce fn: batches of 5
- # Window size fn: even=5, odd=10
-
- def window_size_func(key):
- window_sizes = constant_op.constant([5, 10], dtype=dtypes.int64)
- return window_sizes[key]
-
- dataset = dataset_ops.Dataset.from_tensor_slices(components).apply(
- grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(20),
- None, window_size_func))
- iterator = dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
- with self.assertRaises(errors.OutOfRangeError):
- batches = 0
- while True:
- result = sess.run(get_next)
- is_even = all(x % 2 == 0 for x in result)
- is_odd = all(x % 2 == 1 for x in result)
- self.assertTrue(is_even or is_odd)
- expected_batch_size = 5 if is_even else 10
- self.assertEqual(expected_batch_size, result.shape[0])
- batches += 1
-
- self.assertEqual(batches, 15)
-
-
-def _element_length_fn(x, y=None):
- del y
- return array_ops.shape(x)[0]
-
-
-def _to_sparse_tensor(record):
- return sparse_tensor.SparseTensor(**record)
-
-
-def _format_record(array, sparse):
- if sparse:
- return {
- "values": array,
- "indices": [[i] for i in range(len(array))],
- "dense_shape": (len(array),)
- }
- return array
-
-
-def _get_record_type(sparse):
- if sparse:
- return {
- "values": dtypes.int64,
- "indices": dtypes.int64,
- "dense_shape": dtypes.int64
- }
- return dtypes.int32
-
-
-def _get_record_shape(sparse):
- if sparse:
- return {
- "values": tensor_shape.TensorShape([None,]),
- "indices": tensor_shape.TensorShape([None, 1]),
- "dense_shape": tensor_shape.TensorShape([1,])
- }
- return tensor_shape.TensorShape([None])
-
-
-class BucketBySequenceLength(test_base.DatasetTestBase):
-
- def testBucket(self):
-
- boundaries = [10, 20, 30]
- batch_sizes = [10, 8, 4, 2]
- lengths = [8, 13, 25, 35]
-
- def build_dataset(sparse):
- def _generator():
- # Produce 1 batch for each bucket
- elements = []
- for batch_size, length in zip(batch_sizes, lengths):
- record_len = length - 1
- for _ in range(batch_size):
- elements.append([1] * record_len)
- record_len = length
- random.shuffle(elements)
- for el in elements:
- yield (_format_record(el, sparse),)
- dataset = dataset_ops.Dataset.from_generator(
- _generator,
- (_get_record_type(sparse),),
- (_get_record_shape(sparse),))
- if sparse:
- dataset = dataset.map(lambda x: (_to_sparse_tensor(x),))
- return dataset
-
- def _test_bucket_by_padding(no_padding):
- dataset = build_dataset(sparse=no_padding)
- dataset = dataset.apply(
- grouping.bucket_by_sequence_length(
- _element_length_fn,
- boundaries,
- batch_sizes,
- no_padding=no_padding))
- batch, = dataset.make_one_shot_iterator().get_next()
-
- with self.cached_session() as sess:
- batches = []
- for _ in range(4):
- batches.append(sess.run(batch))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(batch)
- batch_sizes_val = []
- lengths_val = []
- for batch in batches:
- shape = batch.dense_shape if no_padding else batch.shape
- batch_size = shape[0]
- length = shape[1]
- batch_sizes_val.append(batch_size)
- lengths_val.append(length)
- sum_check = batch.values.sum() if no_padding else batch.sum()
- self.assertEqual(sum_check, batch_size * length - 1)
- self.assertEqual(sum(batch_sizes_val), sum(batch_sizes))
- self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val))
- self.assertEqual(sorted(lengths), sorted(lengths_val))
-
- for no_padding in (True, False):
- _test_bucket_by_padding(no_padding)
-
- def testPadToBoundary(self):
-
- boundaries = [10, 20, 30]
- batch_sizes = [10, 8, 4, 2]
- lengths = [8, 13, 25]
-
- def element_gen():
- # Produce 1 batch for each bucket
- elements = []
- for batch_size, length in zip(batch_sizes[:-1], lengths):
- for _ in range(batch_size):
- elements.append([1] * length)
- random.shuffle(elements)
- for el in elements:
- yield (el,)
- for _ in range(batch_sizes[-1]):
- el = [1] * (boundaries[-1] + 5)
- yield (el,)
-
- element_len = lambda el: array_ops.shape(el)[0]
- dataset = dataset_ops.Dataset.from_generator(
- element_gen, (dtypes.int64,), ([None],)).apply(
- grouping.bucket_by_sequence_length(
- element_len, boundaries, batch_sizes,
- pad_to_bucket_boundary=True))
- batch, = dataset.make_one_shot_iterator().get_next()
-
- with self.cached_session() as sess:
- batches = []
- for _ in range(3):
- batches.append(sess.run(batch))
- with self.assertRaisesOpError("bucket_boundaries"):
- sess.run(batch)
- batch_sizes_val = []
- lengths_val = []
- for batch in batches:
- batch_size = batch.shape[0]
- length = batch.shape[1]
- batch_sizes_val.append(batch_size)
- lengths_val.append(length)
- batch_sizes = batch_sizes[:-1]
- self.assertEqual(sum(batch_sizes_val), sum(batch_sizes))
- self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val))
- self.assertEqual([boundary - 1 for boundary in sorted(boundaries)],
- sorted(lengths_val))
-
- def testPadToBoundaryNoExtraneousPadding(self):
-
- boundaries = [3, 7, 11]
- batch_sizes = [2, 2, 2, 2]
- lengths = range(1, 11)
-
- def element_gen():
- for length in lengths:
- yield ([1] * length,)
-
- element_len = lambda element: array_ops.shape(element)[0]
- dataset = dataset_ops.Dataset.from_generator(
- element_gen, (dtypes.int64,), ([None],)).apply(
- grouping.bucket_by_sequence_length(
- element_len, boundaries, batch_sizes,
- pad_to_bucket_boundary=True))
- batch, = dataset.make_one_shot_iterator().get_next()
-
- with self.cached_session() as sess:
- batches = []
- for _ in range(5):
- batches.append(sess.run(batch))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(batch)
-
- self.assertAllEqual(batches[0], [[1, 0],
- [1, 1]])
- self.assertAllEqual(batches[1], [[1, 1, 1, 0, 0, 0],
- [1, 1, 1, 1, 0, 0]])
- self.assertAllEqual(batches[2], [[1, 1, 1, 1, 1, 0],
- [1, 1, 1, 1, 1, 1]])
- self.assertAllEqual(batches[3], [[1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
- [1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
- self.assertAllEqual(batches[4], [[1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
- [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
-
- def testTupleElements(self):
-
- def build_dataset(sparse):
- def _generator():
- text = [[1, 2, 3], [3, 4, 5, 6, 7], [1, 2], [8, 9, 0, 2, 3]]
- label = [1, 2, 1, 2]
- for x, y in zip(text, label):
- yield (_format_record(x, sparse), y)
- dataset = dataset_ops.Dataset.from_generator(
- generator=_generator,
- output_types=(_get_record_type(sparse), dtypes.int32),
- output_shapes=(_get_record_shape(sparse),
- tensor_shape.TensorShape([])))
- if sparse:
- dataset = dataset.map(lambda x, y: (_to_sparse_tensor(x), y))
- return dataset
-
- def _test_tuple_elements_by_padding(no_padding):
- dataset = build_dataset(sparse=no_padding)
- dataset = dataset.apply(grouping.bucket_by_sequence_length(
- element_length_func=_element_length_fn,
- bucket_batch_sizes=[2, 2, 2],
- bucket_boundaries=[0, 8],
- no_padding=no_padding))
- shapes = dataset.output_shapes
- self.assertEqual([None, None], shapes[0].as_list())
- self.assertEqual([None], shapes[1].as_list())
-
- for no_padding in (True, False):
- _test_tuple_elements_by_padding(no_padding)
-
- def testBucketSparse(self):
- """Tests bucketing of sparse tensors (case where `no_padding` == True).
-
- Test runs on following dataset:
- [
- [0],
- [0, 1],
- [0, 1, 2]
- ...
- [0, ..., max_len - 1]
- ]
- Sequences are bucketed by length and batched with
- `batch_size` < `bucket_size`.
- """
-
- min_len = 0
- max_len = 100
- batch_size = 7
- bucket_size = 10
-
- def _build_dataset():
- input_data = [range(i+1) for i in range(min_len, max_len)]
- def generator_fn():
- for record in input_data:
- yield _format_record(record, sparse=True)
- dataset = dataset_ops.Dataset.from_generator(
- generator=generator_fn,
- output_types=_get_record_type(sparse=True))
- dataset = dataset.map(_to_sparse_tensor)
- return dataset
-
- def _compute_expected_batches():
- """Computes expected batch outputs and stores in a set."""
- all_expected_sparse_tensors = set()
- for bucket_start_len in range(min_len, max_len, bucket_size):
- for batch_offset in range(0, bucket_size, batch_size):
- batch_start_len = bucket_start_len + batch_offset
- batch_end_len = min(batch_start_len + batch_size,
- bucket_start_len + bucket_size)
- expected_indices = []
- expected_values = []
- for length in range(batch_start_len, batch_end_len):
- for val in range(length + 1):
- expected_indices.append((length - batch_start_len, val))
- expected_values.append(val)
- expected_sprs_tensor = (tuple(expected_indices),
- tuple(expected_values))
- all_expected_sparse_tensors.add(expected_sprs_tensor)
- return all_expected_sparse_tensors
-
- def _compute_batches(dataset):
- """Computes actual batch outputs of dataset and stores in a set."""
- batch = dataset.make_one_shot_iterator().get_next()
- all_sparse_tensors = set()
- with self.cached_session() as sess:
- with self.assertRaises(errors.OutOfRangeError):
- while True:
- output = sess.run(batch)
- sprs_tensor = (tuple([tuple(idx) for idx in output.indices]),
- tuple(output.values))
- all_sparse_tensors.add(sprs_tensor)
- return all_sparse_tensors
-
- dataset = _build_dataset()
- boundaries = range(min_len + bucket_size + 1, max_len, bucket_size)
- dataset = dataset.apply(grouping.bucket_by_sequence_length(
- _element_length_fn,
- boundaries,
- [batch_size] * (len(boundaries) + 1),
- no_padding=True))
- batches = _compute_batches(dataset)
- expected_batches = _compute_expected_batches()
- self.assertEqual(batches, expected_batches)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py
deleted file mode 100644
index 5b3c512b64..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py
+++ /dev/null
@@ -1,632 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for CsvDatasetOp."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import gzip
-import os
-import string
-import tempfile
-import time
-import zlib
-
-import numpy as np
-
-from tensorflow.contrib.data.python.ops import error_ops
-from tensorflow.contrib.data.python.ops import readers
-from tensorflow.python.client import session
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import readers as core_readers
-from tensorflow.python.eager import context
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import test_util
-from tensorflow.python.ops import parsing_ops
-from tensorflow.python.platform import gfile
-from tensorflow.python.platform import googletest
-from tensorflow.python.platform import test
-
-
-@test_util.run_all_in_graph_and_eager_modes
-class CsvDatasetOpTest(test_base.DatasetTestBase):
-
- def _setup_files(self, inputs, linebreak='\n', compression_type=None):
- filenames = []
- for i, ip in enumerate(inputs):
- fn = os.path.join(self.get_temp_dir(), 'temp_%d.csv' % i)
- contents = linebreak.join(ip).encode('utf-8')
- if compression_type is None:
- with open(fn, 'wb') as f:
- f.write(contents)
- elif compression_type == 'GZIP':
- with gzip.GzipFile(fn, 'wb') as f:
- f.write(contents)
- elif compression_type == 'ZLIB':
- contents = zlib.compress(contents)
- with open(fn, 'wb') as f:
- f.write(contents)
- else:
- raise ValueError('Unsupported compression_type', compression_type)
- filenames.append(fn)
- return filenames
-
- def _make_test_datasets(self, inputs, **kwargs):
- # Test by comparing its output to what we could get with map->decode_csv
- filenames = self._setup_files(inputs)
- dataset_expected = core_readers.TextLineDataset(filenames)
- dataset_expected = dataset_expected.map(
- lambda l: parsing_ops.decode_csv(l, **kwargs))
- dataset_actual = readers.CsvDataset(filenames, **kwargs)
- return (dataset_actual, dataset_expected)
-
- def _test_by_comparison(self, inputs, **kwargs):
- """Checks that CsvDataset is equiv to TextLineDataset->map(decode_csv)."""
- dataset_actual, dataset_expected = self._make_test_datasets(
- inputs, **kwargs)
- self.assertDatasetsEqual(dataset_actual, dataset_expected)
-
- def _verify_output_or_err(self,
- dataset,
- expected_output=None,
- expected_err_re=None):
- if expected_err_re is None:
- # Verify that output is expected, without errors
- nxt = self.getNext(dataset)
- expected_output = [[
- v.encode('utf-8') if isinstance(v, str) else v for v in op
- ] for op in expected_output]
- for value in expected_output:
- op = self.evaluate(nxt())
- self.assertAllEqual(op, value)
- with self.assertRaises(errors.OutOfRangeError):
- self.evaluate(nxt())
- else:
- # Verify that OpError is produced as expected
- with self.assertRaisesOpError(expected_err_re):
- nxt = self.getNext(dataset)
- while True:
- try:
- self.evaluate(nxt())
- except errors.OutOfRangeError:
- break
-
- def _test_dataset(
- self,
- inputs,
- expected_output=None,
- expected_err_re=None,
- linebreak='\n',
- compression_type=None, # Used for both setup and parsing
- **kwargs):
- """Checks that elements produced by CsvDataset match expected output."""
- # Convert str type because py3 tf strings are bytestrings
- filenames = self._setup_files(inputs, linebreak, compression_type)
- kwargs['compression_type'] = compression_type
- dataset = readers.CsvDataset(filenames, **kwargs)
- self._verify_output_or_err(dataset, expected_output, expected_err_re)
-
- def testCsvDataset_requiredFields(self):
- record_defaults = [[]] * 4
- inputs = [['1,2,3,4']]
- self._test_by_comparison(inputs, record_defaults=record_defaults)
-
- def testCsvDataset_int(self):
- record_defaults = [[0]] * 4
- inputs = [['1,2,3,4', '5,6,7,8']]
- self._test_by_comparison(inputs, record_defaults=record_defaults)
-
- def testCsvDataset_float(self):
- record_defaults = [[0.0]] * 4
- inputs = [['1.0,2.1,3.2,4.3', '5.4,6.5,7.6,8.7']]
- self._test_by_comparison(inputs, record_defaults=record_defaults)
-
- def testCsvDataset_string(self):
- record_defaults = [['']] * 4
- inputs = [['1.0,2.1,hello,4.3', '5.4,6.5,goodbye,8.7']]
- self._test_by_comparison(inputs, record_defaults=record_defaults)
-
- def testCsvDataset_withEmptyFields(self):
- record_defaults = [[0]] * 4
- inputs = [[',,,', '1,1,1,', ',2,2,2']]
- self._test_dataset(
- inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]],
- record_defaults=record_defaults)
-
- def testCsvDataset_errWithUnquotedQuotes(self):
- record_defaults = [['']] * 3
- inputs = [['1,2"3,4']]
- self._test_dataset(
- inputs,
- expected_err_re='Unquoted fields cannot have quotes inside',
- record_defaults=record_defaults)
-
- def testCsvDataset_errWithUnescapedQuotes(self):
- record_defaults = [['']] * 3
- inputs = [['"a"b","c","d"']]
- self._test_dataset(
- inputs,
- expected_err_re=
- 'Quote inside a string has to be escaped by another quote',
- record_defaults=record_defaults)
-
- def testCsvDataset_ignoreErrWithUnescapedQuotes(self):
- record_defaults = [['']] * 3
- inputs = [['1,"2"3",4', '1,"2"3",4",5,5', 'a,b,"c"d"', 'e,f,g']]
- filenames = self._setup_files(inputs)
- dataset = readers.CsvDataset(filenames, record_defaults=record_defaults)
- dataset = dataset.apply(error_ops.ignore_errors())
- self._verify_output_or_err(dataset, [['e', 'f', 'g']])
-
- def testCsvDataset_ignoreErrWithUnquotedQuotes(self):
- record_defaults = [['']] * 3
- inputs = [['1,2"3,4', 'a,b,c"d', '9,8"7,6,5', 'e,f,g']]
- filenames = self._setup_files(inputs)
- dataset = readers.CsvDataset(filenames, record_defaults=record_defaults)
- dataset = dataset.apply(error_ops.ignore_errors())
- self._verify_output_or_err(dataset, [['e', 'f', 'g']])
-
- def testCsvDataset_withNoQuoteDelimAndUnquotedQuotes(self):
- record_defaults = [['']] * 3
- inputs = [['1,2"3,4']]
- self._test_by_comparison(
- inputs, record_defaults=record_defaults, use_quote_delim=False)
-
- def testCsvDataset_mixedTypes(self):
- record_defaults = [
- constant_op.constant([], dtype=dtypes.int32),
- constant_op.constant([], dtype=dtypes.float32),
- constant_op.constant([], dtype=dtypes.string),
- constant_op.constant([], dtype=dtypes.float64)
- ]
- inputs = [['1,2.1,3.2,4.3', '5,6.5,7.6,8.7']]
- self._test_by_comparison(inputs, record_defaults=record_defaults)
-
- def testCsvDataset_withUseQuoteDelimFalse(self):
- record_defaults = [['']] * 4
- inputs = [['1,2,"3,4"', '"5,6",7,8']]
- self._test_by_comparison(
- inputs, record_defaults=record_defaults, use_quote_delim=False)
-
- def testCsvDataset_withFieldDelim(self):
- record_defaults = [[0]] * 4
- inputs = [['1:2:3:4', '5:6:7:8']]
- self._test_by_comparison(
- inputs, record_defaults=record_defaults, field_delim=':')
-
- def testCsvDataset_withNaValue(self):
- record_defaults = [[0]] * 4
- inputs = [['1,NA,3,4', 'NA,6,7,8']]
- self._test_by_comparison(
- inputs, record_defaults=record_defaults, na_value='NA')
-
- def testCsvDataset_withSelectCols(self):
- record_defaults = [['']] * 2
- inputs = [['1,2,3,4', '"5","6","7","8"']]
- self._test_by_comparison(
- inputs, record_defaults=record_defaults, select_cols=[1, 2])
-
- def testCsvDataset_withSelectColsTooHigh(self):
- record_defaults = [[0]] * 2
- inputs = [['1,2,3,4', '5,6,7,8']]
- self._test_dataset(
- inputs,
- expected_err_re='Expect 2 fields but have 1 in record',
- record_defaults=record_defaults,
- select_cols=[3, 4])
-
- def testCsvDataset_withOneCol(self):
- record_defaults = [['NA']]
- inputs = [['0', '', '2']]
- self._test_dataset(
- inputs, [['0'], ['NA'], ['2']], record_defaults=record_defaults)
-
- def testCsvDataset_withMultipleFiles(self):
- record_defaults = [[0]] * 4
- inputs = [['1,2,3,4', '5,6,7,8'], ['5,6,7,8']]
- self._test_by_comparison(inputs, record_defaults=record_defaults)
-
- def testCsvDataset_withLeadingAndTrailingSpaces(self):
- record_defaults = [[0.0]] * 4
- inputs = [['0, 1, 2, 3']]
- expected = [[0.0, 1.0, 2.0, 3.0]]
- self._test_dataset(inputs, expected, record_defaults=record_defaults)
-
- def testCsvDataset_errorWithMissingDefault(self):
- record_defaults = [[]] * 2
- inputs = [['0,']]
- self._test_dataset(
- inputs,
- expected_err_re='Field 1 is required but missing in record!',
- record_defaults=record_defaults)
-
- def testCsvDataset_errorWithFewerDefaultsThanFields(self):
- record_defaults = [[0.0]] * 2
- inputs = [['0,1,2,3']]
- self._test_dataset(
- inputs,
- expected_err_re='Expect 2 fields but have more in record',
- record_defaults=record_defaults)
-
- def testCsvDataset_errorWithMoreDefaultsThanFields(self):
- record_defaults = [[0.0]] * 5
- inputs = [['0,1,2,3']]
- self._test_dataset(
- inputs,
- expected_err_re='Expect 5 fields but have 4 in record',
- record_defaults=record_defaults)
-
- def testCsvDataset_withHeader(self):
- record_defaults = [[0]] * 2
- inputs = [['col1,col2', '1,2']]
- expected = [[1, 2]]
- self._test_dataset(
- inputs,
- expected,
- record_defaults=record_defaults,
- header=True,
- )
-
- def testCsvDataset_withHeaderAndNoRecords(self):
- record_defaults = [[0]] * 2
- inputs = [['col1,col2']]
- expected = []
- self._test_dataset(
- inputs,
- expected,
- record_defaults=record_defaults,
- header=True,
- )
-
- def testCsvDataset_errorWithHeaderEmptyFile(self):
- record_defaults = [[0]] * 2
- inputs = [[]]
- expected_err_re = "Can't read header of file"
- self._test_dataset(
- inputs,
- expected_err_re=expected_err_re,
- record_defaults=record_defaults,
- header=True,
- )
-
- def testCsvDataset_withEmptyFile(self):
- record_defaults = [['']] * 2
- inputs = [['']] # Empty file
- self._test_dataset(
- inputs, expected_output=[], record_defaults=record_defaults)
-
- def testCsvDataset_errorWithEmptyRecord(self):
- record_defaults = [['']] * 2
- inputs = [['', '1,2']] # First record is empty
- self._test_dataset(
- inputs,
- expected_err_re='Expect 2 fields but have 1 in record',
- record_defaults=record_defaults)
-
- def testCsvDataset_withChainedOps(self):
- # Testing that one dataset can create multiple iterators fine.
- # `repeat` creates multiple iterators from the same C++ Dataset.
- record_defaults = [[0]] * 4
- inputs = [['1,,3,4', '5,6,,8']]
- ds_actual, ds_expected = self._make_test_datasets(
- inputs, record_defaults=record_defaults)
- self.assertDatasetsEqual(
- ds_actual.repeat(5).prefetch(1),
- ds_expected.repeat(5).prefetch(1))
-
- def testCsvDataset_withTypeDefaults(self):
- # Testing using dtypes as record_defaults for required fields
- record_defaults = [dtypes.float32, [0.0]]
- inputs = [['1.0,2.0', '3.0,4.0']]
- self._test_dataset(
- inputs,
- [[1.0, 2.0], [3.0, 4.0]],
- record_defaults=record_defaults,
- )
-
- def testMakeCsvDataset_fieldOrder(self):
- data = [[
- '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19',
- '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19'
- ]]
- file_path = self._setup_files(data)
-
- ds = readers.make_csv_dataset(
- file_path, batch_size=1, shuffle=False, num_epochs=1)
- nxt = self.getNext(ds)
-
- result = list(self.evaluate(nxt()).values())
-
- self.assertEqual(result, sorted(result))
-
-## The following tests exercise parsing logic for quoted fields
-
- def testCsvDataset_withQuoted(self):
- record_defaults = [['']] * 4
- inputs = [['"a","b","c :)","d"', '"e","f","g :(","h"']]
- self._test_by_comparison(inputs, record_defaults=record_defaults)
-
- def testCsvDataset_withOneColAndQuotes(self):
- record_defaults = [['']]
- inputs = [['"0"', '"1"', '"2"']]
- self._test_dataset(
- inputs, [['0'], ['1'], ['2']], record_defaults=record_defaults)
-
- def testCsvDataset_withNewLine(self):
- # In this case, we expect it to behave differently from
- # TextLineDataset->map(decode_csv) since that flow has bugs
- record_defaults = [['']] * 4
- inputs = [['a,b,"""c""\n0","d\ne"', 'f,g,h,i']]
- expected = [['a', 'b', '"c"\n0', 'd\ne'], ['f', 'g', 'h', 'i']]
- self._test_dataset(inputs, expected, record_defaults=record_defaults)
-
- def testCsvDataset_withNewLineInUnselectedCol(self):
- record_defaults = [['']]
- inputs = [['1,"2\n3",4', '5,6,7']]
- self._test_dataset(
- inputs,
- expected_output=[['1'], ['5']],
- record_defaults=record_defaults,
- select_cols=[0])
-
- def testCsvDataset_withMultipleNewLines(self):
- # In this case, we expect it to behave differently from
- # TextLineDataset->map(decode_csv) since that flow has bugs
- record_defaults = [['']] * 4
- inputs = [['a,"b\n\nx","""c""\n \n0","d\ne"', 'f,g,h,i']]
- expected = [['a', 'b\n\nx', '"c"\n \n0', 'd\ne'], ['f', 'g', 'h', 'i']]
- self._test_dataset(inputs, expected, record_defaults=record_defaults)
-
- def testCsvDataset_errorWithTerminateMidRecord(self):
- record_defaults = [['']] * 4
- inputs = [['a,b,c,"a']]
- self._test_dataset(
- inputs,
- expected_err_re=
- 'Reached end of file without closing quoted field in record',
- record_defaults=record_defaults)
-
- def testCsvDataset_withEscapedQuotes(self):
- record_defaults = [['']] * 4
- inputs = [['1.0,2.1,"she said: ""hello""",4.3', '5.4,6.5,goodbye,8.7']]
- self._test_by_comparison(inputs, record_defaults=record_defaults)
-
-
-## Testing that parsing works with all buffer sizes, quoted/unquoted fields,
-## and different types of line breaks
-
- def testCsvDataset_withInvalidBufferSize(self):
- record_defaults = [['']] * 4
- inputs = [['a,b,c,d']]
- self._test_dataset(
- inputs,
- expected_err_re='buffer_size should be positive',
- record_defaults=record_defaults,
- buffer_size=0)
-
- def _test_dataset_on_buffer_sizes(self,
- inputs,
- expected,
- linebreak,
- record_defaults,
- compression_type=None,
- num_sizes_to_test=20):
- # Testing reading with a range of buffer sizes that should all work.
- for i in list(range(1, 1 + num_sizes_to_test)) + [None]:
- self._test_dataset(
- inputs,
- expected,
- linebreak=linebreak,
- compression_type=compression_type,
- record_defaults=record_defaults,
- buffer_size=i)
-
- def testCsvDataset_withLF(self):
- record_defaults = [['NA']] * 3
- inputs = [['abc,def,ghi', '0,1,2', ',,']]
- expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']]
- self._test_dataset_on_buffer_sizes(
- inputs, expected, linebreak='\n', record_defaults=record_defaults)
-
- def testCsvDataset_withCR(self):
- # Test that when the line separator is '\r', parsing works with all buffer
- # sizes
- record_defaults = [['NA']] * 3
- inputs = [['abc,def,ghi', '0,1,2', ',,']]
- expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']]
- self._test_dataset_on_buffer_sizes(
- inputs, expected, linebreak='\r', record_defaults=record_defaults)
-
- def testCsvDataset_withCRLF(self):
- # Test that when the line separator is '\r\n', parsing works with all buffer
- # sizes
- record_defaults = [['NA']] * 3
- inputs = [['abc,def,ghi', '0,1,2', ',,']]
- expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']]
- self._test_dataset_on_buffer_sizes(
- inputs, expected, linebreak='\r\n', record_defaults=record_defaults)
-
- def testCsvDataset_withBufferSizeAndQuoted(self):
- record_defaults = [['NA']] * 3
- inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']]
- expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'],
- ['NA', 'NA', 'NA']]
- self._test_dataset_on_buffer_sizes(
- inputs, expected, linebreak='\n', record_defaults=record_defaults)
-
- def testCsvDataset_withCRAndQuoted(self):
- # Test that when the line separator is '\r', parsing works with all buffer
- # sizes
- record_defaults = [['NA']] * 3
- inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']]
- expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'],
- ['NA', 'NA', 'NA']]
- self._test_dataset_on_buffer_sizes(
- inputs, expected, linebreak='\r', record_defaults=record_defaults)
-
- def testCsvDataset_withCRLFAndQuoted(self):
- # Test that when the line separator is '\r\n', parsing works with all buffer
- # sizes
- record_defaults = [['NA']] * 3
- inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']]
- expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'],
- ['NA', 'NA', 'NA']]
- self._test_dataset_on_buffer_sizes(
- inputs, expected, linebreak='\r\n', record_defaults=record_defaults)
-
- def testCsvDataset_withGzipCompressionType(self):
- record_defaults = [['NA']] * 3
- inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']]
- expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'],
- ['NA', 'NA', 'NA']]
- self._test_dataset_on_buffer_sizes(
- inputs,
- expected,
- linebreak='\r\n',
- compression_type='GZIP',
- record_defaults=record_defaults)
-
- def testCsvDataset_withZlibCompressionType(self):
- record_defaults = [['NA']] * 3
- inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']]
- expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'],
- ['NA', 'NA', 'NA']]
- self._test_dataset_on_buffer_sizes(
- inputs,
- expected,
- linebreak='\r\n',
- compression_type='ZLIB',
- record_defaults=record_defaults)
-
- def testCsvDataset_withScalarDefaults(self):
- record_defaults = [constant_op.constant(0, dtype=dtypes.int64)] * 4
- inputs = [[',,,', '1,1,1,', ',2,2,2']]
- self._test_dataset(
- inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]],
- record_defaults=record_defaults)
-
- def testCsvDataset_with2DDefaults(self):
- record_defaults = [constant_op.constant([[0]], dtype=dtypes.int64)] * 4
- inputs = [[',,,', '1,1,1,', ',2,2,2']]
-
- if context.executing_eagerly():
- err_spec = errors.InvalidArgumentError, (
- 'Each record default should be at '
- 'most rank 1.')
- else:
- err_spec = ValueError, 'Shape must be at most rank 1 but is rank 2'
-
- with self.assertRaisesWithPredicateMatch(*err_spec):
- self._test_dataset(
- inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]],
- record_defaults=record_defaults)
-
-
-class CsvDatasetBenchmark(test.Benchmark):
- """Benchmarks for the various ways of creating a dataset from CSV files.
- """
- FLOAT_VAL = '1.23456E12'
- STR_VAL = string.ascii_letters * 10
-
- def _setUp(self, str_val):
- # Since this isn't test.TestCase, have to manually create a test dir
- gfile.MakeDirs(googletest.GetTempDir())
- self._temp_dir = tempfile.mkdtemp(dir=googletest.GetTempDir())
-
- self._num_cols = [4, 64, 256]
- self._num_per_iter = 5000
- self._filenames = []
- for n in self._num_cols:
- fn = os.path.join(self._temp_dir, 'file%d.csv' % n)
- with open(fn, 'wb') as f:
- # Just write 100 rows and use `repeat`... Assumes the cost
- # of creating an iterator is not significant
- row = ','.join([str_val for _ in range(n)])
- f.write('\n'.join([row for _ in range(100)]))
- self._filenames.append(fn)
-
- def _tearDown(self):
- gfile.DeleteRecursively(self._temp_dir)
-
- def _runBenchmark(self, dataset, num_cols, prefix):
- dataset = dataset.skip(self._num_per_iter - 1)
- deltas = []
- for _ in range(10):
- next_element = dataset.make_one_shot_iterator().get_next()
- with session.Session() as sess:
- start = time.time()
- # NOTE: This depends on the underlying implementation of skip, to have
- # the net effect of calling `GetNext` num_per_iter times on the
- # input dataset. We do it this way (instead of a python for loop, or
- # batching N inputs in one iter) so that the overhead from session.run
- # or batch doesn't dominate. If we eventually optimize skip, this has
- # to change.
- sess.run(next_element)
- end = time.time()
- deltas.append(end - start)
- # Median wall time per CSV record read and decoded
- median_wall_time = np.median(deltas) / self._num_per_iter
- print('%s num_cols: %d Median wall time: %f' % (prefix, num_cols,
- median_wall_time))
- self.report_benchmark(
- iters=self._num_per_iter,
- wall_time=median_wall_time,
- name='%s_with_cols_%d' % (prefix, num_cols))
-
- def benchmarkMapWithFloats(self):
- self._setUp(self.FLOAT_VAL)
- for i in range(len(self._filenames)):
- num_cols = self._num_cols[i]
- kwargs = {'record_defaults': [[0.0]] * num_cols}
- dataset = core_readers.TextLineDataset(self._filenames[i]).repeat()
- dataset = dataset.map(lambda l: parsing_ops.decode_csv(l, **kwargs)) # pylint: disable=cell-var-from-loop
- self._runBenchmark(dataset, num_cols, 'csv_float_map_decode_csv')
- self._tearDown()
-
- def benchmarkMapWithStrings(self):
- self._setUp(self.STR_VAL)
- for i in range(len(self._filenames)):
- num_cols = self._num_cols[i]
- kwargs = {'record_defaults': [['']] * num_cols}
- dataset = core_readers.TextLineDataset(self._filenames[i]).repeat()
- dataset = dataset.map(lambda l: parsing_ops.decode_csv(l, **kwargs)) # pylint: disable=cell-var-from-loop
- self._runBenchmark(dataset, num_cols, 'csv_strings_map_decode_csv')
- self._tearDown()
-
- def benchmarkCsvDatasetWithFloats(self):
- self._setUp(self.FLOAT_VAL)
- for i in range(len(self._filenames)):
- num_cols = self._num_cols[i]
- kwargs = {'record_defaults': [[0.0]] * num_cols}
- dataset = core_readers.TextLineDataset(self._filenames[i]).repeat()
- dataset = readers.CsvDataset(self._filenames[i], **kwargs).repeat() # pylint: disable=cell-var-from-loop
- self._runBenchmark(dataset, num_cols, 'csv_float_fused_dataset')
- self._tearDown()
-
- def benchmarkCsvDatasetWithStrings(self):
- self._setUp(self.STR_VAL)
- for i in range(len(self._filenames)):
- num_cols = self._num_cols[i]
- kwargs = {'record_defaults': [['']] * num_cols}
- dataset = core_readers.TextLineDataset(self._filenames[i]).repeat()
- dataset = readers.CsvDataset(self._filenames[i], **kwargs).repeat() # pylint: disable=cell-var-from-loop
- self._runBenchmark(dataset, num_cols, 'csv_strings_fused_dataset')
- self._tearDown()
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py b/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py
deleted file mode 100644
index 722e87e555..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py
+++ /dev/null
@@ -1,71 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.data.python.ops import batching
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.util import nest
-from tensorflow.python.framework import dtypes
-from tensorflow.python.ops import array_ops
-from tensorflow.python.platform import test
-
-
-class DatasetConstructorTest(test_base.DatasetTestBase):
-
- def testRestructureDataset(self):
- components = (array_ops.placeholder(dtypes.int32),
- (array_ops.placeholder(dtypes.int32, shape=[None]),
- array_ops.placeholder(dtypes.int32, shape=[20, 30])))
- dataset = dataset_ops.Dataset.from_tensors(components)
-
- i32 = dtypes.int32
-
- test_cases = [((i32, i32, i32), None),
- (((i32, i32), i32), None),
- ((i32, i32, i32), (None, None, None)),
- ((i32, i32, i32), ([17], [17], [20, 30]))]
-
- for new_types, new_shape_lists in test_cases:
- # pylint: disable=protected-access
- new = batching._RestructuredDataset(dataset, new_types, new_shape_lists)
- # pylint: enable=protected-access
- self.assertEqual(new_types, new.output_types)
- if new_shape_lists is not None:
- for expected_shape_list, shape in zip(
- nest.flatten(new_shape_lists), nest.flatten(new.output_shapes)):
- if expected_shape_list is None:
- self.assertIs(None, shape.ndims)
- else:
- self.assertEqual(expected_shape_list, shape.as_list())
-
- fail_cases = [((i32, dtypes.int64, i32), None),
- ((i32, i32, i32, i32), None),
- ((i32, i32, i32), ((None, None), None)),
- ((i32, i32, i32), (None, None, None, None)),
- ((i32, i32, i32), (None, [None], [21, 30]))]
-
- for new_types, new_shape_lists in fail_cases:
- with self.assertRaises(ValueError):
- # pylint: disable=protected-access
- new = batching._RestructuredDataset(dataset, new_types, new_shape_lists)
- # pylint: enable=protected-access
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py
deleted file mode 100644
index bc10c21472..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py
+++ /dev/null
@@ -1,148 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.data.python.ops import interleave_ops
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import random_seed
-from tensorflow.python.platform import test
-
-
-class DirectedInterleaveDatasetTest(test_base.DatasetTestBase):
-
- def testBasic(self):
- selector_dataset = dataset_ops.Dataset.range(10).repeat(100)
- input_datasets = [
- dataset_ops.Dataset.from_tensors(i).repeat(100) for i in range(10)
- ]
- dataset = interleave_ops._DirectedInterleaveDataset(selector_dataset,
- input_datasets)
- iterator = dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(iterator.initializer)
- for _ in range(100):
- for i in range(10):
- self.assertEqual(i, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def _normalize(self, vec):
- return vec / vec.sum()
-
- def _chi2(self, expected, actual):
- actual = np.asarray(actual)
- expected = np.asarray(expected)
- diff = actual - expected
- chi2 = np.sum(diff * diff / expected, axis=0)
- return chi2
-
- def _testSampleFromDatasetsHelper(self, weights, num_datasets, num_samples):
- # Create a dataset that samples each integer in `[0, num_datasets)`
- # with probability given by `weights[i]`.
- dataset = interleave_ops.sample_from_datasets([
- dataset_ops.Dataset.from_tensors(i).repeat(None)
- for i in range(num_datasets)
- ], weights)
- dataset = dataset.take(num_samples)
- iterator = dataset.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- freqs = np.zeros([num_datasets])
- for _ in range(num_samples):
- freqs[sess.run(next_element)] += 1
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- return freqs
-
- def testSampleFromDatasets(self):
- random_seed.set_random_seed(1619)
- num_samples = 5000
- rand_probs = self._normalize(np.random.random_sample((15,)))
-
- # Use chi-squared test to assert that the observed distribution matches the
- # expected distribution. Based on the implementation in
- # "tensorflow/python/kernel_tests/multinomial_op_test.py".
- for probs in [[.85, .05, .1], rand_probs, [1.]]:
- probs = np.asarray(probs)
- classes = len(probs)
- freqs = self._testSampleFromDatasetsHelper(probs, classes, num_samples)
- self.assertLess(self._chi2(probs, freqs / num_samples), 1e-2)
-
- # Also check that `weights` as a dataset samples correctly.
- probs_ds = dataset_ops.Dataset.from_tensors(probs).repeat()
- freqs = self._testSampleFromDatasetsHelper(probs_ds, classes, num_samples)
- self.assertLess(self._chi2(probs, freqs / num_samples), 1e-2)
-
- def testSelectFromDatasets(self):
- words = [b"foo", b"bar", b"baz"]
- datasets = [dataset_ops.Dataset.from_tensors(w).repeat() for w in words]
- choice_array = np.random.randint(3, size=(15,), dtype=np.int64)
- choice_dataset = dataset_ops.Dataset.from_tensor_slices(choice_array)
- dataset = interleave_ops.choose_from_datasets(datasets, choice_dataset)
- iterator = dataset.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- for i in choice_array:
- self.assertEqual(words[i], sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testErrors(self):
- with self.assertRaisesRegexp(ValueError,
- r"vector of length `len\(datasets\)`"):
- interleave_ops.sample_from_datasets(
- [dataset_ops.Dataset.range(10),
- dataset_ops.Dataset.range(20)],
- weights=[0.25, 0.25, 0.25, 0.25])
-
- with self.assertRaisesRegexp(TypeError, "`tf.float32` or `tf.float64`"):
- interleave_ops.sample_from_datasets(
- [dataset_ops.Dataset.range(10),
- dataset_ops.Dataset.range(20)],
- weights=[1, 1])
-
- with self.assertRaisesRegexp(TypeError, "must have the same type"):
- interleave_ops.sample_from_datasets([
- dataset_ops.Dataset.from_tensors(0),
- dataset_ops.Dataset.from_tensors(0.0)
- ])
-
- with self.assertRaisesRegexp(TypeError, "tf.int64"):
- interleave_ops.choose_from_datasets([
- dataset_ops.Dataset.from_tensors(0),
- dataset_ops.Dataset.from_tensors(1)
- ], choice_dataset=dataset_ops.Dataset.from_tensors(1.0))
-
- with self.assertRaisesRegexp(TypeError, "scalar"):
- interleave_ops.choose_from_datasets([
- dataset_ops.Dataset.from_tensors(0),
- dataset_ops.Dataset.from_tensors(1)
- ], choice_dataset=dataset_ops.Dataset.from_tensors([1.0]))
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py
deleted file mode 100644
index 6d01bf585c..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py
+++ /dev/null
@@ -1,76 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Benchmarks FilterDataset input pipeline op."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import time
-
-import numpy as np
-
-from tensorflow.contrib.data.python.ops import optimization
-from tensorflow.python.client import session
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.platform import test
-
-
-class FilterBenchmark(test.Benchmark):
-
- # This benchmark compares the performance of pipeline with multiple chained
- # filter with and without filter fusion.
- def benchmarkFilters(self):
- chain_lengths = [0, 1, 2, 5, 10, 20, 50]
- for chain_length in chain_lengths:
- self._benchmarkFilters(chain_length, False)
- self._benchmarkFilters(chain_length, True)
-
- def _benchmarkFilters(self, chain_length, optimize_dataset):
- with ops.Graph().as_default():
- dataset = dataset_ops.Dataset.from_tensors(5).repeat(None)
- for _ in range(chain_length):
- dataset = dataset.filter(lambda x: math_ops.greater_equal(x - 5, 0))
- if optimize_dataset:
- dataset = dataset.apply(optimization.optimize(["filter_fusion"]))
-
- iterator = dataset.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- with session.Session() as sess:
- for _ in range(10):
- sess.run(next_element.op)
- deltas = []
- for _ in range(100):
- start = time.time()
- for _ in range(100):
- sess.run(next_element.op)
- end = time.time()
- deltas.append(end - start)
-
- median_wall_time = np.median(deltas) / 100
- opt_mark = "opt" if optimize_dataset else "no-opt"
- print("Filter dataset {} chain length: {} Median wall time: {}".format(
- opt_mark, chain_length, median_wall_time))
- self.report_benchmark(
- iters=1000,
- wall_time=median_wall_time,
- name="benchmark_filter_dataset_chain_latency_{}_{}".format(
- opt_mark, chain_length))
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py
deleted file mode 100644
index d4d3d4adb2..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py
+++ /dev/null
@@ -1,79 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for experimental indexed dataset ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import unittest
-
-from tensorflow.contrib.data.python.ops import indexed_dataset_ops
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
-from tensorflow.python.platform import test
-
-
-class IndexedDatasetOpsTest(test_base.DatasetTestBase):
-
- def testLowLevelIndexedDatasetOps(self):
- identity = ged_ops.experimental_identity_indexed_dataset(
- ops.convert_to_tensor(16, dtype=dtypes.uint64))
- handle = ged_ops.experimental_materialized_index_dataset_handle(
- container="",
- shared_name="",
- output_types=[dtypes.uint64],
- output_shapes=[[]])
- materialize = ged_ops.experimental_indexed_dataset_materialize(
- identity, handle)
- index = array_ops.placeholder(dtypes.uint64)
- get_op = ged_ops.experimental_indexed_dataset_get(
- handle, index, output_types=[dtypes.uint64], output_shapes=[[]])
-
- with self.cached_session() as sess:
- sess.run(materialize)
- self.assertEqual([3], sess.run(get_op, feed_dict={index: 3}))
-
- def testIdentityIndexedDataset(self):
- ds = indexed_dataset_ops.IdentityIndexedDataset(16)
- materialized = ds.materialize()
- with self.cached_session() as sess:
- sess.run(materialized.initializer)
- placeholder = array_ops.placeholder(dtypes.uint64, shape=[])
- for i in range(16):
- output = sess.run(
- materialized.get(placeholder), feed_dict={placeholder: i})
- self.assertEqual([i], output)
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(materialized.get(placeholder), feed_dict={placeholder: 16})
-
- @unittest.skip("Requisite functionality currently unimplemented.")
- def testIdentityIndexedDatasetIterator(self):
- ds = indexed_dataset_ops.IdentityIndexedDataset(16)
- itr = ds.make_initializable_iterator()
- n = itr.get_next()
- with self.cached_session() as sess:
- sess.run(itr.initializer)
- for i in range(16):
- output = sess.run(n)
- self.assertEqual(i, output)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(n)
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py
deleted file mode 100644
index 28bd670ab5..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py
+++ /dev/null
@@ -1,811 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import itertools
-import math
-import threading
-import time
-
-from six.moves import zip_longest
-
-from tensorflow.contrib.data.python.ops import interleave_ops
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import script_ops
-from tensorflow.python.ops import sparse_ops
-from tensorflow.python.platform import test
-
-
-class ParallelInterleaveDatasetTest(test_base.DatasetTestBase):
-
- def setUp(self):
-
- self.input_values = array_ops.placeholder(dtypes.int64, shape=[None])
- self.cycle_length = array_ops.placeholder(dtypes.int64, shape=[])
- self.block_length = array_ops.placeholder(dtypes.int64, shape=[])
- self.sloppy = array_ops.placeholder(dtypes.bool, shape=[])
- self.buffer_output_elements = array_ops.placeholder(dtypes.int64, shape=[])
- self.prefetch_input_elements = array_ops.placeholder(dtypes.int64, shape=[])
-
- self.error = None
- self.repeat_count = 2
-
- # Set up threading events used to sequence when items are produced that
- # are subsequently interleaved. These events allow us to deterministically
- # simulate slowdowns and force sloppiness.
- self.read_coordination_events = {}
- self.write_coordination_events = {}
- # input values [4, 5, 6] are the common case for the tests; set defaults
- for i in range(4, 7):
- self.read_coordination_events[i] = threading.Semaphore(0)
- self.write_coordination_events[i] = threading.Event()
-
- def map_py_fn(x):
- self.write_coordination_events[x].wait()
- self.write_coordination_events[x].clear()
- self.read_coordination_events[x].release()
- if self.error:
- err = self.error
- self.error = None
- raise err # pylint: disable=raising-bad-type
- return x * x
-
- def map_fn(x):
- return script_ops.py_func(map_py_fn, [x], x.dtype)
-
- def interleave_fn(x):
- dataset = dataset_ops.Dataset.from_tensors(x)
- dataset = dataset.repeat(x)
- return dataset.map(map_fn)
-
- self.dataset = (
- dataset_ops.Dataset.from_tensor_slices(self.input_values)
- .repeat(self.repeat_count).apply(
- interleave_ops.parallel_interleave(interleave_fn, self.cycle_length,
- self.block_length, self.sloppy,
- self.buffer_output_elements,
- self.prefetch_input_elements)))
- self.iterator = self.dataset.make_initializable_iterator()
- self.init_op = self.iterator.initializer
- self.next_element = self.iterator.get_next()
-
- def _interleave(self, lists, cycle_length, block_length):
- """Python implementation of interleave used for testing."""
- num_open = 0
-
- # `all_iterators` acts as a queue of iterators over each element of `lists`.
- all_iterators = [iter(l) for l in lists]
-
- # `open_iterators` are the iterators whose elements are currently being
- # interleaved.
- open_iterators = []
- for i in range(cycle_length):
- if all_iterators:
- open_iterators.append(all_iterators.pop(0))
- num_open += 1
- else:
- open_iterators.append(None)
-
- while num_open or all_iterators:
- for i in range(cycle_length):
- if open_iterators[i] is None:
- if all_iterators:
- open_iterators[i] = all_iterators.pop(0)
- num_open += 1
- else:
- continue
- for _ in range(block_length):
- try:
- yield next(open_iterators[i])
- except StopIteration:
- open_iterators[i] = None
- num_open -= 1
- break
-
- def testPythonImplementation(self):
- input_lists = [[4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6],
- [4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]]
-
- # Cycle length 1 acts like `Dataset.flat_map()`.
- expected_elements = itertools.chain(*input_lists)
- for expected, produced in zip(expected_elements,
- self._interleave(input_lists, 1, 1)):
- self.assertEqual(expected, produced)
-
- # Cycle length > 1.
- expected_elements = [
- 4, 5, 4, 5, 4, 5, 4, 5, 5, 6, 6, 4, 6, 4, 6, 4, 6, 4, 6, 5, 6, 5, 6, 5,
- 6, 5, 6, 5, 6, 6
- ]
- for index, (expected, produced) in enumerate(
- zip_longest(expected_elements, self._interleave(input_lists, 2, 1))):
- self.assertEqual(expected, produced, "Values differ at %s. %s != %s" %
- (index, expected, produced))
-
- def testPythonImplementationBlockLength(self):
- input_lists = [[4] * 4, [5] * 5, [6] * 6] * 2
- expected_elements = [
- 4, 4, 5, 5, 4, 4, 5, 5, 5, 6, 6, 4, 4, 6, 6, 4, 4, 6, 6, 5, 5, 6, 6, 5,
- 5, 6, 6, 5, 6, 6
- ]
- for index, (expected, produced) in enumerate(
- zip_longest(expected_elements, self._interleave(input_lists, 2, 2))):
- self.assertEqual(expected, produced, "Values differ at %s. %s != %s" %
- (index, expected, produced))
-
- def testPythonImplementationEmptyLists(self):
- input_lists = [[4, 4, 4, 4], [], [6, 6, 6, 6, 6, 6], [4, 4, 4, 4], [],
- [6, 6, 6, 6, 6, 6]]
-
- expected_elements = [
- 4, 4, 6, 4, 6, 4, 6, 6, 4, 6, 4, 6, 4, 4, 6, 6, 6, 6, 6, 6
- ]
- for index, (expected, produced) in enumerate(
- zip_longest(expected_elements, self._interleave(input_lists, 2, 1))):
- self.assertEqual(expected, produced, "Values differ at %s. %s != %s" %
- (index, expected, produced))
-
- def _clear_coordination_events(self):
- for i in range(4, 7):
- self.read_coordination_events[i] = threading.Semaphore(0)
- self.write_coordination_events[i].clear()
-
- def _allow_all_map_threads(self):
- for i in range(4, 7):
- self.write_coordination_events[i].set()
-
- def _testSingleThreaded(self, sloppy=False, prefetch_input_elements=0):
- # cycle_length=1,block_length=1 acts like `Dataset.interleave()` and
- # `Dataset.flat_map()` and is single-threaded. No synchronization required.
- with self.cached_session() as sess:
- self._clear_coordination_events()
- sess.run(
- self.init_op,
- feed_dict={
- self.input_values: [4, 5, 6],
- self.cycle_length: 1,
- self.block_length: 1,
- self.sloppy: sloppy,
- self.buffer_output_elements: 1,
- self.prefetch_input_elements: prefetch_input_elements,
- })
-
- for expected_element in self._interleave(
- [[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 1, 1):
- self.write_coordination_events[expected_element].set()
- self.assertEqual(expected_element * expected_element,
- sess.run(self.next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(self.next_element)
-
- def testSingleThreaded(self):
- self._testSingleThreaded()
-
- def testSingleThreadedSloppy(self):
- self._testSingleThreaded(sloppy=True)
-
- def testSingleThreadedPrefetch1Itr(self):
- self._testSingleThreaded(prefetch_input_elements=1)
-
- def testSingleThreadedPrefetch1ItrSloppy(self):
- self._testSingleThreaded(prefetch_input_elements=1, sloppy=True)
-
- def testSingleThreadedRagged(self):
- # Tests a sequence with wildly different elements per iterator.
- with self.cached_session() as sess:
- self._clear_coordination_events()
- sess.run(
- self.init_op,
- feed_dict={
- self.input_values: [3, 7, 4],
- self.cycle_length: 2,
- self.block_length: 1,
- self.sloppy: False,
- self.buffer_output_elements: 1,
- self.prefetch_input_elements: 1,
- })
-
- # Add coordination values for 3 and 7
- self.read_coordination_events[3] = threading.Semaphore(0)
- self.write_coordination_events[3] = threading.Event()
- self.read_coordination_events[7] = threading.Semaphore(0)
- self.write_coordination_events[7] = threading.Event()
-
- for expected_element in self._interleave(
- [[3] * 3, [7] * 7, [4] * 4] * self.repeat_count, 2, 1):
- self.write_coordination_events[expected_element].set()
- output = sess.run(self.next_element)
- self.assertEqual(expected_element * expected_element, output)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(self.next_element)
-
- def _testTwoThreadsNoContention(self, sloppy=False):
- # num_threads > 1.
- # Explicit coordination should result in `Dataset.interleave()` behavior
- with self.cached_session() as sess:
- self._clear_coordination_events()
- done_first_event = False
- sess.run(
- self.init_op,
- feed_dict={
- self.input_values: [4, 5, 6],
- self.cycle_length: 2,
- self.block_length: 1,
- self.sloppy: sloppy,
- self.buffer_output_elements: 1,
- self.prefetch_input_elements: 1,
- })
- for i, expected_element in enumerate(
- self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
- 1)):
- self.write_coordination_events[expected_element].set()
- if done_first_event: # First event starts the worker threads.
- self.read_coordination_events[expected_element].acquire()
- actual_element = sess.run(self.next_element)
- if not done_first_event:
- self.read_coordination_events[expected_element].acquire()
- done_first_event = True
- self.assertEqual(expected_element * expected_element, actual_element,
- "At index %s: %s expected, got: %s" %
- (i, expected_element, actual_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(self.next_element)
-
- def testTwoThreadsNoContention(self):
- self._testTwoThreadsNoContention()
-
- def testTwoThreadsNoContentionSloppy(self):
- self._testTwoThreadsNoContention(sloppy=True)
-
- def _testTwoThreadsNoContentionWithRaces(self, sloppy=False):
- """Tests where all the workers race in producing elements.
-
- Note: this is in contrast with the previous test which carefully sequences
- the execution of the map functions.
-
- Args:
- sloppy: Whether to be sloppy or not.
- """
- with self.cached_session() as sess:
- self._clear_coordination_events()
- done_first_event = False
- sess.run(
- self.init_op,
- feed_dict={
- self.input_values: [4, 5, 6],
- self.cycle_length: 2,
- self.block_length: 1,
- self.sloppy: sloppy,
- self.buffer_output_elements: 1,
- self.prefetch_input_elements: 1,
- })
- for i, expected_element in enumerate(
- self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
- 1)):
- if done_first_event: # First event starts the worker threads.
- self._allow_all_map_threads()
- self.read_coordination_events[expected_element].acquire()
- else:
- self.write_coordination_events[expected_element].set()
- time.sleep(0.5) # Sleep to consistently "avoid" the race condition.
- actual_element = sess.run(self.next_element)
- if not done_first_event:
- done_first_event = True
- self.assertTrue(
- self.read_coordination_events[expected_element].acquire(False))
- self.assertEqual(expected_element * expected_element, actual_element,
- "At index %s: %s expected, got: %s" %
- (i, expected_element, actual_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(self.next_element)
-
- def testTwoThreadsNoContentionWithRaces(self):
- self._testTwoThreadsNoContentionWithRaces()
-
- def testTwoThreadsNoContentionWithRacesSloppy(self):
- self._testTwoThreadsNoContentionWithRaces(sloppy=True)
-
- def _testTwoThreadsNoContentionBlockLength(self, sloppy=False):
- # num_threads > 1.
- # Explicit coordination should result in `Dataset.interleave()` behavior
- with self.cached_session() as sess:
- self._clear_coordination_events()
- done_first_event = False
- sess.run(
- self.init_op,
- feed_dict={
- self.input_values: [4, 5, 6],
- self.cycle_length: 2,
- self.block_length: 2,
- self.sloppy: sloppy,
- self.buffer_output_elements: 1,
- self.prefetch_input_elements: 1,
- })
- for i, expected_element in enumerate(
- self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
- 2)):
- self.write_coordination_events[expected_element].set()
- if done_first_event: # First event starts the worker threads.
- self.read_coordination_events[expected_element].acquire()
- actual_element = sess.run(self.next_element)
- if not done_first_event:
- done_first_event = True
- self.read_coordination_events[expected_element].acquire()
- self.assertEqual(expected_element * expected_element, actual_element,
- "At index %s: %s expected, got: %s" %
- (i, expected_element, actual_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(self.next_element)
-
- def testTwoThreadsNoContentionBlockLength(self):
- self._testTwoThreadsNoContentionBlockLength()
-
- def testTwoThreadsNoContentionBlockLengthSloppy(self):
- self._testTwoThreadsNoContentionBlockLength(sloppy=True)
-
- def _testTwoThreadsNoContentionWithRacesAndBlocking(self, sloppy=False):
- """Tests where all the workers race in producing elements.
-
- Note: this is in contrast with the previous test which carefully sequences
- the execution of the map functions.
-
-
- Args:
- sloppy: Whether to be sloppy or not.
- """
- with self.cached_session() as sess:
- self._clear_coordination_events()
- done_first_event = False
- sess.run(
- self.init_op,
- feed_dict={
- self.input_values: [4, 5, 6],
- self.cycle_length: 2,
- self.block_length: 2,
- self.sloppy: sloppy,
- self.buffer_output_elements: 1,
- self.prefetch_input_elements: 1,
- })
- for i, expected_element in enumerate(
- self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
- 2)):
- if done_first_event: # First event starts the worker threads.
- self._allow_all_map_threads()
- self.read_coordination_events[expected_element].acquire()
- else:
- self.write_coordination_events[expected_element].set()
- time.sleep(0.5) # Sleep to consistently "avoid" the race condition.
- actual_element = sess.run(self.next_element)
- if not done_first_event:
- done_first_event = True
- self.assertTrue(
- self.read_coordination_events[expected_element].acquire(False))
- self.assertEqual(expected_element * expected_element, actual_element,
- "At index %s: %s expected, got: %s" %
- (i, expected_element, actual_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(self.next_element)
-
- def testTwoThreadsNoContentionWithRacesAndBlocking(self):
- self._testTwoThreadsNoContentionWithRacesAndBlocking()
-
- def testTwoThreadsNoContentionWithRacesAndBlockingSloppy(self):
- self._testTwoThreadsNoContentionWithRacesAndBlocking(sloppy=True)
-
- def _testEmptyInput(self, sloppy=False):
- with self.cached_session() as sess:
- # Empty input.
- self._clear_coordination_events()
- sess.run(
- self.init_op,
- feed_dict={
- self.input_values: [],
- self.cycle_length: 2,
- self.block_length: 3,
- self.sloppy: sloppy,
- self.buffer_output_elements: 1,
- self.prefetch_input_elements: 0,
- })
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(self.next_element)
-
- def testEmptyInput(self):
- self._testEmptyInput()
-
- def testEmptyInputSloppy(self):
- self._testEmptyInput(sloppy=True)
-
- def _testNonEmptyInputIntoEmptyOutputs(self, sloppy=False):
- # Non-empty input leading to empty output.
- with self.cached_session() as sess:
- self._clear_coordination_events()
- sess.run(
- self.init_op,
- feed_dict={
- self.input_values: [0, 0, 0],
- self.cycle_length: 2,
- self.block_length: 3,
- self.sloppy: sloppy,
- self.buffer_output_elements: 1,
- self.prefetch_input_elements: 0,
- })
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(self.next_element)
-
- def testNonEmptyInputIntoEmptyOutputs(self):
- self._testNonEmptyInputIntoEmptyOutputs()
-
- def testNonEmptyInputIntoEmptyOutputsSloppy(self):
- self._testNonEmptyInputIntoEmptyOutputs(sloppy=True)
-
- def _testPartiallyEmptyOutputs(self, sloppy=False, prefetch_input_elements=1):
- race_indices = {2, 8, 14} # Sequence points when sloppy mode has race conds
- # Mixture of non-empty and empty interleaved datasets.
- with self.cached_session() as sess:
- self._clear_coordination_events()
- done_first_event = False
- sess.run(
- self.init_op,
- feed_dict={
- self.input_values: [4, 0, 6],
- self.cycle_length: 2,
- self.block_length: 1,
- self.sloppy: sloppy,
- self.buffer_output_elements: 1,
- self.prefetch_input_elements: prefetch_input_elements,
- })
- for i, expected_element in enumerate(
- self._interleave([[4] * 4, [], [6] * 6] * self.repeat_count, 2, 1)):
- self.write_coordination_events[expected_element].set()
- # First event starts the worker threads. Additionally, when running the
- # sloppy case with prefetch_input_elements=0, we get stuck if we wait
- # for the read coordination event for certain event orderings in the
- # presence of finishing iterators.
- if done_first_event and not (sloppy and (i in race_indices)):
- self.read_coordination_events[expected_element].acquire()
- actual_element = sess.run(self.next_element)
- if not done_first_event or (sloppy and (i in race_indices)):
- done_first_event = True
- self.read_coordination_events[expected_element].acquire()
- self.assertEqual(expected_element * expected_element, actual_element,
- "At index %s: %s expected, got: %s" %
- (i, expected_element, actual_element))
-
- def testPartiallyEmptyOutputs(self):
- self._testPartiallyEmptyOutputs()
-
- def testPartiallyEmptyOutputsSloppy(self):
- self._testPartiallyEmptyOutputs(sloppy=True, prefetch_input_elements=0)
-
- def testDelayedOutputSloppy(self):
- # Explicitly control the sequence of events to ensure we correctly avoid
- # head-of-line blocking.
- with self.cached_session() as sess:
- self._clear_coordination_events()
- sess.run(
- self.init_op,
- feed_dict={
- self.input_values: [4, 5, 6],
- self.cycle_length: 2,
- self.block_length: 1,
- self.sloppy: True,
- self.buffer_output_elements: 1,
- self.prefetch_input_elements: 0,
- })
-
- mis_ordering = [
- 4, 4, 5, 4, 5, 5, 4, 5, 6, 6, 6, 5, 4, 4, 6, 6, 4, 4, 6, 5, 6, 6, 6,
- 6, 5, 5, 5, 5, 6, 6
- ]
- for element in mis_ordering:
- self.write_coordination_events[element].set()
- self.assertEqual(element * element, sess.run(self.next_element))
- self.assertTrue(self.read_coordination_events[element].acquire(False))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(self.next_element)
-
- def testBlockLengthWithContentionSloppy(self):
- with self.cached_session() as sess:
- self._clear_coordination_events()
- done_first_event = False
- sess.run(
- self.init_op,
- feed_dict={
- self.input_values: [4, 5, 6],
- self.cycle_length: 2,
- self.block_length: 1,
- self.sloppy: True,
- self.buffer_output_elements: 1,
- self.prefetch_input_elements: 1,
- })
- # Test against a generating sequence that differs from the uncontended
- # case, in order to prove sloppy correctness.
- for i, expected_element in enumerate(
- self._interleave(
- [[4] * 4, [5] * 5, [6] * 6] * self.repeat_count,
- cycle_length=2,
- block_length=3)):
- self.write_coordination_events[expected_element].set()
- if done_first_event: # First event starts the worker threads.
- self.read_coordination_events[expected_element].acquire()
- actual_element = sess.run(self.next_element)
- if not done_first_event:
- self.read_coordination_events[expected_element].acquire()
- done_first_event = True
- self.assertEqual(expected_element * expected_element, actual_element,
- "At index %s: %s expected, got: %s" %
- (i, expected_element, actual_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(self.next_element)
-
- def _testEarlyExit(self, sloppy=False):
- # Exiting without consuming all input should not block
- with self.cached_session() as sess:
- self._clear_coordination_events()
- sess.run(
- self.init_op,
- feed_dict={
- self.input_values: [4, 5, 6],
- self.cycle_length: 3,
- self.block_length: 2,
- self.sloppy: sloppy,
- self.buffer_output_elements: 1,
- self.prefetch_input_elements: 0,
- })
- for i in range(4, 7):
- self.write_coordination_events[i].set()
- elem = sess.run(self.next_element) # Start all workers
- # Allow the one successful worker to progress beyond the py_func again.
- elem = int(math.sqrt(elem))
- self.write_coordination_events[elem].set()
- self.read_coordination_events[elem].acquire()
- # Allow the prefetch to succeed
- for i in range(4, 7):
- self.read_coordination_events[i].acquire()
- self.write_coordination_events[i].set()
-
- def testEarlyExit(self):
- self._testEarlyExit()
-
- def testEarlyExitSloppy(self):
- self._testEarlyExit(sloppy=True)
-
- def _testTooManyReaders(self, sloppy=False):
-
- def interleave_fn(x):
- dataset = dataset_ops.Dataset.from_tensors(x)
- dataset = dataset.repeat(math_ops.cast(x, dtype=dtypes.int64))
- return dataset
-
- dataset = dataset_ops.Dataset.from_tensor_slices([4, 5, 6])
- dataset = dataset.repeat(self.repeat_count)
- dataset = dataset.apply(
- interleave_ops.parallel_interleave(
- interleave_fn, cycle_length=16, block_length=2, sloppy=sloppy))
- iterator = dataset.make_one_shot_iterator()
-
- with self.cached_session() as sess:
- output_values = []
- for _ in range(30):
- output_values.append(sess.run(iterator.get_next()))
-
- expected_values = self._interleave(
- [[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 1, 2)
- self.assertItemsEqual(output_values, expected_values)
-
- def testTooManyReaders(self):
- self._testTooManyReaders()
-
- def testTooManyReadersSloppy(self):
- self._testTooManyReaders(sloppy=True)
-
- def testSparse(self):
- def _map_fn(i):
- return sparse_tensor.SparseTensor(
- indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2])
-
- def _interleave_fn(x):
- return dataset_ops.Dataset.from_tensor_slices(
- sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values))
-
- dataset = dataset_ops.Dataset.range(10).map(_map_fn)
- iterator = dataset.apply(
- interleave_ops.parallel_interleave(
- _interleave_fn, cycle_length=1)).make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
- for i in range(10):
- for j in range(2):
- expected = [i, 0] if j % 2 == 0 else [0, -i]
- self.assertAllEqual(expected, sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testErrorsInOutputFn(self):
- with self.cached_session() as sess:
- self._clear_coordination_events()
- sess.run(
- self.init_op,
- feed_dict={
- self.input_values: [4, 5, 6],
- self.cycle_length: 2,
- self.block_length: 1,
- self.sloppy: False,
- self.buffer_output_elements: 1,
- self.prefetch_input_elements: 0,
- })
-
- except_on_element_indices = set([3])
-
- for i, expected_element in enumerate(
- self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
- 1)):
- if i in except_on_element_indices:
- self.error = ValueError()
- self.write_coordination_events[expected_element].set()
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(self.next_element)
- else:
- self.write_coordination_events[expected_element].set()
- actual_element = sess.run(self.next_element)
- self.assertEqual(expected_element * expected_element, actual_element,
- "At index %s: %s expected, got: %s" %
- (i, expected_element, actual_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(self.next_element)
-
- def testErrorsInInputFn(self):
-
- def map_py_fn(x):
- if x == 5:
- raise ValueError()
- return x
-
- def map_fn(x):
- return script_ops.py_func(map_py_fn, [x], x.dtype)
-
- def interleave_fn(x):
- dataset = dataset_ops.Dataset.from_tensors(x)
- dataset = dataset.repeat(x)
- return dataset
-
- self.dataset = (
- dataset_ops.Dataset.from_tensor_slices(self.input_values).map(map_fn)
- .repeat(self.repeat_count).apply(
- interleave_ops.parallel_interleave(interleave_fn, self.cycle_length,
- self.block_length, self.sloppy,
- self.buffer_output_elements,
- self.prefetch_input_elements)))
-
- self.iterator = self.dataset.make_initializable_iterator()
- self.init_op = self.iterator.initializer
- self.next_element = self.iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(
- self.init_op,
- feed_dict={
- self.input_values: [4, 5, 6],
- self.cycle_length: 2,
- self.block_length: 1,
- self.sloppy: False,
- self.buffer_output_elements: 1,
- self.prefetch_input_elements: 0,
- })
- for i, expected_element in enumerate(
- self._interleave([[4] * 4, [5], [6] * 6] * self.repeat_count, 2, 1)):
- if expected_element == 5:
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(self.next_element)
- else:
- actual_element = sess.run(self.next_element)
- self.assertEqual(expected_element, actual_element,
- "At index %s: %s expected, got: %s" %
- (i, expected_element, actual_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(self.next_element)
-
- def testErrorsInInterleaveFn(self):
-
- def map_py_fn(x):
- if x == 5:
- raise ValueError()
- return x
-
- def interleave_fn(x):
- dataset = dataset_ops.Dataset.from_tensors(x)
- y = script_ops.py_func(map_py_fn, [x], x.dtype)
- dataset = dataset.repeat(y)
- return dataset
-
- self.dataset = (
- dataset_ops.Dataset.from_tensor_slices(self.input_values)
- .repeat(self.repeat_count).apply(
- interleave_ops.parallel_interleave(interleave_fn, self.cycle_length,
- self.block_length, self.sloppy,
- self.buffer_output_elements,
- self.prefetch_input_elements)))
-
- self.iterator = self.dataset.make_initializable_iterator()
- self.init_op = self.iterator.initializer
- self.next_element = self.iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(
- self.init_op,
- feed_dict={
- self.input_values: [4, 5, 6],
- self.cycle_length: 2,
- self.block_length: 1,
- self.sloppy: False,
- self.buffer_output_elements: 1,
- self.prefetch_input_elements: 0,
- })
- for i, expected_element in enumerate(
- self._interleave([[4] * 4, [5], [6] * 6] * self.repeat_count, 2, 1)):
- if expected_element == 5:
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(self.next_element)
- else:
- actual_element = sess.run(self.next_element)
- self.assertEqual(expected_element, actual_element,
- "At index %s: %s expected, got: %s" %
- (i, expected_element, actual_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(self.next_element)
-
- def testShutdownRace(self):
- dataset = dataset_ops.Dataset.range(20)
- map_fn = lambda x: dataset_ops.Dataset.range(20 * x, 20 * (x + 1))
- dataset = dataset.apply(
- interleave_ops.parallel_interleave(
- map_fn,
- cycle_length=3,
- sloppy=False,
- buffer_output_elements=1,
- prefetch_input_elements=0))
- dataset = dataset.batch(32)
- iterator = dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- results = []
- with self.cached_session() as sess:
- for _ in range(2):
- elements = []
- sess.run(iterator.initializer)
- try:
- while True:
- elements.extend(sess.run(next_element))
- except errors.OutOfRangeError:
- pass
- results.append(elements)
-
- self.assertAllEqual(results[0], results[1])
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
deleted file mode 100644
index 58a1d7c93b..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
+++ /dev/null
@@ -1,125 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for experimental iterator_ops."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.data.python.ops import iterator_ops
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.estimator import estimator
-from tensorflow.python.estimator import model_fn
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import variables
-from tensorflow.python.platform import test
-from tensorflow.python.training import checkpoint_management
-from tensorflow.python.training import saver as saver_lib
-from tensorflow.python.training import training_util
-
-
-class CheckpointInputPipelineHookTest(test_base.DatasetTestBase):
-
- @staticmethod
- def _model_fn(features, labels, mode, config):
- del labels
- del mode
- del config
- global_step = training_util.get_or_create_global_step()
- update_global_step_op = global_step.assign_add(1)
- latest_feature = variables.VariableV1(
- 0, name='latest_feature', dtype=dtypes.int64)
- store_latest_feature_op = latest_feature.assign(features)
- ops.add_to_collection('my_vars', global_step)
- ops.add_to_collection('my_vars', latest_feature)
- return model_fn.EstimatorSpec(
- mode='train',
- train_op=control_flow_ops.group(
- [update_global_step_op, store_latest_feature_op]),
- loss=constant_op.constant(2.0))
-
- def _read_vars(self, model_dir):
- """Returns (global_step, latest_feature)."""
- with ops.Graph().as_default() as g:
- ckpt_path = checkpoint_management.latest_checkpoint(model_dir)
- meta_filename = ckpt_path + '.meta'
- saver_lib.import_meta_graph(meta_filename)
- saver = saver_lib.Saver()
- with self.session(graph=g) as sess:
- saver.restore(sess, ckpt_path)
- return sess.run(ops.get_collection('my_vars'))
-
- def _build_iterator_saver_hook(self, est):
- return iterator_ops.CheckpointInputPipelineHook(est)
-
- def testReturnDatasetFromInputFn(self):
-
- def _input_fn():
- return dataset_ops.Dataset.range(10)
-
- est = estimator.Estimator(model_fn=self._model_fn)
-
- est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
- self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1))
- est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
- self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3))
-
- def testBuildIteratorInInputFn(self):
-
- def _input_fn():
- ds = dataset_ops.Dataset.range(10)
- iterator = ds.make_one_shot_iterator()
- return iterator.get_next()
-
- est = estimator.Estimator(model_fn=self._model_fn)
-
- est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
- self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1))
- est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
- self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3))
-
- def testDoNotRestore(self):
-
- def _input_fn():
- return dataset_ops.Dataset.range(10)
-
- est = estimator.Estimator(model_fn=self._model_fn)
-
- est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
- self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1))
- est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
- self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3))
- # Hook not provided, input pipeline was not restored.
- est.train(_input_fn, steps=2)
- self.assertSequenceEqual(self._read_vars(est.model_dir), (6, 1))
-
- def testRaiseErrorIfNoIterator(self):
-
- def _input_fn():
- return constant_op.constant(1, dtype=dtypes.int64)
-
- est = estimator.Estimator(model_fn=self._model_fn)
-
- with self.assertRaises(ValueError):
- est.train(
- _input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
deleted file mode 100644
index 385c4ef6ea..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
+++ /dev/null
@@ -1,359 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import hashlib
-import itertools
-import os
-import time
-
-import numpy as np
-
-from tensorflow.contrib.data.python.ops import batching
-from tensorflow.contrib.data.python.ops import error_ops
-from tensorflow.contrib.data.python.ops import optimization
-from tensorflow.core.protobuf import config_pb2
-from tensorflow.python.client import session
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import io_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.platform import test
-from tensorflow.python.util import compat
-
-_NUMPY_RANDOM_SEED = 42
-
-
-class MapDatasetTest(test_base.DatasetTestBase):
-
- def testMapIgnoreError(self):
- components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
-
- dataset = (
- dataset_ops.Dataset.from_tensor_slices(components)
- .map(lambda x: array_ops.check_numerics(x, "message")).apply(
- error_ops.ignore_errors()))
- iterator = dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
- for x in [1., 2., 3., 5.]:
- self.assertEqual(x, sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testParallelMapIgnoreError(self):
- components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
-
- dataset = (
- dataset_ops.Dataset.from_tensor_slices(components).map(
- lambda x: array_ops.check_numerics(x, "message"),
- num_parallel_calls=2).prefetch(2).apply(error_ops.ignore_errors()))
- iterator = dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
- for x in [1., 2., 3., 5.]:
- self.assertEqual(x, sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testReadFileIgnoreError(self):
-
- def write_string_to_file(value, filename):
- with open(filename, "w") as f:
- f.write(value)
-
- filenames = [
- os.path.join(self.get_temp_dir(), "file_%d.txt" % i) for i in range(5)
- ]
- for filename in filenames:
- write_string_to_file(filename, filename)
-
- dataset = (
- dataset_ops.Dataset.from_tensor_slices(filenames).map(
- io_ops.read_file,
- num_parallel_calls=2).prefetch(2).apply(error_ops.ignore_errors()))
- iterator = dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- # All of the files are present.
- sess.run(init_op)
- for filename in filenames:
- self.assertEqual(compat.as_bytes(filename), sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Delete one of the files.
- os.remove(filenames[0])
-
- # Attempting to read filenames[0] will fail, but ignore_errors()
- # will catch the error.
- sess.run(init_op)
- for filename in filenames[1:]:
- self.assertEqual(compat.as_bytes(filename), sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testCaptureResourceInMapFn(self):
-
- def _build_ds(iterator):
-
- def _map_fn(x):
- get_next = iterator.get_next()
- return x * get_next
-
- return dataset_ops.Dataset.range(10).map(_map_fn)
-
- def _build_graph():
- captured_iterator = dataset_ops.Dataset.range(
- 10).make_initializable_iterator()
- ds = _build_ds(captured_iterator)
- iterator = ds.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
- return captured_iterator.initializer, init_op, get_next
-
- with ops.Graph().as_default() as g:
- captured_init_op, init_op, get_next = _build_graph()
- with self.session(graph=g) as sess:
- sess.run(captured_init_op)
- sess.run(init_op)
- for i in range(10):
- self.assertEquals(i * i, sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
-
-class MapDatasetBenchmark(test.Benchmark):
-
- # The purpose of this benchmark is to compare the performance of chaining vs
- # fusing of the map and batch transformations across various configurations.
- #
- # NOTE: It is recommended to build the benchmark with
- # `-c opt --copt=-mavx --copt=-mavx2 --copt=-mfma --copt=-gmlt`
- # and execute it on a machine with at least 32 CPU cores.
- def benchmarkMapAndBatch(self):
-
- # Sequential pipeline configurations.
- seq_elem_size_series = itertools.product([1], [1], [1, 2, 4, 8], [16])
- seq_batch_size_series = itertools.product([1], [1], [1], [8, 16, 32, 64])
-
- # Parallel pipeline configuration.
- par_elem_size_series = itertools.product([32], [32], [1, 2, 4, 8], [256])
- par_batch_size_series = itertools.product([32], [32], [1],
- [128, 256, 512, 1024])
- par_num_calls_series = itertools.product([8, 16, 32, 64], [32], [1], [512])
- par_inter_op_series = itertools.product([32], [8, 16, 32, 64], [1], [512])
-
- def name(method, label, num_calls, inter_op, element_size, batch_size):
- return ("%s_id_%s_num_calls_%d_inter_op_%d_elem_size_%d_batch_size_%d" % (
- method,
- hashlib.sha1(label).hexdigest(),
- num_calls,
- inter_op,
- element_size,
- batch_size,
- ))
-
- def benchmark(label, series):
-
- print("%s:" % label)
- for num_calls, inter_op, element_size, batch_size in series:
-
- num_iters = 1024 // (
- (element_size * batch_size) // min(num_calls, inter_op))
- k = 1024 * 1024
- dataset = dataset_ops.Dataset.from_tensors((np.random.rand(
- element_size, 4 * k), np.random.rand(4 * k, 1))).repeat()
-
- chained_dataset = dataset.map(
- math_ops.matmul,
- num_parallel_calls=num_calls).batch(batch_size=batch_size)
- chained_iterator = chained_dataset.make_one_shot_iterator()
- chained_get_next = chained_iterator.get_next()
-
- chained_deltas = []
- with session.Session(
- config=config_pb2.ConfigProto(
- inter_op_parallelism_threads=inter_op,
- use_per_session_threads=True)) as sess:
- for _ in range(5):
- sess.run(chained_get_next.op)
- for _ in range(num_iters):
- start = time.time()
- sess.run(chained_get_next.op)
- end = time.time()
- chained_deltas.append(end - start)
-
- fused_dataset = dataset.apply(
- batching.map_and_batch(
- math_ops.matmul,
- num_parallel_calls=num_calls,
- batch_size=batch_size))
- fused_iterator = fused_dataset.make_one_shot_iterator()
- fused_get_next = fused_iterator.get_next()
-
- fused_deltas = []
- with session.Session(
- config=config_pb2.ConfigProto(
- inter_op_parallelism_threads=inter_op,
- use_per_session_threads=True)) as sess:
-
- for _ in range(5):
- sess.run(fused_get_next.op)
- for _ in range(num_iters):
- start = time.time()
- sess.run(fused_get_next.op)
- end = time.time()
- fused_deltas.append(end - start)
-
- print(
- "batch size: %d, num parallel calls: %d, inter-op parallelism: %d, "
- "element size: %d, num iters: %d\nchained wall time: %f (median), "
- "%f (mean), %f (stddev), %f (min), %f (max)\n fused wall time: "
- "%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n "
- "chained/fused: %.2fx (median), %.2fx (mean)" %
- (batch_size, num_calls, inter_op, element_size, num_iters,
- np.median(chained_deltas), np.mean(chained_deltas),
- np.std(chained_deltas), np.min(chained_deltas),
- np.max(chained_deltas), np.median(fused_deltas),
- np.mean(fused_deltas), np.std(fused_deltas), np.min(fused_deltas),
- np.max(fused_deltas),
- np.median(chained_deltas) / np.median(fused_deltas),
- np.mean(chained_deltas) / np.mean(fused_deltas)))
-
- self.report_benchmark(
- iters=num_iters,
- wall_time=np.median(chained_deltas),
- name=name("chained", label, num_calls, inter_op, element_size,
- batch_size))
-
- self.report_benchmark(
- iters=num_iters,
- wall_time=np.median(fused_deltas),
- name=name("fused", label, num_calls, inter_op, element_size,
- batch_size))
-
- print("")
-
- np.random.seed(_NUMPY_RANDOM_SEED)
- benchmark("Sequential element size evaluation", seq_elem_size_series)
- benchmark("Sequential batch size evaluation", seq_batch_size_series)
- benchmark("Parallel element size evaluation", par_elem_size_series)
- benchmark("Parallel batch size evaluation", par_batch_size_series)
- benchmark("Transformation parallelism evaluation", par_num_calls_series)
- benchmark("Threadpool size evaluation", par_inter_op_series)
-
- # This benchmark compares the performance of pipeline with multiple chained
- # maps with and without map fusion.
- def benchmarkChainOfMaps(self):
- chain_lengths = [0, 1, 2, 5, 10, 20, 50]
- for chain_length in chain_lengths:
- self._benchmarkChainOfMaps(chain_length, False)
- self._benchmarkChainOfMaps(chain_length, True)
-
- def _benchmarkChainOfMaps(self, chain_length, optimize_dataset):
- with ops.Graph().as_default():
- dataset = dataset_ops.Dataset.from_tensors(0).repeat(None)
- for _ in range(chain_length):
- dataset = dataset.map(lambda x: x)
- if optimize_dataset:
- dataset = dataset.apply(optimization.optimize(["map_fusion"]))
-
- iterator = dataset.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- with session.Session() as sess:
- for _ in range(5):
- sess.run(next_element.op)
- deltas = []
- for _ in range(100):
- start = time.time()
- for _ in range(100):
- sess.run(next_element.op)
- end = time.time()
- deltas.append(end - start)
-
- median_wall_time = np.median(deltas) / 100
- opt_mark = "opt" if optimize_dataset else "no-opt"
- print("Map dataset {} chain length: {} Median wall time: {}".format(
- opt_mark, chain_length, median_wall_time))
- self.report_benchmark(
- iters=1000,
- wall_time=median_wall_time,
- name="benchmark_map_dataset_chain_latency_{}_{}".format(
- opt_mark, chain_length))
-
-
-class MapAndFilterBenchmark(test.Benchmark):
-
- # This benchmark compares the performance of pipeline with multiple chained
- # map + filter with and without map fusion.
- def benchmarkMapAndFilter(self):
- chain_lengths = [0, 1, 2, 5, 10, 20, 50]
- for chain_length in chain_lengths:
- self._benchmarkMapAndFilter(chain_length, False)
- self._benchmarkMapAndFilter(chain_length, True)
-
- def _benchmarkMapAndFilter(self, chain_length, optimize_dataset):
- with ops.Graph().as_default():
- dataset = dataset_ops.Dataset.from_tensors(0).repeat(None)
- for _ in range(chain_length):
- dataset = dataset.map(lambda x: x + 5).filter(
- lambda x: math_ops.greater_equal(x - 5, 0))
- if optimize_dataset:
- dataset = dataset.apply(
- optimization.optimize(["map_and_filter_fusion"]))
-
- iterator = dataset.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- with session.Session() as sess:
- for _ in range(10):
- sess.run(next_element.op)
- deltas = []
- for _ in range(100):
- start = time.time()
- for _ in range(100):
- sess.run(next_element.op)
- end = time.time()
- deltas.append(end - start)
-
- median_wall_time = np.median(deltas) / 100
- opt_mark = "opt" if optimize_dataset else "no-opt"
- print("Map and filter dataset {} chain length: {} Median wall time: {}".
- format(opt_mark, chain_length, median_wall_time))
- self.report_benchmark(
- iters=1000,
- wall_time=median_wall_time,
- name="benchmark_map_and_filter_dataset_chain_latency_{}_{}".format(
- opt_mark, chain_length))
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
deleted file mode 100644
index 751e6d5b30..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
+++ /dev/null
@@ -1,281 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for MapDefunOp."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import time
-
-from tensorflow.contrib.data.python.ops import map_defun
-from tensorflow.python.client import session
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import function
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import check_ops
-from tensorflow.python.ops import data_flow_ops
-from tensorflow.python.ops import functional_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.platform import test
-
-
-class MapDefunTest(test_base.DatasetTestBase):
-
- def testMapDefunSimple(self):
-
- @function.Defun(dtypes.int32)
- def simple_fn(x):
- return x * 2 + 3
-
- nums = [[1, 2], [3, 4], [5, 6]]
- elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
- r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(2,)])[0]
- expected = elems * 2 + 3
- self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
-
- def testMapDefunMismatchedTypes(self):
-
- @function.Defun(dtypes.int32)
- def fn(x):
- return math_ops.cast(x, dtypes.float64)
-
- nums = [1, 2, 3, 4, 5, 6]
- elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
- r = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])[0]
- with self.assertRaises(errors.InvalidArgumentError):
- self.evaluate(r)
-
- def testMapDefunReduceDim(self):
- # Tests where the output has a different rank from the input
-
- @function.Defun(dtypes.int32)
- def fn(x):
- return array_ops.gather(x, 0)
-
- nums = [[1, 2], [3, 4], [5, 6]]
- elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
- r = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])[0]
- expected = constant_op.constant([1, 3, 5])
- self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
-
- def testMapDefunMultipleOutputs(self):
-
- @function.Defun(dtypes.int32)
- def fn(x):
- return (x, math_ops.cast(x * 2 + 3, dtypes.float64))
-
- nums = [[1, 2], [3, 4], [5, 6]]
- elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
- r = map_defun.map_defun(fn, [elems], [dtypes.int32, dtypes.float64], [(2,),
- (2,)])
- expected = [elems, elems * 2 + 3]
- self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
-
- def testMapDefunShapeInference(self):
-
- @function.Defun(dtypes.int32)
- def fn(x):
- return x
-
- nums = [[1, 2], [3, 4], [5, 6]]
- elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
- result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2,)])[0]
- self.assertEqual(result.get_shape(), (3, 2))
-
- def testMapDefunPartialShapeInference(self):
-
- @function.Defun(dtypes.int32)
- def fn(x):
- return x
-
- elems = array_ops.placeholder(dtypes.int64, (None, 2))
- result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2,)])
- self.assertEqual(result[0].get_shape().as_list(), [None, 2])
-
- def testMapDefunRaisesErrorOnRuntimeShapeMismatch(self):
-
- @function.Defun(dtypes.int32, dtypes.int32)
- def fn(x, y):
- return x, y
-
- elems1 = array_ops.placeholder(dtypes.int32)
- elems2 = array_ops.placeholder(dtypes.int32)
- result = map_defun.map_defun(fn, [elems1, elems2],
- [dtypes.int32, dtypes.int32], [(), ()])
- with self.cached_session() as sess:
- with self.assertRaisesWithPredicateMatch(
- errors.InvalidArgumentError,
- "All inputs must have the same dimension 0."):
- sess.run(result, feed_dict={elems1: [1, 2, 3, 4, 5], elems2: [1, 2, 3]})
-
- def testMapDefunRaisesDefunError(self):
-
- @function.Defun(dtypes.int32)
- def fn(x):
- with ops.control_dependencies([check_ops.assert_equal(x, 0)]):
- return array_ops.identity(x)
-
- elems = constant_op.constant([0, 0, 0, 37, 0])
- result = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])
- with self.assertRaises(errors.InvalidArgumentError):
- self.evaluate(result)
-
- def testMapDefunCancelledCorrectly(self):
-
- @function.Defun(dtypes.int64)
- def defun(x):
- # x has leading dimension 5, this will raise an error
- return array_ops.gather(x, 10)
-
- c = array_ops.tile(
- array_ops.expand_dims(
- constant_op.constant([1, 2, 3, 4, 5], dtype=dtypes.int64), 0),
- [100, 1])
- map_defun_op = map_defun.map_defun(defun, [c], [dtypes.int64], [()])[0]
- with self.assertRaisesRegexp(errors.InvalidArgumentError,
- r"indices = 10 is not in \[0, 5\)"):
- self.evaluate(map_defun_op)
-
- def testMapDefunWithUnspecifiedOutputShape(self):
-
- @function.Defun(dtypes.int32)
- def simple_fn(x):
- res = x * 2 + 3
- return (res, res + 1, res + 2)
-
- nums = [[1, 2], [3, 4], [5, 6]]
- elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
- r = map_defun.map_defun(simple_fn, [elems],
- [dtypes.int32, dtypes.int32, dtypes.int32],
- [None, (None,), (2,)])
- expected = elems * 2 + 3
- self.assertAllEqual(self.evaluate(r[0]), self.evaluate(expected))
- self.assertAllEqual(self.evaluate(r[1]), self.evaluate(expected + 1))
- self.assertAllEqual(self.evaluate(r[2]), self.evaluate(expected + 2))
-
- def testMapDefunWithDifferentOutputShapeEachRun(self):
-
- @function.Defun(dtypes.int32)
- def simple_fn(x):
- return x * 2 + 3
-
- elems = array_ops.placeholder(dtypes.int32, name="data")
- r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [None])[0]
- with session.Session() as sess:
- self.assertAllEqual(sess.run(r, feed_dict={elems: [0]}), [3])
- self.assertAllEqual(
- sess.run(r, feed_dict={elems: [[0], [1]]}), [[3], [5]])
-
- def testMapDefunWithWrongOutputShape(self):
-
- @function.Defun(dtypes.int32)
- def simple_fn(x):
- return x * 2 + 3
-
- nums = [[1, 2], [3, 4], [5, 6]]
- elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
- r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(1,)])[0]
- with self.assertRaises(errors.InvalidArgumentError):
- self.evaluate(r)
-
- def testMapDefunWithInvalidInput(self):
-
- @function.Defun(dtypes.int32)
- def simple_fn(x):
- return x * 2
-
- c = constant_op.constant(2)
- with self.assertRaises(ValueError):
- # Fails at graph construction time for inputs with known shapes.
- r = map_defun.map_defun(simple_fn, [c], [dtypes.int32], [None])[0]
- p = array_ops.placeholder(dtypes.int32)
- r = map_defun.map_defun(simple_fn, [p], [dtypes.int32], [None])[0]
- with session.Session() as sess:
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(r, feed_dict={p: 0})
-
- def _assert_op_cancelled(self, sess, map_defun_op):
- with self.assertRaisesRegexp(errors.CancelledError, "was cancelled"):
- sess.run(map_defun_op)
-
- def testMapDefunWithParentCancellation(self):
- # Checks that a cancellation of the parent graph is threaded through to
- # MapDefunOp correctly.
- @function.Defun(dtypes.int32)
- def simple_fn(x):
- del x
- queue = data_flow_ops.FIFOQueue(10, dtypes.int32, ())
- # Blocking
- return queue.dequeue_many(5)
-
- c = constant_op.constant([1, 2, 3, 4, 5])
- map_defun_op = map_defun.map_defun(simple_fn, [c], [dtypes.int32], [()])[0]
-
- with self.cached_session() as sess:
- thread = self.checkedThread(
- self._assert_op_cancelled, args=(sess, map_defun_op))
- thread.start()
- time.sleep(0.1)
- sess.close()
- thread.join()
-
-
-class MapDefunBenchmark(test.Benchmark):
-
- def _run(self, op, name=None, num_iters=3000):
- with session.Session() as sess:
- # Warm up the session
- for _ in range(5):
- sess.run(op)
- start = time.time()
- for _ in range(num_iters):
- sess.run(op)
- end = time.time()
- mean_us = (end - start) * 1e6 / num_iters
- self.report_benchmark(
- name=name,
- iters=num_iters,
- wall_time=mean_us,
- extras={"examples_per_sec": num_iters / (end - start)})
-
- def benchmarkDefunVsMapFn(self):
- """Benchmarks to compare the performance of MapDefun vs tf.map_fn."""
-
- @function.Defun(dtypes.int32)
- def defun(x):
- return array_ops.identity(x)
-
- def map_fn(x):
- return array_ops.identity(x)
-
- base = math_ops.range(100)
- for input_size in [10, 100, 1000, 10000]:
- num_iters = 100000 // input_size
- map_defun_op = map_defun.map_defun(defun, [base], [dtypes.int32], [()])
- map_fn_op = functional_ops.map_fn(map_fn, base)
-
- self._run(
- map_defun_op,
- "benchmarkMapDefun_size_%d" % input_size,
- num_iters=num_iters)
- self._run(
- map_fn_op, "benchmarkMapFn_size_%d" % input_size, num_iters=num_iters)
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD b/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD
deleted file mode 100644
index d7b5edcd9a..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD
+++ /dev/null
@@ -1,164 +0,0 @@
-package(default_visibility = ["//tensorflow:internal"])
-
-licenses(["notice"]) # Apache 2.0
-
-exports_files(["LICENSE"])
-
-load("//tensorflow:tensorflow.bzl", "py_test")
-
-py_test(
- name = "assert_next_dataset_op_test",
- size = "medium",
- srcs = ["assert_next_dataset_op_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/data/python/ops:optimization",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:errors",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- ],
-)
-
-py_test(
- name = "hoist_random_uniform_test",
- size = "small",
- srcs = ["hoist_random_uniform_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/data/python/ops:optimization",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:math_ops",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "@absl_py//absl/testing:parameterized",
- ],
-)
-
-py_test(
- name = "latency_all_edges_test",
- size = "small",
- srcs = ["latency_all_edges_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/data/python/kernel_tests:stats_dataset_test_base",
- "//tensorflow/contrib/data/python/ops:optimization",
- "//tensorflow/contrib/data/python/ops:stats_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:errors",
- "//tensorflow/python/data/ops:dataset_ops",
- ],
-)
-
-py_test(
- name = "map_vectorization_test",
- size = "small",
- srcs = ["map_vectorization_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/data/python/ops:optimization",
- "//tensorflow/python:check_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:session",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- "@absl_py//absl/testing:parameterized",
- ],
-)
-
-py_test(
- name = "map_and_filter_fusion_test",
- size = "medium",
- srcs = ["map_and_filter_fusion_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/data/python/ops:optimization",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:math_ops",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "@absl_py//absl/testing:parameterized",
- ],
-)
-
-py_test(
- name = "map_parallelization_test",
- size = "small",
- srcs = ["map_parallelization_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/data/python/ops:optimization",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:math_ops",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "@absl_py//absl/testing:parameterized",
- ],
-)
-
-py_test(
- name = "model_dataset_op_test",
- size = "medium",
- srcs = ["model_dataset_op_test.py"],
- srcs_version = "PY2AND3",
- tags = [
- "optonly",
- ],
- deps = [
- "//tensorflow/contrib/data/python/ops:batching",
- "//tensorflow/contrib/data/python/ops:interleave_ops",
- "//tensorflow/contrib/data/python/ops:optimization",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:errors",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "noop_elimination_test",
- size = "small",
- srcs = ["noop_elimination_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/data/python/ops:batching",
- "//tensorflow/contrib/data/python/ops:interleave_ops",
- "//tensorflow/contrib/data/python/ops:optimization",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:errors",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "optimize_dataset_op_test",
- size = "small",
- srcs = ["optimize_dataset_op_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/data/python/ops:optimization",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:errors",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py
deleted file mode 100644
index fe1b5280ba..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py
+++ /dev/null
@@ -1,65 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.data.python.ops import optimization
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import errors
-from tensorflow.python.platform import test
-
-
-class AssertNextDatasetTest(test_base.DatasetTestBase):
-
- def testAssertNext(self):
- dataset = dataset_ops.Dataset.from_tensors(0).apply(
- optimization.assert_next(["Map"])).map(lambda x: x)
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- self.assertEqual(0, sess.run(get_next))
-
- def testAssertNextInvalid(self):
- dataset = dataset_ops.Dataset.from_tensors(0).apply(
- optimization.assert_next(["Whoops"])).map(lambda x: x)
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- with self.assertRaisesRegexp(
- errors.InvalidArgumentError,
- "Asserted Whoops transformation at offset 0 but encountered "
- "Map transformation instead."):
- sess.run(get_next)
-
- def testAssertNextShort(self):
- dataset = dataset_ops.Dataset.from_tensors(0).apply(
- optimization.assert_next(["Map", "Whoops"])).map(lambda x: x)
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- with self.assertRaisesRegexp(
- errors.InvalidArgumentError,
- "Asserted next 2 transformations but encountered only 1."):
- sess.run(get_next)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/hoist_random_uniform_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/hoist_random_uniform_test.py
deleted file mode 100644
index b43efb5c7c..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/hoist_random_uniform_test.py
+++ /dev/null
@@ -1,103 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for HostState optimization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from absl.testing import parameterized
-
-from tensorflow.contrib.data.python.ops import optimization
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import random_ops
-from tensorflow.python.platform import test
-
-
-class HoistRandomUniformTest(test_base.DatasetTestBase, parameterized.TestCase):
-
- @staticmethod
- def map_functions():
- plus_one = lambda x: x + 1
-
- def random(_):
- return random_ops.random_uniform([],
- minval=1,
- maxval=10,
- dtype=dtypes.float32,
- seed=42)
-
- def random_with_assert(x):
- y = random(x)
- assert_op = control_flow_ops.Assert(math_ops.greater_equal(y, 1), [y])
- with ops.control_dependencies([assert_op]):
- return y
-
- twice_random = lambda x: (random(x) + random(x)) / 2.
-
- tests = [("PlusOne", plus_one, False), ("RandomUniform", random, True),
- ("RandomWithAssert", random_with_assert, True),
- ("TwiceRandom", twice_random, False)]
- return tuple(tests)
-
- @parameterized.named_parameters(*map_functions.__func__())
- def testHoisting(self, function, will_optimize):
- dataset = dataset_ops.Dataset.range(5).apply(
- optimization.assert_next(
- ["Zip[0]", "Map"] if will_optimize else ["Map"])).map(function)
-
- dataset = dataset.apply(optimization.optimize(["hoist_random_uniform"]))
- self._testDataset(dataset)
-
- def testAdditionalInputs(self):
- a = constant_op.constant(1, dtype=dtypes.float32)
- b = constant_op.constant(0, dtype=dtypes.float32)
- some_tensor = math_ops.mul(a, b)
-
- def random_with_capture(_):
- return some_tensor + random_ops.random_uniform(
- [], minval=1, maxval=10, dtype=dtypes.float32, seed=42)
-
- dataset = dataset_ops.Dataset.range(5).apply(
- optimization.assert_next(
- ["Zip[0]", "Map"])).map(random_with_capture).apply(
- optimization.optimize(["hoist_random_uniform"]))
- self._testDataset(dataset)
-
- def _testDataset(self, dataset):
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
- previous_result = 0
- with self.cached_session() as sess:
- for _ in range(5):
- result = sess.run(get_next)
- self.assertLessEqual(1, result)
- self.assertLessEqual(result, 10)
- # This checks if the result is somehow random by checking if we are not
- # generating the same values.
- self.assertNotEqual(previous_result, result)
- previous_result = result
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py
deleted file mode 100644
index e4f18222fd..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py
+++ /dev/null
@@ -1,58 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the LatencyAllEdges optimization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.data.python.kernel_tests import stats_dataset_test_base
-from tensorflow.contrib.data.python.ops import optimization
-from tensorflow.contrib.data.python.ops import stats_ops
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import errors
-from tensorflow.python.platform import test
-
-
-class OptimizeStatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
-
- def testLatencyStatsOptimization(self):
-
- stats_aggregator = stats_ops.StatsAggregator()
- dataset = dataset_ops.Dataset.from_tensors(1).apply(
- optimization.assert_next(
- ["LatencyStats", "Map", "LatencyStats", "Prefetch",
- "LatencyStats"])).map(lambda x: x * x).prefetch(1).apply(
- stats_ops.set_stats_aggregator(stats_aggregator)).apply(
- optimization.optimize(["latency_all_edges"]))
- iterator = dataset.make_initializable_iterator()
- get_next = iterator.get_next()
- summary_t = stats_aggregator.get_summary()
-
- with self.cached_session() as sess:
- sess.run(iterator.initializer)
- self.assertEqual(1 * 1, sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
- summary_str = sess.run(summary_t)
- self._assertSummaryHasCount(summary_str,
- "record_latency_TensorDataset/_1", 1)
- self._assertSummaryHasCount(summary_str, "record_latency_MapDataset/_4",
- 1)
- self._assertSummaryHasCount(summary_str,
- "record_latency_PrefetchDataset/_6", 1)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py
deleted file mode 100644
index e9e3fc81e5..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py
+++ /dev/null
@@ -1,225 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the MapAndFilterFusion optimization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from absl.testing import parameterized
-
-from tensorflow.contrib.data.python.ops import optimization
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.ops import math_ops
-from tensorflow.python.platform import test
-
-
-class MapAndFilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
-
- @staticmethod
- def map_functions():
- identity = lambda x: x
- increment = lambda x: x + 1
-
- def increment_and_square(x):
- y = x + 1
- return y * y
-
- functions = [identity, increment, increment_and_square]
- tests = []
- for i, fun1 in enumerate(functions):
- for j, fun2 in enumerate(functions):
- tests.append((
- "Test{}{}".format(i, j),
- [fun1, fun2],
- ))
- for k, fun3 in enumerate(functions):
- tests.append((
- "Test{}{}{}".format(i, j, k),
- [fun1, fun2, fun3],
- ))
-
- swap = lambda x, n: (n, x)
- tests.append((
- "Swap1",
- [lambda x: (x, 42), swap],
- ))
- tests.append((
- "Swap2",
- [lambda x: (x, 42), swap, swap],
- ))
- return tuple(tests)
-
- @parameterized.named_parameters(*map_functions.__func__())
- def testMapFusion(self, functions):
- dataset = dataset_ops.Dataset.range(5).apply(
- optimization.assert_next(["Map", "Prefetch"]))
- for function in functions:
- dataset = dataset.map(function)
-
- dataset = dataset.prefetch(0).apply(optimization.optimize(["map_fusion"]))
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- for x in range(5):
- result = sess.run(get_next)
- r = x
- for function in functions:
- if isinstance(r, tuple):
- r = function(*r) # Pass tuple as multiple arguments.
- else:
- r = function(r)
- self.assertAllEqual(r, result)
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- @staticmethod
- def map_and_filter_functions():
- identity = lambda x: x
- increment = lambda x: x + 1
- minus_five = lambda x: x - 5
-
- def increment_and_square(x):
- y = x + 1
- return y * y
-
- take_all = lambda x: constant_op.constant(True)
- is_zero = lambda x: math_ops.equal(x, 0)
- is_odd = lambda x: math_ops.equal(x % 2, 0)
- greater = lambda x: math_ops.greater(x + 5, 0)
-
- functions = [identity, increment, minus_five, increment_and_square]
- filters = [take_all, is_zero, is_odd, greater]
- tests = []
-
- for x, fun in enumerate(functions):
- for y, predicate in enumerate(filters):
- tests.append(("Mixed{}{}".format(x, y), fun, predicate))
-
- # Multi output
- tests.append(("Multi1", lambda x: (x, x),
- lambda x, y: constant_op.constant(True)))
- tests.append(
- ("Multi2", lambda x: (x, 2),
- lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0)))
- return tuple(tests)
-
- @parameterized.named_parameters(*map_and_filter_functions.__func__())
- def testMapFilterFusion(self, function, predicate):
- dataset = dataset_ops.Dataset.range(10).apply(
- optimization.assert_next(
- ["Map",
- "FilterByLastComponent"])).map(function).filter(predicate).apply(
- optimization.optimize(["map_and_filter_fusion"]))
- self._testMapAndFilter(dataset, function, predicate)
-
- def _testMapAndFilter(self, dataset, function, predicate):
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- for x in range(10):
- r = function(x)
- if isinstance(r, tuple):
- b = predicate(*r) # Pass tuple as multiple arguments.
- else:
- b = predicate(r)
- if sess.run(b):
- result = sess.run(get_next)
- self.assertAllEqual(r, result)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testAdditionalInputs(self):
- a = constant_op.constant(3, dtype=dtypes.int64)
- b = constant_op.constant(4, dtype=dtypes.int64)
- some_tensor = math_ops.mul(a, b)
- function = lambda x: x * x
-
- def predicate(y):
- return math_ops.less(math_ops.cast(y, dtypes.int64), some_tensor)
-
- # We are currently not supporting functions with additional inputs.
- dataset = dataset_ops.Dataset.range(10).apply(
- optimization.assert_next(
- ["Map", "Filter"])).map(function).filter(predicate).apply(
- optimization.optimize(["map_and_filter_fusion"]))
-
- self._testMapAndFilter(dataset, function, predicate)
-
- @staticmethod
- def filter_functions():
- take_all = lambda x: constant_op.constant(True)
- is_zero = lambda x: math_ops.equal(x, 0)
- greater = lambda x: math_ops.greater(x + 5, 0)
-
- tests = []
- filters = [take_all, is_zero, greater]
- identity = lambda x: x
- for x, predicate_1 in enumerate(filters):
- for y, predicate_2 in enumerate(filters):
- tests.append(("Mixed{}{}".format(x, y), identity,
- [predicate_1, predicate_2]))
- for z, predicate_3 in enumerate(filters):
- tests.append(("Mixed{}{}{}".format(x, y, z), identity,
- [predicate_1, predicate_2, predicate_3]))
-
- take_all_multiple = lambda x, y: constant_op.constant(True)
- # Multi output
- tests.append(("Multi1", lambda x: (x, x),
- [take_all_multiple, take_all_multiple]))
- tests.append(("Multi2", lambda x: (x, 2), [
- take_all_multiple,
- lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0)
- ]))
- return tuple(tests)
-
- @parameterized.named_parameters(*filter_functions.__func__())
- def testFilterFusion(self, map_function, predicates):
- dataset = dataset_ops.Dataset.range(5).apply(
- optimization.assert_next(["Map", "Filter",
- "Prefetch"])).map(map_function)
- for predicate in predicates:
- dataset = dataset.filter(predicate)
-
- dataset = dataset.prefetch(0).apply(
- optimization.optimize(["filter_fusion"]))
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- for x in range(5):
- r = map_function(x)
- filtered = False
- for predicate in predicates:
- if isinstance(r, tuple):
- b = predicate(*r) # Pass tuple as multiple arguments.
- else:
- b = predicate(r)
- if not sess.run(b):
- filtered = True
- break
-
- if not filtered:
- result = sess.run(get_next)
- self.assertAllEqual(r, result)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py
deleted file mode 100644
index f7907eb890..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py
+++ /dev/null
@@ -1,85 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the MapParallelization optimization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from absl.testing import parameterized
-
-from tensorflow.contrib.data.python.ops import optimization
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import random_ops
-from tensorflow.python.platform import test
-
-
-class MapParallelizationTest(test_base.DatasetTestBase, parameterized.TestCase):
-
- @staticmethod
- def map_functions():
- identity = lambda x: x
- increment = lambda x: x + 1
-
- def assert_greater(x):
- assert_op = control_flow_ops.Assert(math_ops.greater(x, -1), [x])
- with ops.control_dependencies([assert_op]):
- return x
-
- def random(_):
- return random_ops.random_uniform([],
- minval=0,
- maxval=10,
- dtype=dtypes.int64,
- seed=42)
-
- def assert_with_random(x):
- x = assert_greater(x)
- return random(x)
-
- return (("Identity", identity, True), ("Increment", increment, True),
- ("AssertGreater", assert_greater, True), ("Random", random, False),
- ("AssertWithRandom", assert_with_random, False))
-
- @parameterized.named_parameters(*map_functions.__func__())
- def testMapParallelization(self, function, should_optimize):
- next_nodes = ["ParallelMap"] if should_optimize else ["Map"]
- dataset = dataset_ops.Dataset.range(5).apply(
- optimization.assert_next(next_nodes)).map(function).apply(
- optimization.optimize(["map_parallelization"]))
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
-
- with self.test_session() as sess:
- for x in range(5):
- result = sess.run(get_next)
- # No need to run the pipeline if it was not optimized. Also the results
- # might be hard to check because of random.
- if not should_optimize:
- return
- r = function(x)
- self.assertAllEqual(r, result)
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py
deleted file mode 100644
index a5ea85f454..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py
+++ /dev/null
@@ -1,223 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the MapVectorization optimization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import time
-
-from absl.testing import parameterized
-import numpy as np
-
-from tensorflow.contrib.data.python.ops import optimization
-from tensorflow.python.client import session
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import check_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.platform import test
-
-
-class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
-
- def _get_test_datasets(self,
- base_dataset,
- map_fn,
- num_parallel_calls=None,
- expect_optimized=True):
- """Given base dataset and map fn, creates test datasets.
-
- Returns a tuple of (unoptimized, dataset, optimized dataset). The
- unoptimized dataset has the assertion that Batch follows Map. The optimized
- dataset has the assertion that Map follows Batch, and has the
- "map_vectorization" optimization applied.
-
- Args:
- base_dataset: Input dataset to map->batch
- map_fn: Map function to use
- num_parallel_calls: (Optional.) num_parallel_calls argument for map
- expect_optimized: (Optional.) Whether we expect the optimization to take
- place, in which case we will assert that Batch is followed by Map,
- otherwise Map followed by Batch. Defaults to True.
-
- Returns:
- Tuple of (unoptimized dataset, optimized dataset).
- """
- map_node_name = "Map" if num_parallel_calls is None else "ParallelMap"
- batch_size = 100
-
- def _make_dataset(node_names):
- return base_dataset.apply(optimization.assert_next(node_names)).map(
- map_fn, num_parallel_calls=num_parallel_calls).batch(batch_size)
-
- unoptimized = _make_dataset([map_node_name, "Batch"])
- optimized = _make_dataset(["Batch", map_node_name] if expect_optimized else
- [map_node_name, "Batch"]).apply(
- optimization.optimize(["map_vectorization"]))
-
- return unoptimized, optimized
-
- @parameterized.named_parameters(
- ("Basic", lambda x: (x, x + 1), None),
- ("Parallel", lambda x: (x, x + 1), 12),
- ("Gather", lambda x: array_ops.gather(x, 0), 12),
- )
- def testOptimization(self, map_fn, num_parallel_calls):
- base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2],
- [3, 4]]).repeat(5)
- unoptimized, optimized = self._get_test_datasets(base_dataset, map_fn,
- num_parallel_calls)
- self.assertDatasetsEqual(unoptimized, optimized)
-
- def testOptimizationBadMapFn(self):
- # Test map functions that give an error
- def map_fn(x):
- # x has leading dimension 5, this will raise an error
- return array_ops.gather(x, 10)
-
- base_dataset = dataset_ops.Dataset.range(5).repeat(5).batch(
- 5, drop_remainder=True)
- _, optimized = self._get_test_datasets(base_dataset, map_fn)
- nxt = optimized.make_one_shot_iterator().get_next()
- with self.assertRaisesRegexp(errors.InvalidArgumentError,
- r"indices = 10 is not in \[0, 5\)"):
- self.evaluate(nxt)
-
- def testOptimizationWithCapturedInputs(self):
- # Tests that vectorization works with captured inputs
- def map_fn(x):
- return x + y
-
- y = constant_op.constant(1, shape=(2,))
- base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2],
- [3, 4]]).repeat(5)
- # TODO(rachelim): when this optimization works, turn on expect_optimized
- unoptimized, optimized = self._get_test_datasets(
- base_dataset, map_fn, expect_optimized=False)
- self.assertDatasetsEqual(optimized, unoptimized)
-
- def testOptimizationIgnoreStateful(self):
-
- def map_fn(x):
- with ops.control_dependencies([check_ops.assert_equal(x, 0)]):
- return array_ops.identity(x)
-
- base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2],
- [3, 4]]).repeat(5)
- unoptimized, optimized = self._get_test_datasets(
- base_dataset, map_fn, expect_optimized=False)
- self.assertDatasetsRaiseSameError(
- unoptimized, optimized, errors.InvalidArgumentError,
- [("OneShotIterator", "OneShotIterator_1", 1),
- ("IteratorGetNext", "IteratorGetNext_1", 1)])
-
- def testOptimizationIgnoreRagged(self):
- # Make sure we ignore inputs that might not be uniformly sized
- def map_fn(x):
- return array_ops.gather(x, 0)
-
- # output_shape = (?,)
- base_dataset = dataset_ops.Dataset.range(20).batch(3, drop_remainder=False)
- unoptimized, optimized = self._get_test_datasets(
- base_dataset, map_fn, expect_optimized=False)
- self.assertDatasetsEqual(unoptimized, optimized)
-
- def testOptimizationIgnoreRaggedMap(self):
- # Don't optimize when the output of the map fn shapes are unknown.
- def map_fn(x):
- return array_ops.tile(x, x)
-
- base_dataset = dataset_ops.Dataset.range(20).batch(1, drop_remainder=True)
- unoptimized, optimized = self._get_test_datasets(
- base_dataset, map_fn, expect_optimized=False)
- self.assertDatasetsRaiseSameError(
- unoptimized, optimized, errors.InvalidArgumentError,
- [("OneShotIterator", "OneShotIterator_1", 1),
- ("IteratorGetNext", "IteratorGetNext_1", 1)])
-
-
-class MapVectorizationBenchmark(test.Benchmark):
- # TODO(rachelim): Add a benchmark for more expensive transformations, such as
- # vgg_preprocessing.
-
- def _run(self, x, num_iters=100, name=None):
- deltas = []
- with session.Session() as sess:
- for _ in range(5):
- # Warm up session...
- sess.run(x)
- for _ in range(num_iters):
- start = time.time()
- sess.run(x)
- end = time.time()
- deltas.append(end - start)
- median_time = np.median(deltas)
- self.report_benchmark(iters=num_iters, wall_time=median_time, name=name)
- return median_time
-
- def _compare(self, input_dataset, map_fn, batch_size, input_size, str_id):
- num_elems = np.prod(input_size)
- name_template = "{}__batch_size_{}_input_size_{}_{}"
- unoptimized = input_dataset.map(map_fn).batch(batch_size)
- unoptimized_op = unoptimized.make_one_shot_iterator().get_next()
-
- optimized = unoptimized.apply(optimization.optimize(["map_vectorization"]))
- optimized_op = optimized.make_one_shot_iterator().get_next()
-
- unoptimized_time = self._run(
- unoptimized_op,
- name=name_template.format(str_id, batch_size, num_elems, "unoptimized"))
- optimized_time = self._run(
- optimized_op,
- name=name_template.format(str_id, batch_size, num_elems, "optimized"))
-
- print("Batch size: {}\n"
- "Input size: {}\n"
- "Transformation: {}\n"
- "Speedup: {}\n".format(batch_size, input_size, str_id,
- (unoptimized_time / optimized_time)))
-
- # Known cheap functions
- def benchmarkIdentity(self):
- self._benchmark_helper(lambda *args: [array_ops.identity(x) for x in args],
- "identity")
-
- def benchmarkAddConst(self):
- self._benchmark_helper(lambda *args: [x + 1 for x in args], "add_const")
-
- def benchmarkSelect(self):
- self._benchmark_helper(lambda *args: args[0], "select")
-
- def benchmarkCast(self):
- self._benchmark_helper(
- lambda *args: [math_ops.cast(x, dtypes.float64) for x in args], "cast")
-
- def _benchmark_helper(self, map_fn, str_id):
- input_sizes = [(10, 10, 3), (10, 100, 300)]
- batch_size = 1000
- for input_size in input_sizes:
- input_dataset = dataset_ops.Dataset.from_tensor_slices(
- (np.random.rand(*input_size), np.random.rand(*input_size))).repeat()
- self._compare(input_dataset, map_fn, batch_size, input_size, str_id)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py
deleted file mode 100644
index 33c250ab2a..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py
+++ /dev/null
@@ -1,183 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import time
-
-import numpy as np
-
-from tensorflow.contrib.data.python.ops import batching
-from tensorflow.contrib.data.python.ops import optimization
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.platform import test
-
-
-class ModelDatasetTest(test_base.DatasetTestBase):
-
- def testModelMap(self):
- k = 1024 * 1024
- dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k),
- np.random.rand(4 * k,
- 1))).repeat()
- dataset = dataset.map(math_ops.matmul)
- iterator = dataset.apply(optimization.model()).make_one_shot_iterator()
- get_next = iterator.get_next()
-
- deltas = []
- with self.cached_session() as sess:
- for _ in range(5):
- sess.run(get_next.op)
- for _ in range(100):
- start = time.time()
- sess.run(get_next.op)
- end = time.time()
- deltas.append(end - start)
-
- print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" %
- (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas),
- np.max(deltas)))
-
- def testModelParallelMap(self):
- k = 1024 * 1024
- dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k),
- np.random.rand(4 * k,
- 1))).repeat()
- dataset = dataset.map(
- math_ops.matmul, num_parallel_calls=optimization.AUTOTUNE)
- iterator = dataset.apply(optimization.model()).make_one_shot_iterator()
- get_next = iterator.get_next()
-
- deltas = []
- with self.cached_session() as sess:
- for _ in range(5):
- sess.run(get_next.op)
- for _ in range(1000):
- start = time.time()
- sess.run(get_next.op)
- end = time.time()
- deltas.append(end - start)
-
- print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" %
- (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas),
- np.max(deltas)))
-
- def testModelMapAndBatch(self):
- batch_size = 16
- k = 1024 * 1024
- dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k),
- np.random.rand(4 * k,
- 1))).repeat()
- dataset = dataset.apply(
- batching.map_and_batch(
- math_ops.matmul,
- num_parallel_calls=optimization.AUTOTUNE,
- batch_size=batch_size))
- iterator = dataset.apply(optimization.model()).make_one_shot_iterator()
- get_next = iterator.get_next()
-
- deltas = []
- with self.cached_session() as sess:
- for _ in range(5):
- sess.run(get_next.op)
- for _ in range(10):
- start = time.time()
- sess.run(get_next.op)
- end = time.time()
- deltas.append(end - start)
-
- print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" %
- (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas),
- np.max(deltas)))
-
- def testModelParallelInterleave(self):
- k = 1024 * 1024
- dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k),
- np.random.rand(4 * k,
- 1))).repeat()
- dataset = dataset.map(math_ops.matmul)
- dataset = dataset_ops.Dataset.range(1).repeat().interleave(
- lambda _: dataset,
- cycle_length=10,
- num_parallel_calls=optimization.AUTOTUNE)
- iterator = dataset.apply(optimization.model()).make_one_shot_iterator()
- get_next = iterator.get_next()
-
- deltas = []
- with self.cached_session() as sess:
- for _ in range(5):
- sess.run(get_next.op)
- for _ in range(1000):
- start = time.time()
- sess.run(get_next.op)
- end = time.time()
- deltas.append(end - start)
-
- print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" %
- (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas),
- np.max(deltas)))
-
- def testModelNested(self):
- k = 1024 * 1024
- a = (np.random.rand(1, 8 * k), np.random.rand(8 * k, 1))
- b = (np.random.rand(1, 4 * k), np.random.rand(4 * k, 1))
- c = (np.random.rand(1, 2 * k), np.random.rand(2 * k, 1))
- dataset = dataset_ops.Dataset.from_tensors((a, b, c)).repeat()
-
- def f1(a, b, c):
- x, y = a
- return math_ops.matmul(x, y), b, c
-
- def f2(a, b, c):
- x, y = b
- return a, math_ops.matmul(x, y), c
-
- def f3(a, b, c):
- x, y = c
- return a, b, math_ops.matmul(x, y)
-
- dataset = dataset.map(f1, num_parallel_calls=optimization.AUTOTUNE)
- dataset = dataset_ops.Dataset.range(1).repeat().interleave(
- lambda _: dataset, cycle_length=2)
-
- dataset = dataset.map(f2, num_parallel_calls=optimization.AUTOTUNE)
- dataset = dataset_ops.Dataset.range(1).repeat().interleave(
- lambda _: dataset, cycle_length=2)
-
- dataset = dataset.map(f3, num_parallel_calls=optimization.AUTOTUNE)
- iterator = dataset.apply(optimization.model()).make_one_shot_iterator()
- get_next = iterator.get_next()
-
- deltas = []
- with self.cached_session() as sess:
- for _ in range(5):
- sess.run(get_next)
- for _ in range(100):
- start = time.time()
- sess.run(get_next)
- end = time.time()
- deltas.append(end - start)
-
- print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" %
- (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas),
- np.max(deltas)))
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py
deleted file mode 100644
index b9e60cfa4e..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py
+++ /dev/null
@@ -1,58 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the MapParallelization optimization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.data.python.ops import optimization
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.ops import math_ops
-from tensorflow.python.platform import test
-
-
-class NoopEliminationTest(test_base.DatasetTestBase):
-
- def testNoopElimination(self):
- a = constant_op.constant(1, dtype=dtypes.int64)
- b = constant_op.constant(2, dtype=dtypes.int64)
- some_tensor = math_ops.mul(a, b)
-
- dataset = dataset_ops.Dataset.range(5)
- dataset = dataset.apply(
- optimization.assert_next(
- ["FiniteRepeat", "FiniteSkip", "Prefetch", "Prefetch"]))
- dataset = dataset.repeat(some_tensor).skip(5).prefetch(0).take(-1).skip(
- 0).repeat(1).prefetch(0)
- dataset = dataset.apply(optimization.optimize(["noop_elimination"]))
-
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
-
- with self.test_session() as sess:
- for x in range(5):
- result = sess.run(get_next)
- self.assertAllEqual(result, x)
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py
deleted file mode 100644
index 04f499f8c5..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py
+++ /dev/null
@@ -1,109 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.data.python.ops import optimization
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import random_ops
-from tensorflow.python.platform import test
-
-
-class OptimizeDatasetTest(test_base.DatasetTestBase):
-
- def testOptimizationDefault(self):
- dataset = dataset_ops.Dataset.range(10).apply(
- optimization.assert_next(
- ["Map", "Batch"])).map(lambda x: x * x).batch(10).apply(
- optimization.optimize())
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testOptimizationEmpty(self):
- dataset = dataset_ops.Dataset.range(10).apply(
- optimization.assert_next(
- ["Map", "Batch"])).map(lambda x: x * x).batch(10).apply(
- optimization.optimize([]))
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testOptimizationFusion(self):
- dataset = dataset_ops.Dataset.range(10).apply(
- optimization.assert_next(
- ["MapAndBatch"])).map(lambda x: x * x).batch(10).apply(
- optimization.optimize(["map_and_batch_fusion"]))
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testOptimizationStatefulFunction(self):
- dataset = dataset_ops.Dataset.range(10).map(
- lambda _: random_ops.random_uniform([])).batch(10).apply(
- optimization.optimize(["map_and_batch_fusion"]))
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(get_next)
-
- def testOptimizationLargeInputFromTensor(self):
- input_t = array_ops.placeholder(dtypes.int32, (None, None, None))
- dataset = dataset_ops.Dataset.from_tensors(input_t).apply(
- optimization.optimize())
- iterator = dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op, {input_t: np.ones([512, 1024, 1025], np.int32)})
- sess.run(get_next)
-
- def testOptimizationLargeInputFromTensorSlices(self):
- input_t = array_ops.placeholder(dtypes.int32, (None, None, None, None))
- dataset = dataset_ops.Dataset.from_tensor_slices(input_t).apply(
- optimization.optimize())
- iterator = dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op, {input_t: np.ones([1, 512, 1024, 1025], np.int32)})
- sess.run(get_next)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py
deleted file mode 100644
index 66ccaceea5..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py
+++ /dev/null
@@ -1,851 +0,0 @@
-# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for tensorflow.ops.parsing_ops."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import copy
-
-import numpy as np
-
-from tensorflow.contrib.data.python.ops import parsing_ops as contrib_parsing_ops
-from tensorflow.core.example import example_pb2
-from tensorflow.core.example import feature_pb2
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.util import nest
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors_impl
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.ops import parsing_ops
-from tensorflow.python.platform import test
-from tensorflow.python.platform import tf_logging
-
-# Helpers for creating Example objects
-example = example_pb2.Example
-feature = feature_pb2.Feature
-features = lambda d: feature_pb2.Features(feature=d)
-bytes_feature = lambda v: feature(bytes_list=feature_pb2.BytesList(value=v))
-int64_feature = lambda v: feature(int64_list=feature_pb2.Int64List(value=v))
-float_feature = lambda v: feature(float_list=feature_pb2.FloatList(value=v))
-# Helpers for creating SequenceExample objects
-feature_list = lambda l: feature_pb2.FeatureList(feature=l)
-feature_lists = lambda d: feature_pb2.FeatureLists(feature_list=d)
-sequence_example = example_pb2.SequenceExample
-
-
-def _compare_output_to_expected(tester, dict_tensors, expected_tensors,
- flat_output):
- tester.assertEqual(set(dict_tensors.keys()), set(expected_tensors.keys()))
-
- i = 0 # Index into the flattened output of session.run()
- for k, v in sorted(dict_tensors.items()):
- # TODO(shivaniagrawal): flat_output is same as v.
- expected_v = expected_tensors[k]
- tf_logging.info("Comparing key: %s", k)
- print("i", i, "flat_output", flat_output[i], "expected_v", expected_v)
- if sparse_tensor.is_sparse(v):
- # Three outputs for SparseTensor : indices, values, shape.
- tester.assertEqual([k, len(expected_v)], [k, 3])
- print("i", i, "flat_output", flat_output[i].indices, "expected_v",
- expected_v[0])
- tester.assertAllEqual(expected_v[0], flat_output[i].indices)
- tester.assertAllEqual(expected_v[1], flat_output[i].values)
- tester.assertAllEqual(expected_v[2], flat_output[i].dense_shape)
- else:
- # One output for standard Tensor.
- tester.assertAllEqual(expected_v, flat_output[i])
- i += 1
-
-
-class ParseExampleTest(test_base.DatasetTestBase):
-
- def _test(self,
- input_tensor,
- feature_val,
- expected_values=None,
- expected_err=None):
-
- with self.cached_session() as sess:
- if expected_err:
- with self.assertRaisesWithPredicateMatch(expected_err[0],
- expected_err[1]):
- dataset = dataset_ops.Dataset.from_tensors(input_tensor).apply(
- contrib_parsing_ops.parse_example_dataset(feature_val))
- get_next = dataset.make_one_shot_iterator().get_next()
- sess.run(get_next)
- return
- else:
- # Returns dict w/ Tensors and SparseTensors.
- # Check values.
- dataset = dataset_ops.Dataset.from_tensors(input_tensor).apply(
- contrib_parsing_ops.parse_example_dataset(feature_val))
- get_next = dataset.make_one_shot_iterator().get_next()
- result = sess.run(get_next)
- flattened = nest.flatten(result)
- print("result", result, "expected_values", expected_values)
- _compare_output_to_expected(self, result, expected_values, flattened)
-
- # Check shapes; if serialized is a Tensor we need its size to
- # properly check.
- batch_size = (
- input_tensor.eval().size if isinstance(input_tensor, ops.Tensor) else
- np.asarray(input_tensor).size)
- for k, f in feature_val.items():
- print("output_shapes as list ",
- tuple(dataset.output_shapes[k].as_list()))
- if isinstance(f, parsing_ops.FixedLenFeature) and f.shape is not None:
- self.assertEqual(dataset.output_shapes[k].as_list()[0], batch_size)
- elif isinstance(f, parsing_ops.VarLenFeature):
- self.assertEqual(dataset.output_shapes[k].as_list()[1], None)
-
- def testEmptySerializedWithAllDefaults(self):
- sparse_name = "st_a"
- a_name = "a"
- b_name = "b"
- c_name = "c:has_a_tricky_name"
- a_default = [0, 42, 0]
- b_default = np.random.rand(3, 3).astype(bytes)
- c_default = np.random.rand(2).astype(np.float32)
-
- expected_st_a = ( # indices, values, shape
- np.empty(
- (0, 2), dtype=np.int64), # indices
- np.empty(
- (0,), dtype=np.int64), # sp_a is DT_INT64
- np.array(
- [2, 0], dtype=np.int64)) # batch == 2, max_elems = 0
-
- expected_output = {
- sparse_name: expected_st_a,
- a_name: np.array(2 * [[a_default]]),
- b_name: np.array(2 * [b_default]),
- c_name: np.array(2 * [c_default]),
- }
-
- self._test(
- ops.convert_to_tensor(["", ""]), {
- sparse_name:
- parsing_ops.VarLenFeature(dtypes.int64),
- a_name:
- parsing_ops.FixedLenFeature(
- (1, 3), dtypes.int64, default_value=a_default),
- b_name:
- parsing_ops.FixedLenFeature(
- (3, 3), dtypes.string, default_value=b_default),
- c_name:
- parsing_ops.FixedLenFeature(
- (2,), dtypes.float32, default_value=c_default),
- },
- expected_values=expected_output)
-
- def testEmptySerializedWithoutDefaultsShouldFail(self):
- input_features = {
- "st_a":
- parsing_ops.VarLenFeature(dtypes.int64),
- "a":
- parsing_ops.FixedLenFeature(
- (1, 3), dtypes.int64, default_value=[0, 42, 0]),
- "b":
- parsing_ops.FixedLenFeature(
- (3, 3),
- dtypes.string,
- default_value=np.random.rand(3, 3).astype(bytes)),
- # Feature "c" is missing a default, this gap will cause failure.
- "c":
- parsing_ops.FixedLenFeature(
- (2,), dtype=dtypes.float32),
- }
-
- # Edge case where the key is there but the feature value is empty
- original = example(features=features({"c": feature()}))
- self._test(
- [original.SerializeToString()],
- input_features,
- expected_err=(errors_impl.InvalidArgumentError,
- "Feature: c \\(data type: float\\) is required"))
-
- # Standard case of missing key and value.
- self._test(
- ["", ""],
- input_features,
- expected_err=(errors_impl.InvalidArgumentError,
- "Feature: c \\(data type: float\\) is required"))
-
- def testDenseNotMatchingShapeShouldFail(self):
- original = [
- example(features=features({
- "a": float_feature([1, 1, 3]),
- })), example(features=features({
- "a": float_feature([-1, -1]),
- }))
- ]
-
- serialized = [m.SerializeToString() for m in original]
-
- self._test(
- ops.convert_to_tensor(serialized),
- {"a": parsing_ops.FixedLenFeature((1, 3), dtypes.float32)},
- expected_err=(errors_impl.InvalidArgumentError,
- "Key: a, Index: 1. Number of float values"))
-
- def testDenseDefaultNoShapeShouldFail(self):
- original = [example(features=features({"a": float_feature([1, 1, 3]),})),]
-
- serialized = [m.SerializeToString() for m in original]
-
- self._test(
- ops.convert_to_tensor(serialized),
- {"a": parsing_ops.FixedLenFeature(None, dtypes.float32)},
- expected_err=(ValueError, "Missing shape for feature a"))
-
- def testSerializedContainingSparse(self):
- original = [
- example(features=features({
- "st_c": float_feature([3, 4])
- })),
- example(features=features({
- "st_c": float_feature([]), # empty float list
- })),
- example(features=features({
- "st_d": feature(), # feature with nothing in it
- })),
- example(features=features({
- "st_c": float_feature([1, 2, -1]),
- "st_d": bytes_feature([b"hi"])
- }))
- ]
-
- serialized = [m.SerializeToString() for m in original]
-
- expected_st_c = ( # indices, values, shape
- np.array(
- [[0, 0], [0, 1], [3, 0], [3, 1], [3, 2]], dtype=np.int64), np.array(
- [3.0, 4.0, 1.0, 2.0, -1.0], dtype=np.float32), np.array(
- [4, 3], dtype=np.int64)) # batch == 2, max_elems = 3
-
- expected_st_d = ( # indices, values, shape
- np.array(
- [[3, 0]], dtype=np.int64), np.array(
- ["hi"], dtype=bytes), np.array(
- [4, 1], dtype=np.int64)) # batch == 2, max_elems = 1
-
- expected_output = {
- "st_c": expected_st_c,
- "st_d": expected_st_d,
- }
-
- self._test(
- ops.convert_to_tensor(serialized), {
- "st_c": parsing_ops.VarLenFeature(dtypes.float32),
- "st_d": parsing_ops.VarLenFeature(dtypes.string)
- },
- expected_values=expected_output)
-
- def testSerializedContainingSparseFeature(self):
- original = [
- example(features=features({
- "val": float_feature([3, 4]),
- "idx": int64_feature([5, 10])
- })),
- example(features=features({
- "val": float_feature([]), # empty float list
- "idx": int64_feature([])
- })),
- example(features=features({
- "val": feature(), # feature with nothing in it
- # missing idx feature
- })),
- example(features=features({
- "val": float_feature([1, 2, -1]),
- "idx":
- int64_feature([0, 9, 3]) # unsorted
- }))
- ]
-
- serialized = [m.SerializeToString() for m in original]
-
- expected_sp = ( # indices, values, shape
- np.array(
- [[0, 5], [0, 10], [3, 0], [3, 3], [3, 9]], dtype=np.int64),
- np.array(
- [3.0, 4.0, 1.0, -1.0, 2.0], dtype=np.float32), np.array(
- [4, 13], dtype=np.int64)) # batch == 4, max_elems = 13
-
- expected_output = {"sp": expected_sp,}
-
- self._test(
- ops.convert_to_tensor(serialized),
- {"sp": parsing_ops.SparseFeature(["idx"], "val", dtypes.float32, [13])},
- expected_values=expected_output)
-
- def testSerializedContainingSparseFeatureReuse(self):
- original = [
- example(features=features({
- "val1": float_feature([3, 4]),
- "val2": float_feature([5, 6]),
- "idx": int64_feature([5, 10])
- })),
- example(features=features({
- "val1": float_feature([]), # empty float list
- "idx": int64_feature([])
- })),
- ]
-
- serialized = [m.SerializeToString() for m in original]
-
- expected_sp1 = ( # indices, values, shape
- np.array(
- [[0, 5], [0, 10]], dtype=np.int64), np.array(
- [3.0, 4.0], dtype=np.float32), np.array(
- [2, 13], dtype=np.int64)) # batch == 2, max_elems = 13
-
- expected_sp2 = ( # indices, values, shape
- np.array(
- [[0, 5], [0, 10]], dtype=np.int64), np.array(
- [5.0, 6.0], dtype=np.float32), np.array(
- [2, 7], dtype=np.int64)) # batch == 2, max_elems = 13
-
- expected_output = {
- "sp1": expected_sp1,
- "sp2": expected_sp2,
- }
-
- self._test(
- ops.convert_to_tensor(serialized), {
- "sp1":
- parsing_ops.SparseFeature("idx", "val1", dtypes.float32, 13),
- "sp2":
- parsing_ops.SparseFeature(
- "idx", "val2", dtypes.float32, size=7, already_sorted=True)
- },
- expected_values=expected_output)
-
- def testSerializedContaining3DSparseFeature(self):
- original = [
- example(features=features({
- "val": float_feature([3, 4]),
- "idx0": int64_feature([5, 10]),
- "idx1": int64_feature([0, 2]),
- })),
- example(features=features({
- "val": float_feature([]), # empty float list
- "idx0": int64_feature([]),
- "idx1": int64_feature([]),
- })),
- example(features=features({
- "val": feature(), # feature with nothing in it
- # missing idx feature
- })),
- example(features=features({
- "val": float_feature([1, 2, -1]),
- "idx0": int64_feature([0, 9, 3]), # unsorted
- "idx1": int64_feature([1, 0, 2]),
- }))
- ]
-
- serialized = [m.SerializeToString() for m in original]
-
- expected_sp = (
- # indices
- np.array(
- [[0, 5, 0], [0, 10, 2], [3, 0, 1], [3, 3, 2], [3, 9, 0]],
- dtype=np.int64),
- # values
- np.array([3.0, 4.0, 1.0, -1.0, 2.0], dtype=np.float32),
- # shape batch == 4, max_elems = 13
- np.array([4, 13, 3], dtype=np.int64))
-
- expected_output = {"sp": expected_sp,}
-
- self._test(
- ops.convert_to_tensor(serialized), {
- "sp":
- parsing_ops.SparseFeature(["idx0", "idx1"], "val",
- dtypes.float32, [13, 3])
- },
- expected_values=expected_output)
-
- def testSerializedContainingDense(self):
- aname = "a"
- bname = "b*has+a:tricky_name"
- original = [
- example(features=features({
- aname: float_feature([1, 1]),
- bname: bytes_feature([b"b0_str"]),
- })), example(features=features({
- aname: float_feature([-1, -1]),
- bname: bytes_feature([b""]),
- }))
- ]
-
- serialized = [m.SerializeToString() for m in original]
-
- expected_output = {
- aname:
- np.array(
- [[1, 1], [-1, -1]], dtype=np.float32).reshape(2, 1, 2, 1),
- bname:
- np.array(
- ["b0_str", ""], dtype=bytes).reshape(2, 1, 1, 1, 1),
- }
-
- # No defaults, values required
- self._test(
- ops.convert_to_tensor(serialized), {
- aname:
- parsing_ops.FixedLenFeature((1, 2, 1), dtype=dtypes.float32),
- bname:
- parsing_ops.FixedLenFeature((1, 1, 1, 1), dtype=dtypes.string),
- },
- expected_values=expected_output)
-
- # This test is identical as the previous one except
- # for the creation of 'serialized'.
- def testSerializedContainingDenseWithConcat(self):
- aname = "a"
- bname = "b*has+a:tricky_name"
- # TODO(lew): Feature appearing twice should be an error in future.
- original = [
- (example(features=features({
- aname: float_feature([10, 10]),
- })), example(features=features({
- aname: float_feature([1, 1]),
- bname: bytes_feature([b"b0_str"]),
- }))),
- (
- example(features=features({
- bname: bytes_feature([b"b100"]),
- })),
- example(features=features({
- aname: float_feature([-1, -1]),
- bname: bytes_feature([b"b1"]),
- })),),
- ]
-
- serialized = [
- m.SerializeToString() + n.SerializeToString() for (m, n) in original
- ]
-
- expected_output = {
- aname:
- np.array(
- [[1, 1], [-1, -1]], dtype=np.float32).reshape(2, 1, 2, 1),
- bname:
- np.array(
- ["b0_str", "b1"], dtype=bytes).reshape(2, 1, 1, 1, 1),
- }
-
- # No defaults, values required
- self._test(
- ops.convert_to_tensor(serialized), {
- aname:
- parsing_ops.FixedLenFeature((1, 2, 1), dtype=dtypes.float32),
- bname:
- parsing_ops.FixedLenFeature((1, 1, 1, 1), dtype=dtypes.string),
- },
- expected_values=expected_output)
-
- def testSerializedContainingDenseScalar(self):
- original = [
- example(features=features({
- "a": float_feature([1]),
- })), example(features=features({}))
- ]
-
- serialized = [m.SerializeToString() for m in original]
-
- expected_output = {
- "a":
- np.array(
- [[1], [-1]], dtype=np.float32) # 2x1 (column vector)
- }
-
- self._test(
- ops.convert_to_tensor(serialized), {
- "a":
- parsing_ops.FixedLenFeature(
- (1,), dtype=dtypes.float32, default_value=-1),
- },
- expected_values=expected_output)
-
- def testSerializedContainingDenseWithDefaults(self):
- original = [
- example(features=features({
- "a": float_feature([1, 1]),
- })),
- example(features=features({
- "b": bytes_feature([b"b1"]),
- })),
- example(features=features({
- "b": feature()
- })),
- ]
-
- serialized = [m.SerializeToString() for m in original]
-
- expected_output = {
- "a":
- np.array(
- [[1, 1], [3, -3], [3, -3]], dtype=np.float32).reshape(3, 1, 2,
- 1),
- "b":
- np.array(
- ["tmp_str", "b1", "tmp_str"], dtype=bytes).reshape(3, 1, 1, 1,
- 1),
- }
-
- self._test(
- ops.convert_to_tensor(serialized), {
- "a":
- parsing_ops.FixedLenFeature(
- (1, 2, 1), dtype=dtypes.float32, default_value=[3.0, -3.0]),
- "b":
- parsing_ops.FixedLenFeature(
- (1, 1, 1, 1), dtype=dtypes.string, default_value="tmp_str"),
- },
- expected_values=expected_output)
-
- def testSerializedContainingSparseAndSparseFeatureAndDenseWithNoDefault(self):
- expected_st_a = ( # indices, values, shape
- np.empty(
- (0, 2), dtype=np.int64), # indices
- np.empty(
- (0,), dtype=np.int64), # sp_a is DT_INT64
- np.array(
- [2, 0], dtype=np.int64)) # batch == 2, max_elems = 0
- expected_sp = ( # indices, values, shape
- np.array(
- [[0, 0], [0, 3], [1, 7]], dtype=np.int64), np.array(
- ["a", "b", "c"], dtype="|S"), np.array(
- [2, 13], dtype=np.int64)) # batch == 4, max_elems = 13
-
- original = [
- example(features=features({
- "c": float_feature([3, 4]),
- "val": bytes_feature([b"a", b"b"]),
- "idx": int64_feature([0, 3])
- })), example(features=features({
- "c": float_feature([1, 2]),
- "val": bytes_feature([b"c"]),
- "idx": int64_feature([7])
- }))
- ]
-
- serialized = [m.SerializeToString() for m in original]
-
- a_default = [1, 2, 3]
- b_default = np.random.rand(3, 3).astype(bytes)
- expected_output = {
- "st_a": expected_st_a,
- "sp": expected_sp,
- "a": np.array(2 * [[a_default]]),
- "b": np.array(2 * [b_default]),
- "c": np.array(
- [[3, 4], [1, 2]], dtype=np.float32),
- }
-
- self._test(
- ops.convert_to_tensor(serialized),
- {
- "st_a":
- parsing_ops.VarLenFeature(dtypes.int64),
- "sp":
- parsing_ops.SparseFeature("idx", "val", dtypes.string, 13),
- "a":
- parsing_ops.FixedLenFeature(
- (1, 3), dtypes.int64, default_value=a_default),
- "b":
- parsing_ops.FixedLenFeature(
- (3, 3), dtypes.string, default_value=b_default),
- # Feature "c" must be provided, since it has no default_value.
- "c":
- parsing_ops.FixedLenFeature((2,), dtypes.float32),
- },
- expected_values=expected_output)
-
- def testSerializedContainingSparseAndSparseFeatureWithReuse(self):
- expected_idx = ( # indices, values, shape
- np.array(
- [[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.int64),
- np.array([0, 3, 7, 1]), np.array(
- [2, 2], dtype=np.int64)) # batch == 4, max_elems = 2
-
- expected_sp = ( # indices, values, shape
- np.array(
- [[0, 0], [0, 3], [1, 1], [1, 7]], dtype=np.int64), np.array(
- ["a", "b", "d", "c"], dtype="|S"), np.array(
- [2, 13], dtype=np.int64)) # batch == 4, max_elems = 13
-
- original = [
- example(features=features({
- "val": bytes_feature([b"a", b"b"]),
- "idx": int64_feature([0, 3])
- })), example(features=features({
- "val": bytes_feature([b"c", b"d"]),
- "idx": int64_feature([7, 1])
- }))
- ]
-
- serialized = [m.SerializeToString() for m in original]
-
- expected_output = {
- "idx": expected_idx,
- "sp": expected_sp,
- }
-
- self._test(
- ops.convert_to_tensor(serialized), {
- "idx":
- parsing_ops.VarLenFeature(dtypes.int64),
- "sp":
- parsing_ops.SparseFeature(["idx"], "val", dtypes.string, [13]),
- },
- expected_values=expected_output)
-
- def _testSerializedContainingVarLenDenseLargerBatch(self, batch_size):
- # During parsing, data read from the serialized proto is stored in buffers.
- # For small batch sizes, a buffer will contain one minibatch entry.
- # For larger batch sizes, a buffer may contain several minibatch
- # entries. This test identified a bug where the code that copied
- # data out of the buffers and into the output tensors assumed each
- # buffer only contained one minibatch entry. The bug has since been fixed.
- truth_int = [i for i in range(batch_size)]
- truth_str = [[("foo%d" % i).encode(), ("bar%d" % i).encode()]
- for i in range(batch_size)]
-
- expected_str = copy.deepcopy(truth_str)
-
- # Delete some intermediate entries
- for i in range(batch_size):
- col = 1
- if np.random.rand() < 0.25:
- # w.p. 25%, drop out the second entry
- expected_str[i][col] = b"default"
- col -= 1
- truth_str[i].pop()
- if np.random.rand() < 0.25:
- # w.p. 25%, drop out the second entry (possibly again)
- expected_str[i][col] = b"default"
- truth_str[i].pop()
-
- expected_output = {
- # Batch size batch_size, 1 time step.
- "a": np.array(truth_int, dtype=np.int64).reshape(batch_size, 1),
- # Batch size batch_size, 2 time steps.
- "b": np.array(expected_str, dtype="|S").reshape(batch_size, 2),
- }
-
- original = [
- example(features=features(
- {"a": int64_feature([truth_int[i]]),
- "b": bytes_feature(truth_str[i])}))
- for i in range(batch_size)
- ]
-
- serialized = [m.SerializeToString() for m in original]
-
- self._test(
- ops.convert_to_tensor(serialized, dtype=dtypes.string), {
- "a":
- parsing_ops.FixedLenSequenceFeature(
- shape=(),
- dtype=dtypes.int64,
- allow_missing=True,
- default_value=-1),
- "b":
- parsing_ops.FixedLenSequenceFeature(
- shape=[],
- dtype=dtypes.string,
- allow_missing=True,
- default_value="default"),
- },
- expected_values=expected_output)
-
- def testSerializedContainingVarLenDenseLargerBatch(self):
- np.random.seed(3456)
- for batch_size in (1, 10, 20, 100, 256):
- self._testSerializedContainingVarLenDenseLargerBatch(batch_size)
-
- def testSerializedContainingVarLenDense(self):
- aname = "a"
- bname = "b"
- cname = "c"
- dname = "d"
- original = [
- example(features=features({
- cname: int64_feature([2]),
- })),
- example(features=features({
- aname: float_feature([1, 1]),
- bname: bytes_feature([b"b0_str", b"b1_str"]),
- })),
- example(features=features({
- aname: float_feature([-1, -1, 2, 2]),
- bname: bytes_feature([b"b1"]),
- })),
- example(features=features({
- aname: float_feature([]),
- cname: int64_feature([3]),
- })),
- ]
-
- serialized = [m.SerializeToString() for m in original]
-
- expected_output = {
- aname:
- np.array(
- [
- [0, 0, 0, 0],
- [1, 1, 0, 0],
- [-1, -1, 2, 2],
- [0, 0, 0, 0],
- ],
- dtype=np.float32).reshape(4, 2, 2, 1),
- bname:
- np.array(
- [["", ""], ["b0_str", "b1_str"], ["b1", ""], ["", ""]],
- dtype=bytes).reshape(4, 2, 1, 1, 1),
- cname:
- np.array([2, 0, 0, 3], dtype=np.int64).reshape(4, 1),
- dname:
- np.empty(shape=(4, 0), dtype=bytes),
- }
-
- self._test(
- ops.convert_to_tensor(serialized), {
- aname:
- parsing_ops.FixedLenSequenceFeature(
- (2, 1), dtype=dtypes.float32, allow_missing=True),
- bname:
- parsing_ops.FixedLenSequenceFeature(
- (1, 1, 1), dtype=dtypes.string, allow_missing=True),
- cname:
- parsing_ops.FixedLenSequenceFeature(
- shape=[], dtype=dtypes.int64, allow_missing=True),
- dname:
- parsing_ops.FixedLenSequenceFeature(
- shape=[], dtype=dtypes.string, allow_missing=True),
- },
- expected_values=expected_output)
-
- # Test with padding values.
- expected_output_custom_padding = dict(expected_output)
- expected_output_custom_padding[aname] = np.array(
- [
- [-2, -2, -2, -2],
- [1, 1, -2, -2],
- [-1, -1, 2, 2],
- [-2, -2, -2, -2],
- ],
- dtype=np.float32).reshape(4, 2, 2, 1)
-
- self._test(
- ops.convert_to_tensor(serialized), {
- aname:
- parsing_ops.FixedLenSequenceFeature(
- (2, 1),
- dtype=dtypes.float32,
- allow_missing=True,
- default_value=-2.0),
- bname:
- parsing_ops.FixedLenSequenceFeature(
- (1, 1, 1), dtype=dtypes.string, allow_missing=True),
- cname:
- parsing_ops.FixedLenSequenceFeature(
- shape=[], dtype=dtypes.int64, allow_missing=True),
- dname:
- parsing_ops.FixedLenSequenceFeature(
- shape=[], dtype=dtypes.string, allow_missing=True),
- }, expected_output_custom_padding)
-
- # Change number of required values so the inputs are not a
- # multiple of this size.
- self._test(
- ops.convert_to_tensor(serialized), {
- aname:
- parsing_ops.FixedLenSequenceFeature(
- (2, 1), dtype=dtypes.float32, allow_missing=True),
- bname:
- parsing_ops.FixedLenSequenceFeature(
- (2, 1, 1), dtype=dtypes.string, allow_missing=True),
- },
- expected_err=(
- errors_impl.OpError, "Key: b, Index: 2. "
- "Number of bytes values is not a multiple of stride length."))
-
- self._test(
- ops.convert_to_tensor(serialized), {
- aname:
- parsing_ops.FixedLenSequenceFeature(
- (2, 1),
- dtype=dtypes.float32,
- allow_missing=True,
- default_value=[]),
- bname:
- parsing_ops.FixedLenSequenceFeature(
- (2, 1, 1), dtype=dtypes.string, allow_missing=True),
- },
- expected_err=(ValueError,
- "Cannot reshape a tensor with 0 elements to shape"))
-
- self._test(
- ops.convert_to_tensor(serialized), {
- aname:
- parsing_ops.FixedLenFeature((None, 2, 1), dtype=dtypes.float32),
- bname:
- parsing_ops.FixedLenSequenceFeature(
- (2, 1, 1), dtype=dtypes.string, allow_missing=True),
- },
- expected_err=(ValueError,
- "First dimension of shape for feature a unknown. "
- "Consider using FixedLenSequenceFeature."))
-
- self._test(
- ops.convert_to_tensor(serialized), {
- cname:
- parsing_ops.FixedLenFeature(
- (1, None), dtype=dtypes.int64, default_value=[[1]]),
- },
- expected_err=(ValueError,
- "All dimensions of shape for feature c need to be known "
- r"but received \(1, None\)."))
-
- self._test(
- ops.convert_to_tensor(serialized), {
- aname:
- parsing_ops.FixedLenSequenceFeature(
- (2, 1), dtype=dtypes.float32, allow_missing=True),
- bname:
- parsing_ops.FixedLenSequenceFeature(
- (1, 1, 1), dtype=dtypes.string, allow_missing=True),
- cname:
- parsing_ops.FixedLenSequenceFeature(
- shape=[], dtype=dtypes.int64, allow_missing=False),
- dname:
- parsing_ops.FixedLenSequenceFeature(
- shape=[], dtype=dtypes.string, allow_missing=True),
- },
- expected_err=(ValueError,
- "Unsupported: FixedLenSequenceFeature requires "
- "allow_missing to be True."))
-
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
deleted file mode 100644
index 7a6a7a709a..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
+++ /dev/null
@@ -1,948 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for prefetching_ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import threading
-
-from tensorflow.contrib.data.python.ops import prefetching_ops
-from tensorflow.core.protobuf import config_pb2
-from tensorflow.python.compat import compat
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.ops import iterator_ops
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import function
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.framework import test_util
-from tensorflow.python.ops import resource_variable_ops
-from tensorflow.python.platform import test
-
-
-class PrefetchingKernelsOpsTest(test_base.DatasetTestBase):
-
- def setUp(self):
- self._event = threading.Event()
-
- def _create_ds_and_iterator(self, device0, initializable=False):
-
- def gen():
- for i in range(1, 10):
- yield [float(i)]
- if i == 6:
- self._event.set()
-
- with ops.device(device0):
- ds = dataset_ops.Dataset.from_generator(gen, (dtypes.float32))
- if initializable:
- ds_iterator = ds.make_initializable_iterator()
- else:
- ds_iterator = ds.make_one_shot_iterator()
- return (ds, ds_iterator)
-
- def _create_ops(self, ds, ds_iterator, buffer_name, device0, device1):
- ds_iterator_handle = ds_iterator.string_handle()
-
- @function.Defun(dtypes.string)
- def _remote_fn(h):
- remote_iterator = iterator_ops.Iterator.from_string_handle(
- h, ds.output_types, ds.output_shapes)
- return remote_iterator.get_next()
-
- target = constant_op.constant(device0)
- with ops.device(device1):
- buffer_resource_handle = prefetching_ops.function_buffering_resource(
- f=_remote_fn,
- output_types=[dtypes.float32],
- target_device=target,
- string_arg=ds_iterator_handle,
- buffer_size=3,
- shared_name=buffer_name)
-
- with ops.device(device1):
- prefetch_op = prefetching_ops.function_buffering_resource_get_next(
- function_buffer_resource=buffer_resource_handle,
- output_types=[dtypes.float32])
- reset_op = prefetching_ops.function_buffering_resource_reset(
- function_buffer_resource=buffer_resource_handle)
- destroy_op = resource_variable_ops.destroy_resource_op(
- buffer_resource_handle, ignore_lookup_error=True)
-
- return (prefetch_op, reset_op, destroy_op)
-
- def _prefetch_fn_helper_one_shot(self, buffer_name, device0, device1):
- worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
-
- ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=False)
- prefetch_op, _, destroy_op = self._create_ops(ds, ds_iterator, buffer_name,
- device0, device1)
-
- with self.test_session(config=worker_config) as sess:
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [1.0])
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [2.0])
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [3.0])
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [4.0])
- self._event.wait()
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [5.0])
- sess.run(destroy_op)
-
- def testSameDeviceCPU(self):
- self._prefetch_fn_helper_one_shot("same_device_cpu",
- "/job:localhost/replica:0/task:0/cpu:0",
- "/job:localhost/replica:0/task:0/cpu:0")
-
- def testDifferentDeviceCPU(self):
- self._prefetch_fn_helper_one_shot("diff_device_cpu",
- "/job:localhost/replica:0/task:0/cpu:0",
- "/job:localhost/replica:0/task:0/cpu:1")
-
- def testDifferentDeviceCPUGPU(self):
- if not test_util.is_gpu_available():
- self.skipTest("No GPU available")
-
- self._prefetch_fn_helper_one_shot("cpu_gpu",
- "/job:localhost/replica:0/task:0/cpu:0",
- "/job:localhost/replica:0/task:0/gpu:0")
-
- def testReinitialization(self):
- worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
-
- device0 = "/job:localhost/replica:0/task:0/cpu:0"
- device1 = "/job:localhost/replica:0/task:0/cpu:1"
- ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=True)
- prefetch_op, reset_op, destroy_op = self._create_ops(
- ds, ds_iterator, "reinit", device0, device1)
-
- with self.test_session(config=worker_config) as sess:
- sess.run(ds_iterator.initializer)
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [1.0])
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [2.0])
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [3.0])
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [4.0])
- self._event.wait()
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [5.0])
- # Lets reset the function buffering resource and reinitialize the
- # iterator. Should be able to go through this again.
- self._event.clear()
- sess.run(reset_op)
- sess.run(ds_iterator.initializer)
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [1.0])
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [2.0])
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [3.0])
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [4.0])
- self._event.wait()
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [5.0])
- sess.run(destroy_op)
-
- def testReinitializationOutOfRange(self):
- worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
-
- device0 = "/job:localhost/replica:0/task:0/cpu:0"
- device1 = "/job:localhost/replica:0/task:0/cpu:1"
- ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=True)
- prefetch_op, reset_op, destroy_op = self._create_ops(
- ds, ds_iterator, "reinit", device0, device1)
-
- with self.test_session(config=worker_config) as sess:
- sess.run(ds_iterator.initializer)
- for i in range(1, 10):
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [float(i)])
- # Try fetching after its over twice to test out end of sequence.
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(prefetch_op)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(prefetch_op)
-
- # Now reset everything and try it out again.
- self._event.clear()
- sess.run(reset_op)
- sess.run(ds_iterator.initializer)
- for i in range(1, 10):
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [float(i)])
- # Try fetching after its over twice to test out end of sequence.
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(prefetch_op)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(prefetch_op)
-
- sess.run(destroy_op)
-
- def testStringsGPU(self):
- if not test_util.is_gpu_available():
- self.skipTest("No GPU available")
-
- device0 = "/job:localhost/replica:0/task:0/cpu:0"
- device1 = "/job:localhost/replica:0/task:0/gpu:0"
-
- ds = dataset_ops.Dataset.from_tensor_slices(["a", "b", "c"])
- ds_iterator = ds.make_one_shot_iterator()
- ds_iterator_handle = ds_iterator.string_handle()
-
- @function.Defun(dtypes.string)
- def _remote_fn(h):
- remote_iterator = iterator_ops.Iterator.from_string_handle(
- h, ds.output_types, ds.output_shapes)
- return remote_iterator.get_next()
-
- target = constant_op.constant(device0)
- with ops.device(device1):
- buffer_resource_handle = prefetching_ops.function_buffering_resource(
- f=_remote_fn,
- output_types=[dtypes.string],
- target_device=target,
- string_arg=ds_iterator_handle,
- buffer_size=3,
- shared_name="strings")
-
- with ops.device(device1):
- prefetch_op = prefetching_ops.function_buffering_resource_get_next(
- function_buffer_resource=buffer_resource_handle,
- output_types=[dtypes.string])
- destroy_op = resource_variable_ops.destroy_resource_op(
- buffer_resource_handle, ignore_lookup_error=True)
-
- with self.cached_session() as sess:
- self.assertEqual([b"a"], sess.run(prefetch_op))
- self.assertEqual([b"b"], sess.run(prefetch_op))
- self.assertEqual([b"c"], sess.run(prefetch_op))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(prefetch_op)
-
- sess.run(destroy_op)
-
-
-class PrefetchToDeviceTest(test_base.DatasetTestBase):
-
- def testPrefetchToDevice(self):
- host_dataset = dataset_ops.Dataset.range(10)
- device_dataset = host_dataset.apply(
- prefetching_ops.prefetch_to_device("/cpu:1"))
-
- # NOTE(mrry): This device block creates the "host" dataset and iterator on
- # /cpu:0, and ensures that the prefetching is across devices. In typical use
- # this would not be necessary, because the GPU device would not support any
- # of the dataset-related ops.
- with ops.device("/cpu:0"):
- iterator = device_dataset.make_one_shot_iterator()
-
- self.assertEqual(host_dataset.output_types, device_dataset.output_types)
- self.assertEqual(host_dataset.output_types, iterator.output_types)
- self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
- self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
- self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
- self.assertEqual(host_dataset.output_classes, iterator.output_classes)
-
- next_element = iterator.get_next()
- self.assertEqual(dtypes.int64, next_element.dtype)
- self.assertEqual([], next_element.shape)
-
- worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
- with self.test_session(config=worker_config) as sess:
- for i in range(10):
- self.assertEqual(i, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testPrefetchToSameDevice(self):
- host_dataset = dataset_ops.Dataset.range(10)
- device_dataset = host_dataset.apply(
- prefetching_ops.prefetch_to_device(
- "/job:localhost/replica:0/task:0/device:CPU:0"))
-
- # NOTE(mrry): This device block creates the "host" dataset and iterator on
- # /cpu:0, and ensures that the prefetching is across devices. In typical use
- # this would not be necessary, because the GPU device would not support any
- # of the dataset-related ops.
- with ops.device("/cpu:0"):
- iterator = device_dataset.make_one_shot_iterator()
-
- self.assertEqual(host_dataset.output_types, device_dataset.output_types)
- self.assertEqual(host_dataset.output_types, iterator.output_types)
- self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
- self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
- self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
- self.assertEqual(host_dataset.output_classes, iterator.output_classes)
-
- next_element = iterator.get_next()
- self.assertEqual(dtypes.int64, next_element.dtype)
- self.assertEqual([], next_element.shape)
-
- with self.cached_session() as sess:
- for i in range(10):
- self.assertEqual(i, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testPrefetchDictToDevice(self):
- host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
- device_dataset = host_dataset.apply(
- prefetching_ops.prefetch_to_device("/cpu:1"))
-
- # NOTE(mrry): This device block creates the "host" dataset and iterator on
- # /cpu:0, and ensures that the prefetching is across devices. In typical use
- # this would not be necessary, because the GPU device would not support any
- # of the dataset-related ops.
- with ops.device("/cpu:0"):
- iterator = device_dataset.make_one_shot_iterator()
-
- self.assertEqual(host_dataset.output_types, device_dataset.output_types)
- self.assertEqual(host_dataset.output_types, iterator.output_types)
- self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
- self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
- self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
- self.assertEqual(host_dataset.output_classes, iterator.output_classes)
-
- next_element = iterator.get_next()
- self.assertEqual(dtypes.int64, next_element["a"].dtype)
- self.assertEqual([], next_element["a"].shape)
-
- worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
- with self.test_session(config=worker_config) as sess:
- for i in range(10):
- self.assertEqual({"a": i}, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testPrefetchSparseTensorsToDevice(self):
- def make_tensor(i):
- return sparse_tensor.SparseTensorValue(
- indices=[[0, 0]], values=(i*[1]), dense_shape=[2, 2])
- host_dataset = dataset_ops.Dataset.range(10).map(make_tensor)
-
- device_dataset = host_dataset.apply(
- prefetching_ops.prefetch_to_device("/cpu:1"))
-
- # NOTE(mrry): This device block creates the "host" dataset and iterator on
- # /cpu:0, and ensures that the prefetching is across devices. In typical use
- # this would not be necessary, because the GPU device would not support any
- # of the dataset-related ops.
- with ops.device("/cpu:0"):
- iterator = device_dataset.make_one_shot_iterator()
-
- self.assertEqual(host_dataset.output_types, device_dataset.output_types)
- self.assertEqual(host_dataset.output_types, iterator.output_types)
- self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
- self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
- self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
- self.assertEqual(host_dataset.output_classes, iterator.output_classes)
-
- next_element = iterator.get_next()
- self.assertEqual(dtypes.int64, next_element.dtype)
-
- worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
- with self.test_session(config=worker_config) as sess:
- for i in range(10):
- actual = sess.run(next_element)
- self.assertAllEqual([i], actual.values)
- self.assertAllEqual([[0, 0]], actual.indices)
- self.assertAllEqual([2, 2], actual.dense_shape)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testPrefetchToDeviceGpu(self):
- if not test_util.is_gpu_available():
- self.skipTest("No GPU available")
-
- host_dataset = dataset_ops.Dataset.range(10)
- device_dataset = host_dataset.apply(
- prefetching_ops.prefetch_to_device("/gpu:0"))
-
- iterator = device_dataset.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- for i in range(10):
- self.assertEqual(i, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testPrefetchToDeviceWithReInit(self):
- host_dataset = dataset_ops.Dataset.range(10)
- device_dataset = host_dataset.apply(
- prefetching_ops.prefetch_to_device("/cpu:1"))
-
- # NOTE(mrry): This device block creates the "host" dataset and iterator on
- # /cpu:0, and ensures that the prefetching is across devices. In typical use
- # this would not be necessary, because the GPU device would not support any
- # of the dataset-related ops.
- with ops.device("/cpu:0"):
- iterator = device_dataset.make_initializable_iterator()
-
- self.assertEqual(host_dataset.output_types, device_dataset.output_types)
- self.assertEqual(host_dataset.output_types, iterator.output_types)
- self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
- self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
- self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
- self.assertEqual(host_dataset.output_classes, iterator.output_classes)
-
- next_element = iterator.get_next()
- self.assertEqual(dtypes.int64, next_element.dtype)
- self.assertEqual([], next_element.shape)
-
- worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
- with self.test_session(config=worker_config) as sess:
- sess.run(iterator.initializer)
- for i in range(5):
- self.assertEqual(i, sess.run(next_element))
- sess.run(iterator.initializer)
- for i in range(10):
- self.assertEqual(i, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testPrefetchToDeviceGpuWithReInit(self):
- if not test_util.is_gpu_available():
- self.skipTest("No GPU available")
-
- host_dataset = dataset_ops.Dataset.range(10)
- device_dataset = host_dataset.apply(
- prefetching_ops.prefetch_to_device("/gpu:0"))
-
- iterator = device_dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(iterator.initializer)
- for i in range(5):
- self.assertEqual(i, sess.run(next_element))
- sess.run(iterator.initializer)
- for i in range(10):
- self.assertEqual(i, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
-
-class CopyToDeviceTest(test_base.DatasetTestBase):
-
- def testCopyToDevice(self):
- host_dataset = dataset_ops.Dataset.range(10)
- device_dataset = host_dataset.apply(
- prefetching_ops.copy_to_device("/cpu:1"))
-
- with ops.device("/cpu:1"):
- iterator = device_dataset.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- self.assertEqual(host_dataset.output_types, device_dataset.output_types)
- self.assertEqual(host_dataset.output_types, iterator.output_types)
- self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
- self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
- self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
- self.assertEqual(host_dataset.output_classes, iterator.output_classes)
-
- self.assertEqual(dtypes.int64, next_element.dtype)
- self.assertEqual([], next_element.shape)
-
- worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
- with self.test_session(config=worker_config) as sess:
- for i in range(10):
- self.assertEqual(i, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testCopyToDeviceInt32(self):
- host_dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3])
- device_dataset = host_dataset.apply(
- prefetching_ops.copy_to_device("/cpu:1"))
-
- with ops.device("/cpu:1"):
- iterator = device_dataset.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- self.assertEqual(host_dataset.output_types, device_dataset.output_types)
- self.assertEqual(host_dataset.output_types, iterator.output_types)
- self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
- self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
- self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
- self.assertEqual(host_dataset.output_classes, iterator.output_classes)
-
- self.assertEqual(dtypes.int32, next_element.dtype)
- self.assertEqual((4,), next_element.shape)
-
- worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
- with self.test_session(config=worker_config) as sess:
- self.assertAllEqual([0, 1, 2, 3], sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testCopyToSameDevice(self):
- host_dataset = dataset_ops.Dataset.range(10)
- device_dataset = host_dataset.apply(
- prefetching_ops.copy_to_device("/cpu:0"))
-
- with ops.device("/cpu:0"):
- iterator = device_dataset.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- self.assertEqual(host_dataset.output_types, device_dataset.output_types)
- self.assertEqual(host_dataset.output_types, iterator.output_types)
- self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
- self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
- self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
- self.assertEqual(host_dataset.output_classes, iterator.output_classes)
-
- self.assertEqual(dtypes.int64, next_element.dtype)
- self.assertEqual([], next_element.shape)
-
- worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
- with self.test_session(config=worker_config) as sess:
- for i in range(10):
- self.assertEqual(i, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testCopyToDeviceWithPrefetch(self):
- host_dataset = dataset_ops.Dataset.range(10)
- device_dataset = host_dataset.apply(
- prefetching_ops.copy_to_device("/cpu:1")).prefetch(1)
-
- with ops.device("/cpu:1"):
- iterator = device_dataset.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- self.assertEqual(host_dataset.output_types, device_dataset.output_types)
- self.assertEqual(host_dataset.output_types, iterator.output_types)
- self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
- self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
- self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
- self.assertEqual(host_dataset.output_classes, iterator.output_classes)
-
- self.assertEqual(dtypes.int64, next_element.dtype)
- self.assertEqual([], next_element.shape)
-
- worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
- with self.test_session(config=worker_config) as sess:
- for i in range(10):
- self.assertEqual(i, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testCopyDictToDevice(self):
- host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
- device_dataset = host_dataset.apply(
- prefetching_ops.copy_to_device("/cpu:1"))
-
- with ops.device("/cpu:1"):
- iterator = device_dataset.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- self.assertEqual(host_dataset.output_types, device_dataset.output_types)
- self.assertEqual(host_dataset.output_types, iterator.output_types)
- self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
- self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
- self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
- self.assertEqual(host_dataset.output_classes, iterator.output_classes)
-
- self.assertEqual(dtypes.int64, next_element["a"].dtype)
- self.assertEqual([], next_element["a"].shape)
-
- worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
- with self.test_session(config=worker_config) as sess:
- for i in range(10):
- self.assertEqual({"a": i}, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testCopyDictToDeviceWithPrefetch(self):
- host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
- device_dataset = host_dataset.apply(
- prefetching_ops.copy_to_device("/cpu:1")).prefetch(1)
-
- with ops.device("/cpu:1"):
- iterator = device_dataset.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- self.assertEqual(host_dataset.output_types, device_dataset.output_types)
- self.assertEqual(host_dataset.output_types, iterator.output_types)
- self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
- self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
- self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
- self.assertEqual(host_dataset.output_classes, iterator.output_classes)
-
- self.assertEqual(dtypes.int64, next_element["a"].dtype)
- self.assertEqual([], next_element["a"].shape)
-
- worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
- with self.test_session(config=worker_config) as sess:
- for i in range(10):
- self.assertEqual({"a": i}, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testCopySparseTensorsToDevice(self):
-
- def make_tensor(i):
- return sparse_tensor.SparseTensorValue(
- indices=[[0, 0]], values=(i * [1]), dense_shape=[2, 2])
-
- host_dataset = dataset_ops.Dataset.range(10).map(make_tensor)
-
- device_dataset = host_dataset.apply(
- prefetching_ops.copy_to_device("/cpu:1"))
-
- with ops.device("/cpu:1"):
- iterator = device_dataset.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- self.assertEqual(host_dataset.output_types, device_dataset.output_types)
- self.assertEqual(host_dataset.output_types, iterator.output_types)
- self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
- self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
- self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
- self.assertEqual(host_dataset.output_classes, iterator.output_classes)
-
- self.assertEqual(dtypes.int64, next_element.dtype)
-
- worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
- with self.test_session(config=worker_config) as sess:
- for i in range(10):
- actual = sess.run(next_element)
- self.assertAllEqual([i], actual.values)
- self.assertAllEqual([[0, 0]], actual.indices)
- self.assertAllEqual([2, 2], actual.dense_shape)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testCopySparseTensorsToDeviceWithPrefetch(self):
-
- def make_tensor(i):
- return sparse_tensor.SparseTensorValue(
- indices=[[0, 0]], values=(i * [1]), dense_shape=[2, 2])
-
- host_dataset = dataset_ops.Dataset.range(10).map(make_tensor)
-
- device_dataset = host_dataset.apply(
- prefetching_ops.copy_to_device("/cpu:1")).prefetch(1)
-
- with ops.device("/cpu:1"):
- iterator = device_dataset.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- self.assertEqual(host_dataset.output_types, device_dataset.output_types)
- self.assertEqual(host_dataset.output_types, iterator.output_types)
- self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
- self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
- self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
- self.assertEqual(host_dataset.output_classes, iterator.output_classes)
-
- self.assertEqual(dtypes.int64, next_element.dtype)
-
- worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
- with self.test_session(config=worker_config) as sess:
- for i in range(10):
- actual = sess.run(next_element)
- self.assertAllEqual([i], actual.values)
- self.assertAllEqual([[0, 0]], actual.indices)
- self.assertAllEqual([2, 2], actual.dense_shape)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testCopyToDeviceGpu(self):
- if not test_util.is_gpu_available():
- self.skipTest("No GPU available")
-
- host_dataset = dataset_ops.Dataset.range(10)
- device_dataset = host_dataset.apply(
- prefetching_ops.copy_to_device("/gpu:0"))
-
- with ops.device("/gpu:0"):
- iterator = device_dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(iterator.initializer)
- for i in range(10):
- self.assertEqual(i, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testCopyToDeviceGpuWithPrefetch(self):
- if not test_util.is_gpu_available():
- self.skipTest("No GPU available")
-
- host_dataset = dataset_ops.Dataset.range(10)
- device_dataset = host_dataset.apply(
- prefetching_ops.copy_to_device("/gpu:0")).prefetch(1)
-
- with ops.device("/gpu:0"):
- iterator = device_dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(iterator.initializer)
- for i in range(10):
- self.assertEqual(i, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testCopyToDeviceGpuInt32(self):
- if not test_util.is_gpu_available():
- self.skipTest("No GPU available")
-
- host_dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3])
- device_dataset = host_dataset.apply(
- prefetching_ops.copy_to_device("/gpu:0"))
-
- with ops.device("/gpu:0"):
- iterator = device_dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(iterator.initializer)
- self.assertAllEqual([0, 1, 2, 3], sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testCopyToDeviceGpuInt32AndPrefetch(self):
- if not test_util.is_gpu_available():
- self.skipTest("No GPU available")
-
- host_dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3])
- device_dataset = host_dataset.apply(
- prefetching_ops.copy_to_device("/gpu:0")).prefetch(1)
-
- with ops.device("/gpu:0"):
- iterator = device_dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(iterator.initializer)
- self.assertAllEqual([0, 1, 2, 3], sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testCopyToDeviceGpuStrings(self):
- if not test_util.is_gpu_available():
- self.skipTest("No GPU available")
-
- host_dataset = dataset_ops.Dataset.from_tensors(["a", "b", "c"])
- device_dataset = host_dataset.apply(
- prefetching_ops.copy_to_device("/gpu:0"))
-
- with ops.device("/gpu:0"):
- iterator = device_dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(iterator.initializer)
- self.assertAllEqual([b"a", b"b", b"c"], sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testCopyToDeviceGpuStringsAndPrefetch(self):
- if not test_util.is_gpu_available():
- self.skipTest("No GPU available")
-
- host_dataset = dataset_ops.Dataset.from_tensors(["a", "b", "c"])
- device_dataset = host_dataset.apply(
- prefetching_ops.copy_to_device("/gpu:0"))
-
- with ops.device("/gpu:0"):
- iterator = device_dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(iterator.initializer)
- self.assertAllEqual([b"a", b"b", b"c"], sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testCopyToDevicePingPongCPUGPU(self):
- if not test_util.is_gpu_available():
- self.skipTest("No GPU available")
-
- with compat.forward_compatibility_horizon(2018, 8, 4):
- host_dataset = dataset_ops.Dataset.range(10)
- device_dataset = host_dataset.apply(
- prefetching_ops.copy_to_device("/gpu:0", source_device="/cpu:0"))
- back_to_cpu_dataset = device_dataset.apply(
- prefetching_ops.copy_to_device("/cpu:0", source_device="/gpu:0"))
-
- with ops.device("/cpu:0"):
- iterator = back_to_cpu_dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(iterator.initializer)
- for i in range(10):
- self.assertEqual(i, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testCopyToDeviceWithReInit(self):
- host_dataset = dataset_ops.Dataset.range(10)
- device_dataset = host_dataset.apply(
- prefetching_ops.copy_to_device("/cpu:1"))
-
- with ops.device("/cpu:1"):
- iterator = device_dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- self.assertEqual(host_dataset.output_types, device_dataset.output_types)
- self.assertEqual(host_dataset.output_types, iterator.output_types)
- self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
- self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
- self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
- self.assertEqual(host_dataset.output_classes, iterator.output_classes)
-
- self.assertEqual(dtypes.int64, next_element.dtype)
- self.assertEqual([], next_element.shape)
-
- worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
- with self.test_session(config=worker_config) as sess:
- sess.run(iterator.initializer)
- for i in range(5):
- self.assertEqual(i, sess.run(next_element))
- sess.run(iterator.initializer)
- for i in range(10):
- self.assertEqual(i, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testCopyToDeviceWithReInitAndPrefetch(self):
- host_dataset = dataset_ops.Dataset.range(10)
- device_dataset = host_dataset.apply(
- prefetching_ops.copy_to_device("/cpu:1")).prefetch(1)
-
- with ops.device("/cpu:1"):
- iterator = device_dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- self.assertEqual(host_dataset.output_types, device_dataset.output_types)
- self.assertEqual(host_dataset.output_types, iterator.output_types)
- self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
- self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
- self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
- self.assertEqual(host_dataset.output_classes, iterator.output_classes)
-
- self.assertEqual(dtypes.int64, next_element.dtype)
- self.assertEqual([], next_element.shape)
-
- worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
- with self.test_session(config=worker_config) as sess:
- sess.run(iterator.initializer)
- for i in range(5):
- self.assertEqual(i, sess.run(next_element))
- sess.run(iterator.initializer)
- for i in range(10):
- self.assertEqual(i, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testCopyToDeviceGpuWithReInit(self):
- if not test_util.is_gpu_available():
- self.skipTest("No GPU available")
-
- host_dataset = dataset_ops.Dataset.range(10)
- device_dataset = host_dataset.apply(
- prefetching_ops.copy_to_device("/gpu:0"))
-
- with ops.device("/gpu:0"):
- iterator = device_dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(iterator.initializer)
- for i in range(5):
- self.assertEqual(i, sess.run(next_element))
- sess.run(iterator.initializer)
- for i in range(10):
- self.assertEqual(i, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testCopyToDeviceGpuWithReInitAndPrefetch(self):
- if not test_util.is_gpu_available():
- self.skipTest("No GPU available")
-
- host_dataset = dataset_ops.Dataset.range(10)
- device_dataset = host_dataset.apply(
- prefetching_ops.copy_to_device("/gpu:0")).prefetch(1)
-
- with ops.device("/gpu:0"):
- iterator = device_dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(iterator.initializer)
- for i in range(5):
- self.assertEqual(i, sess.run(next_element))
- sess.run(iterator.initializer)
- for i in range(10):
- self.assertEqual(i, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testIteratorGetNextAsOptionalOnGPU(self):
- if not test_util.is_gpu_available():
- self.skipTest("No GPU available")
-
- host_dataset = dataset_ops.Dataset.range(3)
- device_dataset = host_dataset.apply(
- prefetching_ops.copy_to_device("/gpu:0"))
- with ops.device("/gpu:0"):
- iterator = device_dataset.make_initializable_iterator()
- next_elem = iterator_ops.get_next_as_optional(iterator)
- elem_has_value_t = next_elem.has_value()
- elem_value_t = next_elem.get_value()
-
- with self.cached_session() as sess:
- # Before initializing the iterator, evaluating the optional fails with
- # a FailedPreconditionError.
- with self.assertRaises(errors.FailedPreconditionError):
- sess.run(elem_has_value_t)
- with self.assertRaises(errors.FailedPreconditionError):
- sess.run(elem_value_t)
-
- # For each element of the dataset, assert that the optional evaluates to
- # the expected value.
- sess.run(iterator.initializer)
- for i in range(3):
- elem_has_value, elem_value = sess.run([elem_has_value_t, elem_value_t])
- self.assertTrue(elem_has_value)
- self.assertEqual(i, elem_value)
-
- # After exhausting the iterator, `next_elem.has_value()` will evaluate to
- # false, and attempting to get the value will fail.
- for _ in range(2):
- self.assertFalse(sess.run(elem_has_value_t))
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(elem_value_t)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
deleted file mode 100644
index 2e901587f4..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
+++ /dev/null
@@ -1,78 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Test RangeDataset."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.data.python.ops import counter
-from tensorflow.contrib.data.python.ops import enumerate_ops
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.platform import test
-
-
-class RangeDatasetTest(test_base.DatasetTestBase):
-
- def testEnumerateDataset(self):
- components = (["a", "b"], [1, 2], [37.0, 38])
- start = constant_op.constant(20, dtype=dtypes.int64)
-
- iterator = (dataset_ops.Dataset.from_tensor_slices(components).apply(
- enumerate_ops.enumerate_dataset(start)).make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- self.assertEqual(dtypes.int64, get_next[0].dtype)
- self.assertEqual((), get_next[0].shape)
- self.assertEqual([tensor_shape.TensorShape([])] * 3,
- [t.shape for t in get_next[1]])
-
- with self.cached_session() as sess:
- sess.run(init_op)
- self.assertEqual((20, (b"a", 1, 37.0)), sess.run(get_next))
- self.assertEqual((21, (b"b", 2, 38.0)), sess.run(get_next))
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testCounter(self):
- """Test dataset construction using `count`."""
- iterator = (counter.Counter(start=3, step=4)
- .make_one_shot_iterator())
- get_next = iterator.get_next()
- self.assertEqual([], get_next.shape.as_list())
- self.assertEqual(dtypes.int64, get_next.dtype)
-
- negative_iterator = (counter.Counter(start=0, step=-1)
- .make_one_shot_iterator())
- negative_get_next = negative_iterator.get_next()
-
- with self.cached_session() as sess:
- self.assertEqual(3, sess.run(get_next))
- self.assertEqual(3 + 4, sess.run(get_next))
- self.assertEqual(3 + 2 * 4, sess.run(get_next))
-
- self.assertEqual(0, sess.run(negative_get_next))
- self.assertEqual(-1, sess.run(negative_get_next))
- self.assertEqual(-2, sess.run(negative_get_next))
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
deleted file mode 100644
index 66ed547b6d..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
+++ /dev/null
@@ -1,1083 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import gzip
-import os
-import zlib
-
-import numpy as np
-
-from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base
-from tensorflow.contrib.data.python.ops import readers
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import readers as core_readers
-from tensorflow.python.data.util import nest
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import parsing_ops
-from tensorflow.python.ops import string_ops
-from tensorflow.python.platform import test
-
-
-class ReadBatchFeaturesTest(
- reader_dataset_ops_test_base.ReadBatchFeaturesTestBase):
-
- def testRead(self):
- for batch_size in [1, 2]:
- for num_epochs in [1, 10]:
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- # Basic test: read from file 0.
- self.outputs = self.make_batch_feature(
- filenames=self.test_filenames[0],
- label_key="label",
- num_epochs=num_epochs,
- batch_size=batch_size).make_one_shot_iterator().get_next()
- self.verify_records(
- sess,
- batch_size,
- 0,
- num_epochs=num_epochs,
- label_key_provided=True)
- with self.assertRaises(errors.OutOfRangeError):
- self._next_actual_batch(sess, label_key_provided=True)
-
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- # Basic test: read from file 1.
- self.outputs = self.make_batch_feature(
- filenames=self.test_filenames[1],
- label_key="label",
- num_epochs=num_epochs,
- batch_size=batch_size).make_one_shot_iterator().get_next()
- self.verify_records(
- sess,
- batch_size,
- 1,
- num_epochs=num_epochs,
- label_key_provided=True)
- with self.assertRaises(errors.OutOfRangeError):
- self._next_actual_batch(sess, label_key_provided=True)
-
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- # Basic test: read from both files.
- self.outputs = self.make_batch_feature(
- filenames=self.test_filenames,
- label_key="label",
- num_epochs=num_epochs,
- batch_size=batch_size).make_one_shot_iterator().get_next()
- self.verify_records(
- sess,
- batch_size,
- num_epochs=num_epochs,
- label_key_provided=True)
- with self.assertRaises(errors.OutOfRangeError):
- self._next_actual_batch(sess, label_key_provided=True)
-
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- # Basic test: read from both files.
- self.outputs = self.make_batch_feature(
- filenames=self.test_filenames,
- num_epochs=num_epochs,
- batch_size=batch_size).make_one_shot_iterator().get_next()
- self.verify_records(sess, batch_size, num_epochs=num_epochs)
- with self.assertRaises(errors.OutOfRangeError):
- self._next_actual_batch(sess)
-
- def testReadWithEquivalentDataset(self):
- features = {
- "file": parsing_ops.FixedLenFeature([], dtypes.int64),
- "record": parsing_ops.FixedLenFeature([], dtypes.int64),
- }
- dataset = (
- core_readers.TFRecordDataset(self.test_filenames)
- .map(lambda x: parsing_ops.parse_single_example(x, features))
- .repeat(10).batch(2))
- iterator = dataset.make_initializable_iterator()
- init_op = iterator.initializer
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
- for file_batch, _, _, _, record_batch, _ in self._next_expected_batch(
- range(self._num_files), 2, 10):
- actual_batch = sess.run(next_element)
- self.assertAllEqual(file_batch, actual_batch["file"])
- self.assertAllEqual(record_batch, actual_batch["record"])
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testReadWithFusedShuffleRepeatDataset(self):
- num_epochs = 5
- total_records = num_epochs * self._num_records
- for batch_size in [1, 2]:
- # Test that shuffling with same seed produces the same result.
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- outputs1 = self.make_batch_feature(
- filenames=self.test_filenames[0],
- num_epochs=num_epochs,
- batch_size=batch_size,
- shuffle=True,
- shuffle_seed=5).make_one_shot_iterator().get_next()
- outputs2 = self.make_batch_feature(
- filenames=self.test_filenames[0],
- num_epochs=num_epochs,
- batch_size=batch_size,
- shuffle=True,
- shuffle_seed=5).make_one_shot_iterator().get_next()
- for _ in range(total_records // batch_size):
- batch1 = self._run_actual_batch(outputs1, sess)
- batch2 = self._run_actual_batch(outputs2, sess)
- for i in range(len(batch1)):
- self.assertAllEqual(batch1[i], batch2[i])
-
- # Test that shuffling with different seeds produces a different order.
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- outputs1 = self.make_batch_feature(
- filenames=self.test_filenames[0],
- num_epochs=num_epochs,
- batch_size=batch_size,
- shuffle=True,
- shuffle_seed=5).make_one_shot_iterator().get_next()
- outputs2 = self.make_batch_feature(
- filenames=self.test_filenames[0],
- num_epochs=num_epochs,
- batch_size=batch_size,
- shuffle=True,
- shuffle_seed=15).make_one_shot_iterator().get_next()
- all_equal = True
- for _ in range(total_records // batch_size):
- batch1 = self._run_actual_batch(outputs1, sess)
- batch2 = self._run_actual_batch(outputs2, sess)
- for i in range(len(batch1)):
- all_equal = all_equal and np.array_equal(batch1[i], batch2[i])
- self.assertFalse(all_equal)
-
- def testParallelReadersAndParsers(self):
- num_epochs = 5
- for batch_size in [1, 2]:
- for reader_num_threads in [2, 4]:
- for parser_num_threads in [2, 4]:
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- self.outputs = self.make_batch_feature(
- filenames=self.test_filenames,
- label_key="label",
- num_epochs=num_epochs,
- batch_size=batch_size,
- reader_num_threads=reader_num_threads,
- parser_num_threads=parser_num_threads).make_one_shot_iterator(
- ).get_next()
- self.verify_records(
- sess,
- batch_size,
- num_epochs=num_epochs,
- label_key_provided=True,
- interleave_cycle_length=reader_num_threads)
- with self.assertRaises(errors.OutOfRangeError):
- self._next_actual_batch(sess, label_key_provided=True)
-
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- self.outputs = self.make_batch_feature(
- filenames=self.test_filenames,
- num_epochs=num_epochs,
- batch_size=batch_size,
- reader_num_threads=reader_num_threads,
- parser_num_threads=parser_num_threads).make_one_shot_iterator(
- ).get_next()
- self.verify_records(
- sess,
- batch_size,
- num_epochs=num_epochs,
- interleave_cycle_length=reader_num_threads)
- with self.assertRaises(errors.OutOfRangeError):
- self._next_actual_batch(sess)
-
- def testDropFinalBatch(self):
- for batch_size in [1, 2]:
- for num_epochs in [1, 10]:
- with ops.Graph().as_default():
- # Basic test: read from file 0.
- outputs = self.make_batch_feature(
- filenames=self.test_filenames[0],
- label_key="label",
- num_epochs=num_epochs,
- batch_size=batch_size,
- drop_final_batch=True).make_one_shot_iterator().get_next()
- for tensor in nest.flatten(outputs):
- if isinstance(tensor, ops.Tensor): # Guard against SparseTensor.
- self.assertEqual(tensor.shape[0], batch_size)
-
- def testIndefiniteRepeatShapeInference(self):
- dataset = self.make_batch_feature(
- filenames=self.test_filenames[0],
- label_key="label",
- num_epochs=None,
- batch_size=32)
- for shape, clazz in zip(nest.flatten(dataset.output_shapes),
- nest.flatten(dataset.output_classes)):
- if issubclass(clazz, ops.Tensor):
- self.assertEqual(32, shape[0])
-
-
-class MakeCsvDatasetTest(test_base.DatasetTestBase):
-
- def _make_csv_dataset(self, filenames, batch_size, num_epochs=1, **kwargs):
- return readers.make_csv_dataset(
- filenames, batch_size=batch_size, num_epochs=num_epochs, **kwargs)
-
- def _setup_files(self, inputs, linebreak="\n", compression_type=None):
- filenames = []
- for i, ip in enumerate(inputs):
- fn = os.path.join(self.get_temp_dir(), "temp_%d.csv" % i)
- contents = linebreak.join(ip).encode("utf-8")
- if compression_type is None:
- with open(fn, "wb") as f:
- f.write(contents)
- elif compression_type == "GZIP":
- with gzip.GzipFile(fn, "wb") as f:
- f.write(contents)
- elif compression_type == "ZLIB":
- contents = zlib.compress(contents)
- with open(fn, "wb") as f:
- f.write(contents)
- else:
- raise ValueError("Unsupported compression_type", compression_type)
- filenames.append(fn)
- return filenames
-
- def _next_expected_batch(self, expected_output, expected_keys, batch_size,
- num_epochs):
- features = {k: [] for k in expected_keys}
- for _ in range(num_epochs):
- for values in expected_output:
- for n, key in enumerate(expected_keys):
- features[key].append(values[n])
- if len(features[expected_keys[0]]) == batch_size:
- yield features
- features = {k: [] for k in expected_keys}
- if features[expected_keys[0]]: # Leftover from the last batch
- yield features
-
- def _verify_output(
- self,
- sess,
- dataset,
- batch_size,
- num_epochs,
- label_name,
- expected_output,
- expected_keys,
- ):
- nxt = dataset.make_one_shot_iterator().get_next()
-
- for expected_features in self._next_expected_batch(
- expected_output,
- expected_keys,
- batch_size,
- num_epochs,
- ):
- actual_features = sess.run(nxt)
-
- if label_name is not None:
- expected_labels = expected_features.pop(label_name)
- self.assertAllEqual(expected_labels, actual_features[1])
- actual_features = actual_features[0]
-
- for k in expected_features.keys():
- # Compare features
- self.assertAllEqual(expected_features[k], actual_features[k])
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(nxt)
-
- def _test_dataset(self,
- inputs,
- expected_output,
- expected_keys,
- batch_size=1,
- num_epochs=1,
- label_name=None,
- **kwargs):
- """Checks that elements produced by CsvDataset match expected output."""
- # Convert str type because py3 tf strings are bytestrings
- filenames = self._setup_files(
- inputs, compression_type=kwargs.get("compression_type", None))
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- dataset = self._make_csv_dataset(
- filenames,
- batch_size=batch_size,
- num_epochs=num_epochs,
- label_name=label_name,
- **kwargs)
- self._verify_output(sess, dataset, batch_size, num_epochs, label_name,
- expected_output, expected_keys)
-
- def testMakeCSVDataset(self):
- """Tests making a CSV dataset with keys and defaults provided."""
- record_defaults = [
- constant_op.constant([], dtypes.int32),
- constant_op.constant([], dtypes.int64),
- constant_op.constant([], dtypes.float32),
- constant_op.constant([], dtypes.float64),
- constant_op.constant([], dtypes.string)
- ]
-
- column_names = ["col%d" % i for i in range(5)]
- inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [
- ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19"
- ]]
- expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"],
- [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]]
- label = "col0"
-
- self._test_dataset(
- inputs,
- expected_output=expected_output,
- expected_keys=column_names,
- column_names=column_names,
- label_name=label,
- batch_size=1,
- num_epochs=1,
- shuffle=False,
- header=True,
- column_defaults=record_defaults,
- )
-
- def testMakeCSVDataset_withBatchSizeAndEpochs(self):
- """Tests making a CSV dataset with keys and defaults provided."""
- record_defaults = [
- constant_op.constant([], dtypes.int32),
- constant_op.constant([], dtypes.int64),
- constant_op.constant([], dtypes.float32),
- constant_op.constant([], dtypes.float64),
- constant_op.constant([], dtypes.string)
- ]
-
- column_names = ["col%d" % i for i in range(5)]
- inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [
- ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19"
- ]]
- expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"],
- [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]]
- label = "col0"
-
- self._test_dataset(
- inputs,
- expected_output=expected_output,
- expected_keys=column_names,
- column_names=column_names,
- label_name=label,
- batch_size=3,
- num_epochs=10,
- shuffle=False,
- header=True,
- column_defaults=record_defaults,
- )
-
- def testMakeCSVDataset_withCompressionType(self):
- """Tests `compression_type` argument."""
- record_defaults = [
- constant_op.constant([], dtypes.int32),
- constant_op.constant([], dtypes.int64),
- constant_op.constant([], dtypes.float32),
- constant_op.constant([], dtypes.float64),
- constant_op.constant([], dtypes.string)
- ]
-
- column_names = ["col%d" % i for i in range(5)]
- inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [
- ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19"
- ]]
- expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"],
- [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]]
- label = "col0"
-
- for compression_type in ("GZIP", "ZLIB"):
- self._test_dataset(
- inputs,
- expected_output=expected_output,
- expected_keys=column_names,
- column_names=column_names,
- label_name=label,
- batch_size=1,
- num_epochs=1,
- shuffle=False,
- header=True,
- column_defaults=record_defaults,
- compression_type=compression_type,
- )
-
- def testMakeCSVDataset_withBadInputs(self):
- """Tests that exception is raised when input is malformed.
- """
- record_defaults = [
- constant_op.constant([], dtypes.int32),
- constant_op.constant([], dtypes.int64),
- constant_op.constant([], dtypes.float32),
- constant_op.constant([], dtypes.float64),
- constant_op.constant([], dtypes.string)
- ]
-
- column_names = ["col%d" % i for i in range(5)]
- inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [
- ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19"
- ]]
- filenames = self._setup_files(inputs)
-
- # Duplicate column names
- with self.assertRaises(ValueError):
- self._make_csv_dataset(
- filenames,
- batch_size=1,
- column_defaults=record_defaults,
- label_name="col0",
- column_names=column_names * 2)
-
- # Label key not one of column names
- with self.assertRaises(ValueError):
- self._make_csv_dataset(
- filenames,
- batch_size=1,
- column_defaults=record_defaults,
- label_name="not_a_real_label",
- column_names=column_names)
-
- def testMakeCSVDataset_withNoLabel(self):
- """Tests making a CSV dataset with no label provided."""
- record_defaults = [
- constant_op.constant([], dtypes.int32),
- constant_op.constant([], dtypes.int64),
- constant_op.constant([], dtypes.float32),
- constant_op.constant([], dtypes.float64),
- constant_op.constant([], dtypes.string)
- ]
-
- column_names = ["col%d" % i for i in range(5)]
- inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [
- ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19"
- ]]
- expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"],
- [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]]
-
- self._test_dataset(
- inputs,
- expected_output=expected_output,
- expected_keys=column_names,
- column_names=column_names,
- batch_size=1,
- num_epochs=1,
- shuffle=False,
- header=True,
- column_defaults=record_defaults,
- )
-
- def testMakeCSVDataset_withNoHeader(self):
- """Tests that datasets can be created from CSV files with no header line.
- """
- record_defaults = [
- constant_op.constant([], dtypes.int32),
- constant_op.constant([], dtypes.int64),
- constant_op.constant([], dtypes.float32),
- constant_op.constant([], dtypes.float64),
- constant_op.constant([], dtypes.string)
- ]
-
- column_names = ["col%d" % i for i in range(5)]
- inputs = [["0,1,2,3,4", "5,6,7,8,9"], ["10,11,12,13,14", "15,16,17,18,19"]]
- expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"],
- [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]]
- label = "col0"
-
- self._test_dataset(
- inputs,
- expected_output=expected_output,
- expected_keys=column_names,
- column_names=column_names,
- label_name=label,
- batch_size=1,
- num_epochs=1,
- shuffle=False,
- header=False,
- column_defaults=record_defaults,
- )
-
- def testMakeCSVDataset_withTypes(self):
- """Tests that defaults can be a dtype instead of a Tensor for required vals.
- """
- record_defaults = [
- dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64,
- dtypes.string
- ]
-
- column_names = ["col%d" % i for i in range(5)]
- inputs = [[",".join(x[0] for x in column_names), "0,1,2,3,4", "5,6,7,8,9"],
- [
- ",".join(x[0] for x in column_names), "10,11,12,13,14",
- "15,16,17,18,19"
- ]]
- expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"],
- [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]]
- label = "col0"
-
- self._test_dataset(
- inputs,
- expected_output=expected_output,
- expected_keys=column_names,
- column_names=column_names,
- label_name=label,
- batch_size=1,
- num_epochs=1,
- shuffle=False,
- header=True,
- column_defaults=record_defaults,
- )
-
- def testMakeCSVDataset_withNoColNames(self):
- """Tests that datasets can be created when column names are not specified.
-
- In that case, we should infer the column names from the header lines.
- """
- record_defaults = [
- constant_op.constant([], dtypes.int32),
- constant_op.constant([], dtypes.int64),
- constant_op.constant([], dtypes.float32),
- constant_op.constant([], dtypes.float64),
- constant_op.constant([], dtypes.string)
- ]
-
- column_names = ["col%d" % i for i in range(5)]
- inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [
- ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19"
- ]]
- expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"],
- [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]]
- label = "col0"
-
- self._test_dataset(
- inputs,
- expected_output=expected_output,
- expected_keys=column_names,
- label_name=label,
- batch_size=1,
- num_epochs=1,
- shuffle=False,
- header=True,
- column_defaults=record_defaults,
- )
-
- def testMakeCSVDataset_withTypeInferenceMismatch(self):
- # Test that error is thrown when num fields doesn't match columns
- column_names = ["col%d" % i for i in range(5)]
- inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [
- ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19"
- ]]
- filenames = self._setup_files(inputs)
- with self.assertRaises(ValueError):
- self._make_csv_dataset(
- filenames,
- column_names=column_names + ["extra_name"],
- column_defaults=None,
- batch_size=2,
- num_epochs=10)
-
- def testMakeCSVDataset_withTypeInference(self):
- """Tests that datasets can be created when no defaults are specified.
-
- In that case, we should infer the types from the first N records.
- """
- column_names = ["col%d" % i for i in range(5)]
- str_int32_max = str(2**33)
- inputs = [[
- ",".join(x for x in column_names),
- "0,%s,2.0,3e50,rabbit" % str_int32_max
- ]]
- expected_output = [[0, 2**33, 2.0, 3e50, b"rabbit"]]
- label = "col0"
-
- self._test_dataset(
- inputs,
- expected_output=expected_output,
- expected_keys=column_names,
- column_names=column_names,
- label_name=label,
- batch_size=1,
- num_epochs=1,
- shuffle=False,
- header=True,
- )
-
- def testMakeCSVDataset_withTypeInferenceFallthrough(self):
- """Tests that datasets can be created when no defaults are specified.
-
- Tests on a deliberately tricky file.
- """
- column_names = ["col%d" % i for i in range(5)]
- str_int32_max = str(2**33)
- inputs = [[
- ",".join(x for x in column_names),
- ",,,,",
- "0,0,0.0,0.0,0.0",
- "0,%s,2.0,3e50,rabbit" % str_int32_max,
- ",,,,",
- ]]
- expected_output = [[0, 0, 0, 0, b""], [0, 0, 0, 0, b"0.0"],
- [0, 2**33, 2.0, 3e50, b"rabbit"], [0, 0, 0, 0, b""]]
- label = "col0"
-
- self._test_dataset(
- inputs,
- expected_output=expected_output,
- expected_keys=column_names,
- column_names=column_names,
- label_name=label,
- batch_size=1,
- num_epochs=1,
- shuffle=False,
- header=True,
- )
-
- def testMakeCSVDataset_withSelectCols(self):
- record_defaults = [
- constant_op.constant([], dtypes.int32),
- constant_op.constant([], dtypes.int64),
- constant_op.constant([], dtypes.float32),
- constant_op.constant([], dtypes.float64),
- constant_op.constant([], dtypes.string)
- ]
- column_names = ["col%d" % i for i in range(5)]
- str_int32_max = str(2**33)
- inputs = [[
- ",".join(x for x in column_names),
- "0,%s,2.0,3e50,rabbit" % str_int32_max
- ]]
- expected_output = [[0, 2**33, 2.0, 3e50, b"rabbit"]]
-
- select_cols = [1, 3, 4]
- self._test_dataset(
- inputs,
- expected_output=[[x[i] for i in select_cols] for x in expected_output],
- expected_keys=[column_names[i] for i in select_cols],
- column_names=column_names,
- column_defaults=[record_defaults[i] for i in select_cols],
- batch_size=1,
- num_epochs=1,
- shuffle=False,
- header=True,
- select_columns=select_cols,
- )
-
- # Can still do inference without provided defaults
- self._test_dataset(
- inputs,
- expected_output=[[x[i] for i in select_cols] for x in expected_output],
- expected_keys=[column_names[i] for i in select_cols],
- column_names=column_names,
- batch_size=1,
- num_epochs=1,
- shuffle=False,
- header=True,
- select_columns=select_cols,
- )
-
- # Can still do column name inference
- self._test_dataset(
- inputs,
- expected_output=[[x[i] for i in select_cols] for x in expected_output],
- expected_keys=[column_names[i] for i in select_cols],
- batch_size=1,
- num_epochs=1,
- shuffle=False,
- header=True,
- select_columns=select_cols,
- )
-
- # Can specify column names instead of indices
- self._test_dataset(
- inputs,
- expected_output=[[x[i] for i in select_cols] for x in expected_output],
- expected_keys=[column_names[i] for i in select_cols],
- column_names=column_names,
- batch_size=1,
- num_epochs=1,
- shuffle=False,
- header=True,
- select_columns=[column_names[i] for i in select_cols],
- )
-
- def testMakeCSVDataset_withSelectColsError(self):
- record_defaults = [
- constant_op.constant([], dtypes.int32),
- constant_op.constant([], dtypes.int64),
- constant_op.constant([], dtypes.float32),
- constant_op.constant([], dtypes.float64),
- constant_op.constant([], dtypes.string)
- ]
- column_names = ["col%d" % i for i in range(5)]
- str_int32_max = str(2**33)
- inputs = [[
- ",".join(x for x in column_names),
- "0,%s,2.0,3e50,rabbit" % str_int32_max
- ]]
-
- select_cols = [1, 3, 4]
- filenames = self._setup_files(inputs)
-
- with self.assertRaises(ValueError):
- # Mismatch in number of defaults and number of columns selected,
- # should raise an error
- self._make_csv_dataset(
- filenames,
- batch_size=1,
- column_defaults=record_defaults,
- column_names=column_names,
- select_columns=select_cols)
-
- with self.assertRaises(ValueError):
- # Invalid column name should raise an error
- self._make_csv_dataset(
- filenames,
- batch_size=1,
- column_defaults=[[0]],
- column_names=column_names,
- label_name=None,
- select_columns=["invalid_col_name"])
-
- def testMakeCSVDataset_withShuffle(self):
- record_defaults = [
- constant_op.constant([], dtypes.int32),
- constant_op.constant([], dtypes.int64),
- constant_op.constant([], dtypes.float32),
- constant_op.constant([], dtypes.float64),
- constant_op.constant([], dtypes.string)
- ]
-
- def str_series(st):
- return ",".join(str(i) for i in range(st, st + 5))
-
- column_names = ["col%d" % i for i in range(5)]
- inputs = [
- [",".join(x for x in column_names)
- ] + [str_series(5 * i) for i in range(15)],
- [",".join(x for x in column_names)] +
- [str_series(5 * i) for i in range(15, 20)],
- ]
-
- filenames = self._setup_files(inputs)
-
- total_records = 20
- for batch_size in [1, 2]:
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- # Test that shuffling with the same seed produces the same result
- dataset1 = self._make_csv_dataset(
- filenames,
- column_defaults=record_defaults,
- column_names=column_names,
- batch_size=batch_size,
- header=True,
- shuffle=True,
- shuffle_seed=5,
- num_epochs=2,
- )
- dataset2 = self._make_csv_dataset(
- filenames,
- column_defaults=record_defaults,
- column_names=column_names,
- batch_size=batch_size,
- header=True,
- shuffle=True,
- shuffle_seed=5,
- num_epochs=2,
- )
- outputs1 = dataset1.make_one_shot_iterator().get_next()
- outputs2 = dataset2.make_one_shot_iterator().get_next()
- for _ in range(total_records // batch_size):
- batch1 = nest.flatten(sess.run(outputs1))
- batch2 = nest.flatten(sess.run(outputs2))
- for i in range(len(batch1)):
- self.assertAllEqual(batch1[i], batch2[i])
-
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- # Test that shuffling with a different seed produces different results
- dataset1 = self._make_csv_dataset(
- filenames,
- column_defaults=record_defaults,
- column_names=column_names,
- batch_size=batch_size,
- header=True,
- shuffle=True,
- shuffle_seed=5,
- num_epochs=2,
- )
- dataset2 = self._make_csv_dataset(
- filenames,
- column_defaults=record_defaults,
- column_names=column_names,
- batch_size=batch_size,
- header=True,
- shuffle=True,
- shuffle_seed=6,
- num_epochs=2,
- )
- outputs1 = dataset1.make_one_shot_iterator().get_next()
- outputs2 = dataset2.make_one_shot_iterator().get_next()
- all_equal = False
- for _ in range(total_records // batch_size):
- batch1 = nest.flatten(sess.run(outputs1))
- batch2 = nest.flatten(sess.run(outputs2))
- for i in range(len(batch1)):
- all_equal = all_equal and np.array_equal(batch1[i], batch2[i])
- self.assertFalse(all_equal)
-
- def testIndefiniteRepeatShapeInference(self):
- column_names = ["col%d" % i for i in range(5)]
- inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [
- ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19"
- ]]
- filenames = self._setup_files(inputs)
- dataset = self._make_csv_dataset(filenames, batch_size=32, num_epochs=None)
- for shape in nest.flatten(dataset.output_shapes):
- self.assertEqual(32, shape[0])
-
-
-class MakeTFRecordDatasetTest(
- reader_dataset_ops_test_base.TFRecordDatasetTestBase):
-
- def _interleave(self, iterators, cycle_length):
- pending_iterators = iterators
- open_iterators = []
- num_open = 0
- for i in range(cycle_length):
- if pending_iterators:
- open_iterators.append(pending_iterators.pop(0))
- num_open += 1
-
- while num_open:
- for i in range(min(cycle_length, len(open_iterators))):
- if open_iterators[i] is None:
- continue
- try:
- yield next(open_iterators[i])
- except StopIteration:
- if pending_iterators:
- open_iterators[i] = pending_iterators.pop(0)
- else:
- open_iterators[i] = None
- num_open -= 1
-
- def _next_expected_batch(self,
- file_indices,
- batch_size,
- num_epochs,
- cycle_length,
- drop_final_batch,
- use_parser_fn):
-
- def _next_record(file_indices):
- for j in file_indices:
- for i in range(self._num_records):
- yield j, i
-
- def _next_record_interleaved(file_indices, cycle_length):
- return self._interleave([_next_record([i]) for i in file_indices],
- cycle_length)
-
- record_batch = []
- batch_index = 0
- for _ in range(num_epochs):
- if cycle_length == 1:
- next_records = _next_record(file_indices)
- else:
- next_records = _next_record_interleaved(file_indices, cycle_length)
- for f, r in next_records:
- record = self._record(f, r)
- if use_parser_fn:
- record = record[1:]
- record_batch.append(record)
- batch_index += 1
- if len(record_batch) == batch_size:
- yield record_batch
- record_batch = []
- batch_index = 0
- if record_batch and not drop_final_batch:
- yield record_batch
-
- def _verify_records(self,
- sess,
- outputs,
- batch_size,
- file_index,
- num_epochs,
- interleave_cycle_length,
- drop_final_batch,
- use_parser_fn):
- if file_index is not None:
- file_indices = [file_index]
- else:
- file_indices = range(self._num_files)
-
- for expected_batch in self._next_expected_batch(
- file_indices, batch_size, num_epochs, interleave_cycle_length,
- drop_final_batch, use_parser_fn):
- actual_batch = sess.run(outputs)
- self.assertAllEqual(expected_batch, actual_batch)
-
- def _read_test(self, batch_size, num_epochs, file_index=None,
- num_parallel_reads=1, drop_final_batch=False, parser_fn=False):
- if file_index is None:
- file_pattern = self.test_filenames
- else:
- file_pattern = self.test_filenames[file_index]
-
- if parser_fn:
- fn = lambda x: string_ops.substr(x, 1, 999)
- else:
- fn = None
-
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- outputs = readers.make_tf_record_dataset(
- file_pattern=file_pattern,
- num_epochs=num_epochs,
- batch_size=batch_size,
- parser_fn=fn,
- num_parallel_reads=num_parallel_reads,
- drop_final_batch=drop_final_batch,
- shuffle=False).make_one_shot_iterator().get_next()
- self._verify_records(
- sess, outputs, batch_size, file_index, num_epochs=num_epochs,
- interleave_cycle_length=num_parallel_reads,
- drop_final_batch=drop_final_batch, use_parser_fn=parser_fn)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(outputs)
-
- def testRead(self):
- for batch_size in [1, 2]:
- for num_epochs in [1, 3]:
- # Basic test: read from file 0.
- self._read_test(batch_size, num_epochs, 0)
-
- # Basic test: read from file 1.
- self._read_test(batch_size, num_epochs, 1)
-
- # Basic test: read from both files.
- self._read_test(batch_size, num_epochs)
-
- # Basic test: read from both files, with parallel reads.
- self._read_test(batch_size, num_epochs, num_parallel_reads=8)
-
- def testDropFinalBatch(self):
- for batch_size in [1, 2, 10]:
- for num_epochs in [1, 3]:
- # Read from file 0.
- self._read_test(batch_size, num_epochs, 0, drop_final_batch=True)
-
- # Read from both files.
- self._read_test(batch_size, num_epochs, drop_final_batch=True)
-
- # Read from both files, with parallel reads.
- self._read_test(batch_size, num_epochs, num_parallel_reads=8,
- drop_final_batch=True)
-
- def testParserFn(self):
- for batch_size in [1, 2]:
- for num_epochs in [1, 3]:
- for drop_final_batch in [False, True]:
- self._read_test(batch_size, num_epochs, parser_fn=True,
- drop_final_batch=drop_final_batch)
- self._read_test(batch_size, num_epochs, num_parallel_reads=8,
- parser_fn=True, drop_final_batch=drop_final_batch)
-
- def _shuffle_test(self, batch_size, num_epochs, num_parallel_reads=1,
- seed=None):
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- dataset = readers.make_tf_record_dataset(
- file_pattern=self.test_filenames,
- num_epochs=num_epochs,
- batch_size=batch_size,
- num_parallel_reads=num_parallel_reads,
- shuffle=True,
- shuffle_seed=seed)
- iterator = dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- sess.run(iterator.initializer)
- first_batches = []
- try:
- while True:
- first_batches.append(sess.run(next_element))
- except errors.OutOfRangeError:
- pass
-
- sess.run(iterator.initializer)
- second_batches = []
- try:
- while True:
- second_batches.append(sess.run(next_element))
- except errors.OutOfRangeError:
- pass
-
- self.assertEqual(len(first_batches), len(second_batches))
- if seed is not None:
- # if you set a seed, should get the same results
- for i in range(len(first_batches)):
- self.assertAllEqual(first_batches[i], second_batches[i])
-
- expected = []
- for f in range(self._num_files):
- for r in range(self._num_records):
- expected.extend([self._record(f, r)] * num_epochs)
-
- for batches in (first_batches, second_batches):
- actual = []
- for b in batches:
- actual.extend(b)
- self.assertAllEqual(sorted(expected), sorted(actual))
-
- def testShuffle(self):
- for batch_size in [1, 2]:
- for num_epochs in [1, 3]:
- for num_parallel_reads in [1, 2]:
- # Test that all expected elements are produced
- self._shuffle_test(batch_size, num_epochs, num_parallel_reads)
- # Test that elements are produced in a consistent order if
- # you specify a seed.
- self._shuffle_test(batch_size, num_epochs, num_parallel_reads,
- seed=21345)
-
- def testIndefiniteRepeatShapeInference(self):
- dataset = readers.make_tf_record_dataset(
- file_pattern=self.test_filenames, num_epochs=None, batch_size=32)
- for shape in nest.flatten(dataset.output_shapes):
- self.assertEqual(32, shape[0])
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py
deleted file mode 100644
index f443b5501b..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py
+++ /dev/null
@@ -1,353 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Base class for testing reader datasets."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import gzip
-import os
-import zlib
-
-from tensorflow.contrib.data.python.ops import readers
-from tensorflow.core.example import example_pb2
-from tensorflow.core.example import feature_pb2
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import iterator_ops
-from tensorflow.python.data.ops import readers as core_readers
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.lib.io import python_io
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import parsing_ops
-from tensorflow.python.util import compat
-
-
-class FixedLengthRecordDatasetTestBase(test_base.DatasetTestBase):
- """Base class for setting up and testing FixedLengthRecordDataset."""
-
- def setUp(self):
- super(FixedLengthRecordDatasetTestBase, self).setUp()
- self._num_files = 2
- self._num_records = 7
- self._header_bytes = 5
- self._record_bytes = 3
- self._footer_bytes = 2
-
- def _record(self, f, r):
- return compat.as_bytes(str(f * 2 + r) * self._record_bytes)
-
- def _createFiles(self):
- filenames = []
- for i in range(self._num_files):
- fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i)
- filenames.append(fn)
- with open(fn, "wb") as f:
- f.write(b"H" * self._header_bytes)
- for j in range(self._num_records):
- f.write(self._record(i, j))
- f.write(b"F" * self._footer_bytes)
- return filenames
-
-
-class ReadBatchFeaturesTestBase(test_base.DatasetTestBase):
- """Base class for setting up and testing `make_batched_feature_dataset`."""
-
- def setUp(self):
- super(ReadBatchFeaturesTestBase, self).setUp()
- self._num_files = 2
- self._num_records = 7
- self.test_filenames = self._createFiles()
-
- def make_batch_feature(self,
- filenames,
- num_epochs,
- batch_size,
- label_key=None,
- reader_num_threads=1,
- parser_num_threads=1,
- shuffle=False,
- shuffle_seed=None,
- drop_final_batch=False):
- self.filenames = filenames
- self.num_epochs = num_epochs
- self.batch_size = batch_size
-
- return readers.make_batched_features_dataset(
- file_pattern=self.filenames,
- batch_size=self.batch_size,
- features={
- "file": parsing_ops.FixedLenFeature([], dtypes.int64),
- "record": parsing_ops.FixedLenFeature([], dtypes.int64),
- "keywords": parsing_ops.VarLenFeature(dtypes.string),
- "label": parsing_ops.FixedLenFeature([], dtypes.string),
- },
- label_key=label_key,
- reader=core_readers.TFRecordDataset,
- num_epochs=self.num_epochs,
- shuffle=shuffle,
- shuffle_seed=shuffle_seed,
- reader_num_threads=reader_num_threads,
- parser_num_threads=parser_num_threads,
- drop_final_batch=drop_final_batch)
-
- def _record(self, f, r, l):
- example = example_pb2.Example(
- features=feature_pb2.Features(
- feature={
- "file":
- feature_pb2.Feature(
- int64_list=feature_pb2.Int64List(value=[f])),
- "record":
- feature_pb2.Feature(
- int64_list=feature_pb2.Int64List(value=[r])),
- "keywords":
- feature_pb2.Feature(
- bytes_list=feature_pb2.BytesList(
- value=self._get_keywords(f, r))),
- "label":
- feature_pb2.Feature(
- bytes_list=feature_pb2.BytesList(
- value=[compat.as_bytes(l)]))
- }))
- return example.SerializeToString()
-
- def _get_keywords(self, f, r):
- num_keywords = 1 + (f + r) % 2
- keywords = []
- for index in range(num_keywords):
- keywords.append(compat.as_bytes("keyword%d" % index))
- return keywords
-
- def _sum_keywords(self, num_files):
- sum_keywords = 0
- for i in range(num_files):
- for j in range(self._num_records):
- sum_keywords += 1 + (i + j) % 2
- return sum_keywords
-
- def _createFiles(self):
- filenames = []
- for i in range(self._num_files):
- fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i)
- filenames.append(fn)
- writer = python_io.TFRecordWriter(fn)
- for j in range(self._num_records):
- writer.write(self._record(i, j, "fake-label"))
- writer.close()
- return filenames
-
- def _run_actual_batch(self, outputs, sess, label_key_provided=False):
- if label_key_provided:
- # outputs would be a tuple of (feature dict, label)
- label_op = outputs[1]
- features_op = outputs[0]
- else:
- features_op = outputs
- label_op = features_op["label"]
- file_op = features_op["file"]
- keywords_indices_op = features_op["keywords"].indices
- keywords_values_op = features_op["keywords"].values
- keywords_dense_shape_op = features_op["keywords"].dense_shape
- record_op = features_op["record"]
- return sess.run([
- file_op, keywords_indices_op, keywords_values_op,
- keywords_dense_shape_op, record_op, label_op
- ])
-
- def _next_actual_batch(self, sess, label_key_provided=False):
- return self._run_actual_batch(self.outputs, sess, label_key_provided)
-
- def _interleave(self, iterators, cycle_length):
- pending_iterators = iterators
- open_iterators = []
- num_open = 0
- for i in range(cycle_length):
- if pending_iterators:
- open_iterators.append(pending_iterators.pop(0))
- num_open += 1
-
- while num_open:
- for i in range(min(cycle_length, len(open_iterators))):
- if open_iterators[i] is None:
- continue
- try:
- yield next(open_iterators[i])
- except StopIteration:
- if pending_iterators:
- open_iterators[i] = pending_iterators.pop(0)
- else:
- open_iterators[i] = None
- num_open -= 1
-
- def _next_expected_batch(self,
- file_indices,
- batch_size,
- num_epochs,
- cycle_length=1):
-
- def _next_record(file_indices):
- for j in file_indices:
- for i in range(self._num_records):
- yield j, i, compat.as_bytes("fake-label")
-
- def _next_record_interleaved(file_indices, cycle_length):
- return self._interleave([_next_record([i]) for i in file_indices],
- cycle_length)
-
- file_batch = []
- keywords_batch_indices = []
- keywords_batch_values = []
- keywords_batch_max_len = 0
- record_batch = []
- batch_index = 0
- label_batch = []
- for _ in range(num_epochs):
- if cycle_length == 1:
- next_records = _next_record(file_indices)
- else:
- next_records = _next_record_interleaved(file_indices, cycle_length)
- for record in next_records:
- f = record[0]
- r = record[1]
- label_batch.append(record[2])
- file_batch.append(f)
- record_batch.append(r)
- keywords = self._get_keywords(f, r)
- keywords_batch_values.extend(keywords)
- keywords_batch_indices.extend(
- [[batch_index, i] for i in range(len(keywords))])
- batch_index += 1
- keywords_batch_max_len = max(keywords_batch_max_len, len(keywords))
- if len(file_batch) == batch_size:
- yield [
- file_batch, keywords_batch_indices, keywords_batch_values,
- [batch_size, keywords_batch_max_len], record_batch, label_batch
- ]
- file_batch = []
- keywords_batch_indices = []
- keywords_batch_values = []
- keywords_batch_max_len = 0
- record_batch = []
- batch_index = 0
- label_batch = []
- if file_batch:
- yield [
- file_batch, keywords_batch_indices, keywords_batch_values,
- [len(file_batch), keywords_batch_max_len], record_batch, label_batch
- ]
-
- def verify_records(self,
- sess,
- batch_size,
- file_index=None,
- num_epochs=1,
- label_key_provided=False,
- interleave_cycle_length=1):
- if file_index is not None:
- file_indices = [file_index]
- else:
- file_indices = range(self._num_files)
-
- for expected_batch in self._next_expected_batch(
- file_indices,
- batch_size,
- num_epochs,
- cycle_length=interleave_cycle_length):
- actual_batch = self._next_actual_batch(
- sess, label_key_provided=label_key_provided)
- for i in range(len(expected_batch)):
- self.assertAllEqual(expected_batch[i], actual_batch[i])
-
-
-class TextLineDatasetTestBase(test_base.DatasetTestBase):
- """Base class for setting up and testing TextLineDataset."""
-
- def _lineText(self, f, l):
- return compat.as_bytes("%d: %d" % (f, l))
-
- def _createFiles(self,
- num_files,
- num_lines,
- crlf=False,
- compression_type=None):
- filenames = []
- for i in range(num_files):
- fn = os.path.join(self.get_temp_dir(), "text_line.%d.txt" % i)
- filenames.append(fn)
- contents = []
- for j in range(num_lines):
- contents.append(self._lineText(i, j))
- # Always include a newline after the record unless it is
- # at the end of the file, in which case we include it
- if j + 1 != num_lines or i == 0:
- contents.append(b"\r\n" if crlf else b"\n")
- contents = b"".join(contents)
-
- if not compression_type:
- with open(fn, "wb") as f:
- f.write(contents)
- elif compression_type == "GZIP":
- with gzip.GzipFile(fn, "wb") as f:
- f.write(contents)
- elif compression_type == "ZLIB":
- contents = zlib.compress(contents)
- with open(fn, "wb") as f:
- f.write(contents)
- else:
- raise ValueError("Unsupported compression_type", compression_type)
-
- return filenames
-
-
-class TFRecordDatasetTestBase(test_base.DatasetTestBase):
- """Base class for setting up and testing TFRecordDataset."""
-
- def setUp(self):
- super(TFRecordDatasetTestBase, self).setUp()
- self._num_files = 2
- self._num_records = 7
-
- self.test_filenames = self._createFiles()
-
- self.filenames = array_ops.placeholder(dtypes.string, shape=[None])
- self.num_epochs = array_ops.placeholder_with_default(
- constant_op.constant(1, dtypes.int64), shape=[])
- self.compression_type = array_ops.placeholder_with_default("", shape=[])
- self.batch_size = array_ops.placeholder(dtypes.int64, shape=[])
-
- repeat_dataset = core_readers.TFRecordDataset(
- self.filenames, self.compression_type).repeat(self.num_epochs)
- batch_dataset = repeat_dataset.batch(self.batch_size)
-
- iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types)
- self.init_op = iterator.make_initializer(repeat_dataset)
- self.init_batch_op = iterator.make_initializer(batch_dataset)
- self.get_next = iterator.get_next()
-
- def _record(self, f, r):
- return compat.as_bytes("Record %d of file %d" % (r, f))
-
- def _createFiles(self):
- filenames = []
- for i in range(self._num_files):
- fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i)
- filenames.append(fn)
- writer = python_io.TFRecordWriter(fn)
- for j in range(self._num_records):
- writer.write(self._record(i, j))
- writer.close()
- return filenames
diff --git a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py b/tensorflow/contrib/data/python/kernel_tests/reduce_dataset_test.py
index cc22ea1df7..e7281d5318 100644
--- a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/reduce_dataset_test.py
@@ -25,49 +25,11 @@ from tensorflow.contrib.data.python.ops import grouping
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class GetSingleElementTest(test_base.DatasetTestBase, parameterized.TestCase):
-
- @parameterized.named_parameters(
- ("Zero", 0, 1),
- ("Five", 5, 1),
- ("Ten", 10, 1),
- ("Empty", 100, 1, errors.InvalidArgumentError, "Dataset was empty."),
- ("MoreThanOne", 0, 2, errors.InvalidArgumentError,
- "Dataset had more than one element."),
- )
- def testGetSingleElement(self, skip, take, error=None, error_msg=None):
- skip_t = array_ops.placeholder(dtypes.int64, shape=[])
- take_t = array_ops.placeholder(dtypes.int64, shape=[])
-
- def make_sparse(x):
- x_1d = array_ops.reshape(x, [1])
- x_2d = array_ops.reshape(x, [1, 1])
- return sparse_tensor.SparseTensor(x_2d, x_1d, x_1d)
-
- dataset = dataset_ops.Dataset.range(100).skip(skip_t).map(
- lambda x: (x * x, make_sparse(x))).take(take_t)
- element = get_single_element.get_single_element(dataset)
-
- with self.cached_session() as sess:
- if error is None:
- dense_val, sparse_val = sess.run(
- element, feed_dict={
- skip_t: skip,
- take_t: take
- })
- self.assertEqual(skip * skip, dense_val)
- self.assertAllEqual([[skip]], sparse_val.indices)
- self.assertAllEqual([skip], sparse_val.values)
- self.assertAllEqual([skip], sparse_val.dense_shape)
- else:
- with self.assertRaisesRegexp(error, error_msg):
- sess.run(element, feed_dict={skip_t: skip, take_t: take})
+class ReduceDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.named_parameters(
("SumZero", 0),
diff --git a/tensorflow/contrib/data/python/kernel_tests/resample_test.py b/tensorflow/contrib/data/python/kernel_tests/resample_test.py
deleted file mode 100644
index 32474bd411..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/resample_test.py
+++ /dev/null
@@ -1,182 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import time
-
-from absl.testing import parameterized
-import numpy as np
-from six.moves import xrange # pylint: disable=redefined-builtin
-
-from tensorflow.contrib.data.python.ops import resampling
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import random_ops
-from tensorflow.python.ops import string_ops
-from tensorflow.python.platform import test
-from tensorflow.python.util import compat
-
-
-def _time_resampling(
- test_obj, data_np, target_dist, init_dist, num_to_sample):
- dataset = dataset_ops.Dataset.from_tensor_slices(data_np).repeat()
-
- # Reshape distribution via rejection sampling.
- dataset = dataset.apply(
- resampling.rejection_resample(
- class_func=lambda x: x,
- target_dist=target_dist,
- initial_dist=init_dist,
- seed=142))
-
- get_next = dataset.make_one_shot_iterator().get_next()
-
- with test_obj.test_session() as sess:
- start_time = time.time()
- for _ in xrange(num_to_sample):
- sess.run(get_next)
- end_time = time.time()
-
- return end_time - start_time
-
-
-class ResampleTest(test_base.DatasetTestBase, parameterized.TestCase):
-
- @parameterized.named_parameters(
- ("InitialDistributionKnown", True),
- ("InitialDistributionUnknown", False))
- def testDistribution(self, initial_known):
- classes = np.random.randint(5, size=(20000,)) # Uniformly sampled
- target_dist = [0.9, 0.05, 0.05, 0.0, 0.0]
- initial_dist = [0.2] * 5 if initial_known else None
- classes = math_ops.to_int64(classes) # needed for Windows build.
- dataset = dataset_ops.Dataset.from_tensor_slices(classes).shuffle(
- 200, seed=21).map(lambda c: (c, string_ops.as_string(c))).repeat()
-
- get_next = dataset.apply(
- resampling.rejection_resample(
- target_dist=target_dist,
- initial_dist=initial_dist,
- class_func=lambda c, _: c,
- seed=27)).make_one_shot_iterator().get_next()
-
- with self.cached_session() as sess:
- returned = []
- while len(returned) < 4000:
- returned.append(sess.run(get_next))
-
- returned_classes, returned_classes_and_data = zip(*returned)
- _, returned_data = zip(*returned_classes_and_data)
- self.assertAllEqual([compat.as_bytes(str(c))
- for c in returned_classes], returned_data)
- total_returned = len(returned_classes)
- class_counts = np.array([
- len([True for v in returned_classes if v == c])
- for c in range(5)])
- returned_dist = class_counts / total_returned
- self.assertAllClose(target_dist, returned_dist, atol=1e-2)
-
- @parameterized.named_parameters(
- ("OnlyInitial", True),
- ("NotInitial", False))
- def testEdgeCasesSampleFromInitialDataset(self, only_initial_dist):
- init_dist = [0.5, 0.5]
- target_dist = [0.5, 0.5] if only_initial_dist else [0.0, 1.0]
- num_classes = len(init_dist)
- # We don't need many samples to test that this works.
- num_samples = 100
- data_np = np.random.choice(num_classes, num_samples, p=init_dist)
-
- dataset = dataset_ops.Dataset.from_tensor_slices(data_np)
-
- # Reshape distribution.
- dataset = dataset.apply(
- resampling.rejection_resample(
- class_func=lambda x: x,
- target_dist=target_dist,
- initial_dist=init_dist))
-
- get_next = dataset.make_one_shot_iterator().get_next()
-
- with self.cached_session() as sess:
- returned = []
- with self.assertRaises(errors.OutOfRangeError):
- while True:
- returned.append(sess.run(get_next))
-
- def testRandomClasses(self):
- init_dist = [0.25, 0.25, 0.25, 0.25]
- target_dist = [0.0, 0.0, 0.0, 1.0]
- num_classes = len(init_dist)
- # We don't need many samples to test a dirac-delta target distribution.
- num_samples = 100
- data_np = np.random.choice(num_classes, num_samples, p=init_dist)
-
- dataset = dataset_ops.Dataset.from_tensor_slices(data_np)
-
- # Apply a random mapping that preserves the data distribution.
- def _remap_fn(_):
- return math_ops.cast(random_ops.random_uniform([1]) * num_classes,
- dtypes.int32)[0]
- dataset = dataset.map(_remap_fn)
-
- # Reshape distribution.
- dataset = dataset.apply(
- resampling.rejection_resample(
- class_func=lambda x: x,
- target_dist=target_dist,
- initial_dist=init_dist))
-
- get_next = dataset.make_one_shot_iterator().get_next()
-
- with self.cached_session() as sess:
- returned = []
- with self.assertRaises(errors.OutOfRangeError):
- while True:
- returned.append(sess.run(get_next))
-
- classes, _ = zip(*returned)
- bincount = np.bincount(
- np.array(classes),
- minlength=num_classes).astype(np.float32) / len(classes)
-
- self.assertAllClose(target_dist, bincount, atol=1e-2)
-
-
-class ResampleDatasetBenchmark(test.Benchmark):
-
- def benchmarkResamplePerformance(self):
- init_dist = [0.25, 0.25, 0.25, 0.25]
- target_dist = [0.0, 0.0, 0.0, 1.0]
- num_classes = len(init_dist)
- # We don't need many samples to test a dirac-delta target distribution
- num_samples = 1000
- data_np = np.random.choice(num_classes, num_samples, p=init_dist)
-
- resample_time = _time_resampling(
- self, data_np, target_dist, init_dist, num_to_sample=1000)
-
- self.report_benchmark(
- iters=1000, wall_time=resample_time, name="benchmark_resample")
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py
deleted file mode 100644
index bdf80eae4e..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py
+++ /dev/null
@@ -1,172 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import itertools
-
-import numpy as np
-
-from tensorflow.contrib.data.python.ops import scan_ops
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.eager import context
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.framework import test_util
-from tensorflow.python.ops import array_ops
-from tensorflow.python.platform import test
-
-
-class ScanDatasetTest(test_base.DatasetTestBase):
-
- def _counting_dataset(self, start, scan_fn):
- return dataset_ops.Dataset.from_tensors(0).repeat().apply(
- scan_ops.scan(start, scan_fn))
-
- def testCount(self):
- def make_scan_fn(step):
- return lambda state, _: (state + step, state)
-
- start = array_ops.placeholder(dtypes.int32, shape=[])
- step = array_ops.placeholder(dtypes.int32, shape=[])
- take = array_ops.placeholder(dtypes.int64, shape=[])
- iterator = self._counting_dataset(
- start, make_scan_fn(step)).take(take).make_initializable_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
-
- for start_val, step_val, take_val in [(0, 1, 10), (0, 1, 0), (10, 1, 10),
- (10, 2, 10), (10, -1, 10),
- (10, -2, 10)]:
- sess.run(iterator.initializer,
- feed_dict={start: start_val, step: step_val, take: take_val})
- for expected, _ in zip(
- itertools.count(start_val, step_val), range(take_val)):
- self.assertEqual(expected, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- @test_util.run_in_graph_and_eager_modes
- def testFibonacci(self):
- iterator = dataset_ops.Dataset.from_tensors(1).repeat(None).apply(
- scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1]))
- ).make_one_shot_iterator()
-
- if context.executing_eagerly():
- next_element = iterator.get_next
- else:
- get_next = iterator.get_next()
- next_element = lambda: get_next
-
- self.assertEqual(1, self.evaluate(next_element()))
- self.assertEqual(1, self.evaluate(next_element()))
- self.assertEqual(2, self.evaluate(next_element()))
- self.assertEqual(3, self.evaluate(next_element()))
- self.assertEqual(5, self.evaluate(next_element()))
- self.assertEqual(8, self.evaluate(next_element()))
-
- def testSparseCount(self):
- def _sparse(i):
- return sparse_tensor.SparseTensorValue(
- indices=np.array([[0, 0]]),
- values=(i * np.array([1])),
- dense_shape=np.array([1, 1]))
-
- def make_scan_fn(step):
- return lambda state, _: (_sparse(state.values[0] + step), state)
-
- start = array_ops.placeholder(dtypes.int32, shape=[])
- step = array_ops.placeholder(dtypes.int32, shape=[])
- take = array_ops.placeholder(dtypes.int64, shape=[])
- iterator = self._counting_dataset(
- _sparse(start),
- make_scan_fn(step)).take(take).make_initializable_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
-
- for start_val, step_val, take_val in [(0, 1, 10), (0, 1, 0), (10, 1, 10),
- (10, 2, 10), (10, -1, 10),
- (10, -2, 10)]:
- sess.run(iterator.initializer,
- feed_dict={start: start_val, step: step_val, take: take_val})
- for expected, _ in zip(
- itertools.count(start_val, step_val), range(take_val)):
- self.assertEqual(expected, sess.run(next_element).values[0])
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testChangingStateShape(self):
- # Test the fixed-point shape invariant calculations: start with
- # initial values with known shapes, and use a scan function that
- # changes the size of the state on each element.
- def _scan_fn(state, input_value):
- # Statically known rank, but dynamic length.
- ret_longer_vector = array_ops.concat([state[0], state[0]], 0)
- # Statically unknown rank.
- ret_larger_rank = array_ops.expand_dims(state[1], 0)
- return (ret_longer_vector, ret_larger_rank), (state, input_value)
-
- dataset = dataset_ops.Dataset.from_tensors(0).repeat(5).apply(
- scan_ops.scan(([0], 1), _scan_fn))
- self.assertEqual([None], dataset.output_shapes[0][0].as_list())
- self.assertIs(None, dataset.output_shapes[0][1].ndims)
- self.assertEqual([], dataset.output_shapes[1].as_list())
-
- iterator = dataset.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- for i in range(5):
- (longer_vector_val, larger_rank_val), _ = sess.run(next_element)
- self.assertAllEqual([0] * (2**i), longer_vector_val)
- self.assertAllEqual(np.array(1, ndmin=i), larger_rank_val)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testIncorrectStateType(self):
-
- def _scan_fn(state, _):
- return constant_op.constant(1, dtype=dtypes.int64), state
-
- dataset = dataset_ops.Dataset.range(10)
- with self.assertRaisesRegexp(
- TypeError,
- "The element types for the new state must match the initial state."):
- dataset.apply(
- scan_ops.scan(constant_op.constant(1, dtype=dtypes.int32), _scan_fn))
-
- def testIncorrectReturnType(self):
-
- def _scan_fn(unused_state, unused_input_value):
- return constant_op.constant(1, dtype=dtypes.int64)
-
- dataset = dataset_ops.Dataset.range(10)
- with self.assertRaisesRegexp(
- TypeError,
- "The scan function must return a pair comprising the new state and the "
- "output value."):
- dataset.apply(
- scan_ops.scan(constant_op.constant(1, dtype=dtypes.int32), _scan_fn))
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD
deleted file mode 100644
index aa89674c6e..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD
+++ /dev/null
@@ -1,555 +0,0 @@
-package(default_visibility = ["//tensorflow:internal"])
-
-licenses(["notice"]) # Apache 2.0
-
-exports_files(["LICENSE"])
-
-load("//tensorflow:tensorflow.bzl", "py_test")
-
-py_library(
- name = "dataset_serialization_test_base",
- srcs = [
- "dataset_serialization_test_base.py",
- ],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/data/python/ops:iterator_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:lookup_ops",
- "//tensorflow/python:platform",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python:training",
- "//tensorflow/python:util",
- "//tensorflow/python:variables",
- "//tensorflow/python/data/ops:iterator_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "batch_dataset_serialization_test",
- size = "medium",
- srcs = ["batch_dataset_serialization_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:batching",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "cache_dataset_serialization_test",
- size = "small",
- srcs = ["cache_dataset_serialization_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:errors",
- "//tensorflow/python/data/ops:dataset_ops",
- "@absl_py//absl/testing:parameterized",
- ],
-)
-
-py_test(
- name = "concatenate_dataset_serialization_test",
- size = "small",
- srcs = ["concatenate_dataset_serialization_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "csv_dataset_serialization_test",
- size = "small",
- srcs = ["csv_dataset_serialization_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:readers",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_ops",
- ],
-)
-
-py_test(
- name = "dataset_constructor_serialization_test",
- size = "medium",
- srcs = ["dataset_constructor_serialization_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "filter_dataset_serialization_test",
- size = "medium",
- srcs = ["filter_dataset_serialization_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "fixed_length_record_dataset_serialization_test",
- size = "medium",
- srcs = ["fixed_length_record_dataset_serialization_test.py"],
- shard_count = 4,
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/kernel_tests:reader_dataset_ops_test_base",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python/data/ops:readers",
- ],
-)
-
-py_test(
- name = "flat_map_dataset_serialization_test",
- size = "medium",
- srcs = ["flat_map_dataset_serialization_test.py"],
- tags = ["no_pip"],
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:function",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:random_ops",
- "//tensorflow/python:sparse_ops",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python/data/ops:dataset_ops",
- ],
-)
-
-py_test(
- name = "group_by_reducer_serialization_test",
- size = "medium",
- srcs = ["group_by_reducer_serialization_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:grouping",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "group_by_window_serialization_test",
- size = "medium",
- srcs = ["group_by_window_serialization_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:grouping",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "ignore_errors_serialization_test",
- size = "small",
- srcs = ["ignore_errors_serialization_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:error_ops",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "interleave_dataset_serialization_test",
- size = "medium",
- srcs = ["interleave_dataset_serialization_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:sparse_ops",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- "@absl_py//absl/testing:parameterized",
- ],
-)
-
-py_test(
- name = "map_and_batch_dataset_serialization_test",
- size = "medium",
- srcs = ["map_and_batch_dataset_serialization_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:batching",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:math_ops",
- "//tensorflow/python/data/ops:dataset_ops",
- ],
-)
-
-py_test(
- name = "map_dataset_serialization_test",
- size = "medium",
- srcs = ["map_dataset_serialization_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:function",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:random_ops",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "optimize_dataset_serialization_test",
- size = "small",
- srcs = ["optimize_dataset_serialization_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:optimization",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python/data/ops:dataset_ops",
- ],
-)
-
-py_test(
- name = "padded_batch_dataset_serialization_test",
- size = "medium",
- srcs = ["padded_batch_dataset_serialization_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:string_ops",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "parallel_interleave_dataset_serialization_test",
- size = "medium",
- srcs = ["parallel_interleave_dataset_serialization_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:interleave_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:sparse_ops",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "parallel_map_dataset_serialization_test",
- size = "medium",
- srcs = ["parallel_map_dataset_serialization_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:function",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:random_ops",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "parse_example_dataset_serialization_test",
- size = "medium",
- srcs = ["parse_example_dataset_serialization_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/kernel_tests:reader_dataset_ops_test_base",
- "//tensorflow/python:client_testlib",
- ],
-)
-
-py_test(
- name = "prefetch_dataset_serialization_test",
- size = "small",
- srcs = ["prefetch_dataset_serialization_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python/data/ops:dataset_ops",
- ],
-)
-
-py_test(
- name = "range_dataset_serialization_test",
- size = "small",
- srcs = ["range_dataset_serialization_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dataset_ops_gen",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:io_ops",
- "//tensorflow/python:parsing_ops",
- "//tensorflow/python:variables",
- "//tensorflow/python/data/ops:dataset_ops",
- ],
-)
-
-py_test(
- name = "sample_from_datasets_serialization_test",
- size = "medium",
- srcs = ["sample_from_datasets_serialization_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:interleave_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python/data/ops:dataset_ops",
- ],
-)
-
-py_test(
- name = "scan_dataset_serialization_test",
- size = "small",
- srcs = ["scan_dataset_serialization_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:scan_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python/data/ops:dataset_ops",
- ],
-)
-
-py_test(
- name = "sequence_dataset_serialization_test",
- size = "medium",
- srcs = ["sequence_dataset_serialization_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "serialization_integration_test",
- size = "small",
- srcs = ["serialization_integration_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- "//tensorflow/contrib/data/python/ops:iterator_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:training",
- "//tensorflow/python/data/ops:dataset_ops",
- ],
-)
-
-py_test(
- name = "shuffle_and_repeat_dataset_serialization_test",
- size = "medium",
- srcs = ["shuffle_and_repeat_dataset_serialization_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:shuffle_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python/data/ops:dataset_ops",
- ],
-)
-
-py_test(
- name = "shuffle_dataset_serialization_test",
- size = "medium",
- srcs = ["shuffle_dataset_serialization_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:iterator_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:training",
- "//tensorflow/python/data/ops:dataset_ops",
- ],
-)
-
-py_test(
- name = "sql_dataset_serialization_test",
- size = "small",
- srcs = ["sql_dataset_serialization_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/kernel_tests:sql_dataset_op_test_base",
- "//tensorflow/contrib/data/python/ops:readers",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- ],
-)
-
-py_test(
- name = "stats_dataset_serialization_test",
- size = "medium",
- srcs = ["stats_dataset_serialization_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:stats_ops",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python/data/ops:dataset_ops",
- ],
-)
-
-py_test(
- name = "textline_dataset_serialization_test",
- size = "medium",
- srcs = ["textline_dataset_serialization_test.py"],
- shard_count = 4,
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/kernel_tests:reader_dataset_ops_test_base",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python/data/ops:readers",
- ],
-)
-
-py_test(
- name = "tf_record_dataset_serialization_test",
- size = "medium",
- srcs = ["tf_record_dataset_serialization_test.py"],
- shard_count = 4,
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/kernel_tests:reader_dataset_ops_test_base",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python/data/ops:readers",
- ],
-)
-
-py_test(
- name = "unbatch_dataset_serialization_test",
- size = "medium",
- srcs = ["unbatch_dataset_serialization_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:batching",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "unique_dataset_serialization_test",
- size = "small",
- srcs = ["unique_dataset_serialization_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:unique",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python/data/ops:dataset_ops",
- ],
-)
-
-py_test(
- name = "zip_dataset_serialization_test",
- size = "small",
- srcs = ["zip_dataset_serialization_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/batch_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/batch_dataset_serialization_test.py
deleted file mode 100644
index af87d8b608..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/batch_dataset_serialization_test.py
+++ /dev/null
@@ -1,83 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the BatchDataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import batching
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.ops import array_ops
-from tensorflow.python.platform import test
-
-
-class BatchDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def build_dataset(self, multiplier=15.0, tensor_slice_len=2, batch_size=2):
- components = (
- np.arange(tensor_slice_len),
- np.array([[1, 2, 3]]) * np.arange(tensor_slice_len)[:, np.newaxis],
- np.array(multiplier) * np.arange(tensor_slice_len))
-
- return dataset_ops.Dataset.from_tensor_slices(components).batch(batch_size)
-
- def testCore(self):
- tensor_slice_len = 8
- batch_size = 2
- num_outputs = tensor_slice_len // batch_size
- self.run_core_tests(
- lambda: self.build_dataset(15.0, tensor_slice_len, batch_size),
- lambda: self.build_dataset(20.0, tensor_slice_len, batch_size),
- num_outputs)
-
- def _build_dataset_dense_to_sparse(self, components):
- return dataset_ops.Dataset.from_tensor_slices(components).map(
- lambda x: array_ops.fill([x], x)).apply(
- batching.dense_to_sparse_batch(4, [12]))
-
- def testDenseToSparseBatchDatasetCore(self):
- components = np.random.randint(5, size=(40,)).astype(np.int32)
- diff_comp = np.random.randint(2, size=(100,)).astype(np.int32)
-
- num_outputs = len(components) // 4
- self.run_core_tests(lambda: self._build_dataset_dense_to_sparse(components),
- lambda: self._build_dataset_dense_to_sparse(diff_comp),
- num_outputs)
-
- def _sparse(self, i):
- return sparse_tensor.SparseTensorValue(
- indices=[[0]], values=(i * [1]), dense_shape=[1])
-
- def _build_dataset_sparse(self, batch_size=5):
- return dataset_ops.Dataset.range(10).map(self._sparse).batch(batch_size)
-
- def testSparseCore(self):
- self.run_core_tests(self._build_dataset_sparse,
- lambda: self._build_dataset_sparse(2), 2)
-
- def _build_dataset_nested_sparse(self):
- return dataset_ops.Dataset.range(10).map(self._sparse).batch(5).batch(2)
-
- def testNestedSparseCore(self):
- self.run_core_tests(self._build_dataset_nested_sparse, None, 1)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/cache_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/cache_dataset_serialization_test.py
deleted file mode 100644
index 1b6059ccbc..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/cache_dataset_serialization_test.py
+++ /dev/null
@@ -1,253 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the CacheDataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import os
-
-from absl.testing import parameterized
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import errors
-from tensorflow.python.platform import test
-
-
-class CacheDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase,
- parameterized.TestCase):
-
- def setUp(self):
- self.range_size = 10
- self.num_repeats = 3
- self.num_outputs = self.range_size * self.num_repeats
- self.cache_file_prefix = 'test'
-
- def make_dataset_fn(self, is_memory):
- if is_memory:
- filename = ''
- else:
- filename = os.path.join(self.get_temp_dir(), self.cache_file_prefix)
-
- def ds_fn():
- return dataset_ops.Dataset.range(self.range_size).cache(filename).repeat(
- self.num_repeats)
-
- return ds_fn
-
- def expected_outputs(self):
- return list(range(self.range_size)) * self.num_repeats
-
- @parameterized.named_parameters(
- ('Memory', True),
- ('File', False),
- )
- def testCheckpointBeforeOneEpoch(self, is_memory):
- ds_fn = self.make_dataset_fn(is_memory)
-
- # Generate 5 entries from iterator and save checkpoint.
- outputs = self.gen_outputs(ds_fn, [], 5, verify_exhausted=False)
- self.assertSequenceEqual(outputs, range(5))
-
- # Restore from checkpoint and produce the rest of the elements from the
- # iterator.
- outputs.extend(
- self.gen_outputs(
- ds_fn, [],
- self.num_outputs - 5,
- ckpt_saved=True,
- verify_exhausted=False))
- self.assertSequenceEqual(outputs, self.expected_outputs())
-
- @parameterized.named_parameters(
- ('Memory', True),
- ('File', False),
- )
- def testCheckpointBeforeOneEpochThenRunFewSteps(self, is_memory):
- ds_fn = self.make_dataset_fn(is_memory)
-
- # Generate 8 entries from iterator but save checkpoint after producing 5.
- outputs = self.gen_outputs(
- ds_fn, [5], 8, verify_exhausted=False, save_checkpoint_at_end=False)
- self.assertSequenceEqual(outputs, range(8))
-
- if is_memory:
- outputs = outputs[:5]
- outputs.extend(
- self.gen_outputs(
- ds_fn, [],
- self.num_outputs - 5,
- ckpt_saved=True,
- verify_exhausted=False))
- self.assertSequenceEqual(outputs, self.expected_outputs())
- else:
- # Restoring from checkpoint and running GetNext should return
- # `AlreadExistsError` now because the lockfile already exists.
- with self.assertRaises(errors.AlreadyExistsError):
- self.gen_outputs(
- ds_fn, [],
- self.num_outputs - 5,
- ckpt_saved=True,
- verify_exhausted=False)
-
- @parameterized.named_parameters(
- ('Memory', True),
- ('File', False),
- )
- def testCheckpointAfterOneEpoch(self, is_memory):
- ds_fn = self.make_dataset_fn(is_memory)
-
- # Generate 15 entries from iterator and save checkpoint.
- outputs = self.gen_outputs(ds_fn, [], 15, verify_exhausted=False)
- self.assertSequenceEqual(outputs, list(range(10)) + list(range(5)))
-
- # Restore from checkpoint and produce the rest of the elements from the
- # iterator.
- outputs.extend(
- self.gen_outputs(
- ds_fn, [],
- self.num_outputs - 15,
- ckpt_saved=True,
- verify_exhausted=False))
- self.assertSequenceEqual(outputs, self.expected_outputs())
-
- @parameterized.named_parameters(
- ('Memory', True),
- ('File', False),
- )
- def testCheckpointAfterOneEpochThenRunFewSteps(self, is_memory):
- ds_fn = self.make_dataset_fn(is_memory)
-
- # Generate 18 entries from iterator but save checkpoint after producing 15.
- outputs = self.gen_outputs(
- ds_fn, [15], 18, verify_exhausted=False, save_checkpoint_at_end=False)
- self.assertSequenceEqual(outputs, list(range(10)) + list(range(8)))
-
- outputs = list(range(10)) + list(range(5)) + self.gen_outputs(
- ds_fn, [],
- self.num_outputs - 15,
- ckpt_saved=True,
- verify_exhausted=False)
- self.assertSequenceEqual(outputs, list(range(10)) * 3)
-
- @parameterized.named_parameters(
- ('Memory', True),
- ('File', False),
- )
- def testCheckpointBeforeOneEpochButRunCompleteEpoch(self, is_memory):
- ds_fn = self.make_dataset_fn(is_memory)
-
- # Generate 13 entries from iterator but save checkpoint after producing 5.
- outputs = self.gen_outputs(
- ds_fn, [5], 13, verify_exhausted=False, save_checkpoint_at_end=False)
- self.assertSequenceEqual(outputs, list(range(10)) + list(range(3)))
-
- # Since we ran for more than one epoch, the cache was completely written.
- # The ckpt was saved when the iterator was in cache-write mode. Test that
- # the iterator falls back to read mode after restoring if the cache has
- # been completely written.
-
- outputs = list(range(5)) + self.gen_outputs(
- ds_fn, [],
- self.num_outputs - 5,
- ckpt_saved=True,
- verify_exhausted=False)
- self.assertSequenceEqual(outputs, list(range(10)) * 3)
-
- @parameterized.named_parameters(
- ('Memory', True),
- ('File', False),
- )
- def testCheckpointUnusedWriterIterator(self, is_memory):
- ds_fn = self.make_dataset_fn(is_memory)
-
- # Checkpoint before get_next is called even once.
- outputs = self.gen_outputs(ds_fn, [], 0, verify_exhausted=False)
- self.assertSequenceEqual(outputs, [])
-
- outputs = self.gen_outputs(
- ds_fn, [], self.num_outputs, ckpt_saved=True, verify_exhausted=False)
- self.assertSequenceEqual(outputs, list(range(10)) * 3)
-
- @parameterized.named_parameters(
- ('Memory', True),
- ('File', False),
- )
- def testCheckpointUnusedMidwayWriterIterator(self, is_memory):
- ds_fn = self.make_dataset_fn(is_memory)
-
- # Produce 5 elements and checkpoint.
- outputs = self.gen_outputs(ds_fn, [], 5, verify_exhausted=False)
- self.assertSequenceEqual(outputs, range(5))
-
- # Restore from checkpoint, then produce no elements and checkpoint.
- outputs.extend(
- self.gen_outputs(ds_fn, [], 0, ckpt_saved=True, verify_exhausted=False))
- self.assertSequenceEqual(outputs, range(5))
-
- # Restore from checkpoint and produce rest of the elements.
- outputs.extend(
- self.gen_outputs(
- ds_fn, [],
- self.num_outputs - 5,
- ckpt_saved=True,
- verify_exhausted=False))
- self.assertSequenceEqual(outputs, list(range(10)) * 3)
-
- @parameterized.named_parameters(
- ('Memory', True),
- ('File', False),
- )
- def testUnusedCheckpointError(self, is_memory):
- ds_fn = self.make_dataset_fn(is_memory)
-
- # Produce 5 elements and save ckpt.
- outputs = self.gen_outputs(ds_fn, [], 5, verify_exhausted=False)
- self.assertSequenceEqual(outputs, range(5))
-
- if is_memory:
- outputs = self.gen_outputs(
- ds_fn, [], self.num_outputs, verify_exhausted=False)
- self.assertSequenceEqual(outputs, self.expected_outputs())
- else:
- # Since the complete cache has not been written, a new iterator which does
- # not restore the checkpoint will throw an error since there is a partial
- # cache shard.
- with self.assertRaises(errors.AlreadyExistsError):
- outputs = self.gen_outputs(
- ds_fn, [], self.num_outputs, verify_exhausted=False)
-
- @parameterized.named_parameters(
- ('Memory', True),
- ('File', False),
- )
- def testIgnoreCheckpointIfCacheWritten(self, is_memory):
- ds_fn = self.make_dataset_fn(is_memory)
-
- # Produce 15 elements and save ckpt. This will write the complete cache.
- outputs = self.gen_outputs(ds_fn, [], 15, verify_exhausted=False)
- self.assertSequenceEqual(outputs, list(range(10)) + list(range(5)))
-
- # Build the iterator again but do not restore from ckpt. Since the cache
- # has already been written we should be able to use it.
- outputs = self.gen_outputs(
- ds_fn, [], self.num_outputs, verify_exhausted=False)
- self.assertSequenceEqual(outputs, list(range(10)) * 3)
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/concatenate_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/concatenate_dataset_serialization_test.py
deleted file mode 100644
index 96f13d75a3..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/concatenate_dataset_serialization_test.py
+++ /dev/null
@@ -1,49 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the ConcatenateDataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.platform import test
-
-
-class ConcatenateDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_concatenate_dataset(self, var_array):
- input_components = (np.tile(np.array([[1], [2], [3], [4]]), 20),
- np.tile(np.array([[12], [13], [14], [15]]), 4))
- to_concatenate_components = (np.tile(
- np.array([[5], [6], [7], [8], [9]]), 20), var_array)
-
- return dataset_ops.Dataset.from_tensor_slices(input_components).concatenate(
- dataset_ops.Dataset.from_tensor_slices(to_concatenate_components))
-
- def testConcatenateCore(self):
- num_outputs = 9
- array = np.tile(np.array([[16], [17], [18], [19], [20]]), 15)
- diff_array = np.array([[1], [2], [3], [4], [5]])
- self.run_core_tests(lambda: self._build_concatenate_dataset(array),
- lambda: self._build_concatenate_dataset(diff_array),
- num_outputs)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/csv_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/csv_dataset_serialization_test.py
deleted file mode 100644
index 247f2046ea..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/csv_dataset_serialization_test.py
+++ /dev/null
@@ -1,73 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the CsvDataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import gzip
-import os
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import readers
-from tensorflow.python.platform import test
-
-
-class CsvDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def setUp(self):
- self._num_cols = 7
- self._num_rows = 10
- self._num_epochs = 14
- self._num_outputs = self._num_rows * self._num_epochs
-
- inputs = [
- ",".join(str(self._num_cols * j + i)
- for i in range(self._num_cols))
- for j in range(self._num_rows)
- ]
- contents = "\n".join(inputs).encode("utf-8")
-
- self._filename = os.path.join(self.get_temp_dir(), "file.csv")
- self._compressed = os.path.join(self.get_temp_dir(),
- "comp.csv") # GZip compressed
-
- with open(self._filename, "wb") as f:
- f.write(contents)
- with gzip.GzipFile(self._compressed, "wb") as f:
- f.write(contents)
-
- def ds_func(self, **kwargs):
- compression_type = kwargs.get("compression_type", None)
- if compression_type == "GZIP":
- filename = self._compressed
- elif compression_type is None:
- filename = self._filename
- else:
- raise ValueError("Invalid compression type:", compression_type)
-
- return readers.CsvDataset(filename, **kwargs).repeat(self._num_epochs)
-
- def testSerializationCore(self):
- defs = [[0]] * self._num_cols
- self.run_core_tests(
- lambda: self.ds_func(record_defaults=defs, buffer_size=2),
- lambda: self.ds_func(record_defaults=defs, buffer_size=12),
- self._num_outputs)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_constructor_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_constructor_serialization_test.py
deleted file mode 100644
index 2139b5c33d..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_constructor_serialization_test.py
+++ /dev/null
@@ -1,95 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the dataset constructors serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.platform import test
-
-
-class FromTensorsSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_tensor_dataset(self, variable_array):
- components = (variable_array, np.array([1, 2, 3]), np.array(37.0))
-
- return dataset_ops.Dataset.from_tensors(components)
-
- def testFromTensorsCore(self):
- # Equal length components
- arr = np.array(1)
- num_outputs = 1
- diff_arr = np.array(2)
- self.run_core_tests(lambda: self._build_tensor_dataset(arr),
- lambda: self._build_tensor_dataset(diff_arr),
- num_outputs)
-
-
-class FromTensorSlicesSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_tensor_slices_dataset(self, components):
- return dataset_ops.Dataset.from_tensor_slices(components)
-
- def testFromTensorSlicesCore(self):
- # Equal length components
- components = (np.tile(np.array([[1], [2], [3], [4]]), 20),
- np.tile(np.array([[12], [13], [14], [15]]), 22),
- np.array([37.0, 38.0, 39.0, 40.0]))
-
- diff_comp = (np.tile(np.array([[1], [2], [3], [4]]), 20),
- np.tile(np.array([[5], [6], [7], [8]]), 22),
- np.array([1.0, 2.0, 3.0, 4.0]))
-
- dict_components = {"foo": [1, 2, 3], "bar": [[4.0], [5.0], [6.0]]}
-
- self.run_core_tests(lambda: self._build_tensor_slices_dataset(components),
- lambda: self._build_tensor_slices_dataset(diff_comp), 4)
- self.run_core_tests(
- lambda: self._build_tensor_slices_dataset(dict_components), None, 3)
-
-
-class FromSparseTensorSlicesSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_sparse_tensor_slice_dataset(self, slices):
- indices = np.array(
- [[i, j] for i in range(len(slices)) for j in range(len(slices[i]))],
- dtype=np.int64)
- values = np.array([val for s in slices for val in s], dtype=np.float64)
- dense_shape = np.array(
- [len(slices), max(len(s) for s in slices) + 1], dtype=np.int64)
- sparse_components = sparse_tensor.SparseTensor(indices, values, dense_shape)
- return dataset_ops.Dataset.from_sparse_tensor_slices(sparse_components)
-
- def testFromSparseTensorSlicesCore(self):
- slices = [[1., 2., 3.], [1.], [1.], [1., 2.], [], [1., 2.], [], [], []]
- diff_slices = [[1., 2.], [2.], [2., 3., 4.], [], [], []]
-
- self.run_core_tests(
- lambda: self._build_sparse_tensor_slice_dataset(slices),
- lambda: self._build_sparse_tensor_slice_dataset(diff_slices),
- 9,
- sparse_tensors=True)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py b/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py
deleted file mode 100644
index 595cecef4d..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py
+++ /dev/null
@@ -1,692 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Base class for testing serializable datasets."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import os
-
-import numpy as np
-
-from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops
-from tensorflow.python.data.ops import iterator_ops
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.ops import lookup_ops
-from tensorflow.python.ops import variables
-from tensorflow.python.platform import gfile
-from tensorflow.python.platform import test
-from tensorflow.python.training import checkpoint_management
-from tensorflow.python.training import saver as saver_lib
-from tensorflow.python.util import nest
-
-
-def remove_variants(get_next_op):
- # TODO(b/72408568): Remove this once session.run can get
- # variant tensors.
- """Remove variants from a nest structure, so sess.run will execute."""
-
- def _remove_variant(x):
- if isinstance(x, ops.Tensor) and x.dtype == dtypes.variant:
- return ()
- else:
- return x
-
- return nest.map_structure(_remove_variant, get_next_op)
-
-
-class DatasetSerializationTestBase(test.TestCase):
- """Base class for testing serializable datasets."""
-
- def tearDown(self):
- self._delete_ckpt()
-
- # TODO(b/72657739): Remove sparse_tensor argument, which is to test the
- # (deprecated) saveable `SparseTensorSliceDataset`, once the API
- # `from_sparse_tensor_slices()`and related tests are deleted.
- def run_core_tests(self, ds_fn1, ds_fn2, num_outputs, sparse_tensors=False):
- """Runs the core tests.
-
- Args:
- ds_fn1: 0-argument function that returns a Dataset.
- ds_fn2: 0-argument function that returns a Dataset different from
- ds_fn1. If None, verify_restore_in_modified_graph test is not run.
- num_outputs: Total number of outputs expected from this Dataset.
- sparse_tensors: Whether dataset is built from SparseTensor(s).
-
- Raises:
- AssertionError if any test fails.
- """
- self.verify_unused_iterator(
- ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
- self.verify_fully_used_iterator(
- ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
- self.verify_exhausted_iterator(
- ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
- self.verify_init_before_restore(
- ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
- self.verify_multiple_breaks(
- ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
- self.verify_reset_restored_iterator(
- ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
- self.verify_restore_in_empty_graph(
- ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
- if ds_fn2:
- self.verify_restore_in_modified_graph(
- ds_fn1, ds_fn2, num_outputs, sparse_tensors=sparse_tensors)
-
- def verify_unused_iterator(self,
- ds_fn,
- num_outputs,
- sparse_tensors=False,
- verify_exhausted=True):
- """Verifies that saving and restoring an unused iterator works.
-
- Args:
- ds_fn: See `run_core_tests`.
- num_outputs: See `run_core_tests`.
- sparse_tensors: See `run_core_tests`.
- verify_exhausted: See `gen_outputs`.
-
- Raises:
- AssertionError if any test fails.
- """
- self.verify_run_with_breaks(
- ds_fn, [0],
- num_outputs,
- sparse_tensors=sparse_tensors,
- verify_exhausted=verify_exhausted)
-
- def verify_fully_used_iterator(self, ds_fn, num_outputs,
- sparse_tensors=False):
- """Verifies that saving and restoring a fully used iterator works.
-
- Note that this only checks saving and restoring an iterator from which
- `num_outputs` items have been produced but does not check for an
- exhausted iterator, i.e., one from which an OutOfRange error has been
- returned.
-
- Args:
- ds_fn: See `run_core_tests`.
- num_outputs: See `run_core_tests`.
- sparse_tensors: See `run_core_tests`.
-
- Raises:
- AssertionError if test fails.
- """
- self.verify_run_with_breaks(
- ds_fn, [num_outputs], num_outputs, sparse_tensors=sparse_tensors)
-
- def verify_exhausted_iterator(self, ds_fn, num_outputs, sparse_tensors=False):
- """Verifies that saving and restoring an exhausted iterator works.
-
- An exhausted iterator is one which has returned an OutOfRange error.
-
- Args:
- ds_fn: See `run_core_tests`.
- num_outputs: See `run_core_tests`.
- sparse_tensors: See `run_core_tests`.
-
- Raises:
- AssertionError if any test fails.
- """
- self.gen_outputs(
- ds_fn, [],
- num_outputs,
- verify_exhausted=True,
- sparse_tensors=sparse_tensors)
- actual = self.gen_outputs(
- ds_fn, [],
- 0,
- ckpt_saved=True,
- verify_exhausted=True,
- sparse_tensors=sparse_tensors)
- self.assertEqual(len(actual), 0)
-
- def verify_init_before_restore(self,
- ds_fn,
- num_outputs,
- sparse_tensors=False,
- verify_exhausted=True):
- """Verifies that restoring into an already initialized iterator works.
-
- Args:
- ds_fn: See `run_core_tests`.
- num_outputs: See `run_core_tests`.
- sparse_tensors: See `run_core_tests`.
- verify_exhausted: See `gen_outputs`.
-
- Raises:
- AssertionError if any test fails.
- """
- self.verify_run_with_breaks(
- ds_fn,
- self.gen_break_points(num_outputs),
- num_outputs,
- init_before_restore=True,
- sparse_tensors=sparse_tensors,
- verify_exhausted=verify_exhausted)
-
- def verify_multiple_breaks(self,
- ds_fn,
- num_outputs,
- num_breaks=10,
- sparse_tensors=False,
- verify_exhausted=True):
- """Attempts to save/restore at multiple break points.
-
- Args:
- ds_fn: See `run_core_tests`.
- num_outputs: See `run_core_tests`.
- num_breaks: The number of break points. These are uniformly spread in
- [0, num_outputs] both inclusive.
- sparse_tensors: See `run_core_tests`.
- verify_exhausted: See `gen_outputs`.
-
- Raises:
- AssertionError if any test fails.
- """
- self.verify_run_with_breaks(
- ds_fn,
- self.gen_break_points(num_outputs, num_breaks),
- num_outputs,
- sparse_tensors=sparse_tensors,
- verify_exhausted=verify_exhausted)
-
- def verify_reset_restored_iterator(self,
- ds_fn,
- num_outputs,
- break_point=None,
- sparse_tensors=False,
- verify_exhausted=True):
- """Attempts to re-initialize a restored iterator.
-
- This is useful when restoring a training checkpoint during validation.
-
- Args:
- ds_fn: See `run_core_tests`.
- num_outputs: See `run_core_tests`.
- break_point: Break point. Optional. Defaults to num_outputs/2.
- sparse_tensors: See `run_core_tests`.
- verify_exhausted: See `gen_outputs`.
-
- Raises:
- AssertionError if any test fails.
- """
- break_point = num_outputs // 2 if not break_point else break_point
-
- # Collect ground truth containing all outputs.
- expected = self.gen_outputs(
- ds_fn, [],
- num_outputs,
- sparse_tensors=sparse_tensors,
- verify_exhausted=verify_exhausted)
-
- # Skip some items and save checkpoint.
- self.gen_outputs(
- ds_fn, [],
- break_point,
- sparse_tensors=sparse_tensors,
- verify_exhausted=False)
-
- actual = []
- # Restore from checkpoint and then run init_op.
- with ops.Graph().as_default() as g:
- saver = self._import_meta_graph()
- init_op, get_next_op = self._get_iterator_ops_from_collection(
- ds_fn, sparse_tensors=sparse_tensors)
- get_next_op = remove_variants(get_next_op)
- with self.session(graph=g) as sess:
- self._restore(saver, sess)
- self._initialize(init_op, sess)
- for _ in range(num_outputs):
- actual.append(sess.run(get_next_op))
- if verify_exhausted:
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next_op)
- self.match(expected, actual)
-
- def verify_restore_in_modified_graph(self,
- ds_fn1,
- ds_fn2,
- num_outputs,
- break_point=None,
- sparse_tensors=False,
- verify_exhausted=True):
- """Attempts to restore an iterator in a modified graph.
-
- Builds an input pipeline using ds_fn1, runs it for `break_point` steps
- and saves a checkpoint. Then builds a new graph using ds_fn2, restores
- the checkpoint from ds_fn1 and verifies that the restore is successful.
-
- Args:
- ds_fn1: See `run_core_tests`.
- ds_fn2: See `run_core_tests`.
- num_outputs: See `run_core_tests`.
- break_point: Break point. Optional. Defaults to num_outputs/2.
- sparse_tensors: See `run_core_tests`.
- verify_exhausted: See `gen_outputs`.
-
- Raises:
- AssertionError if any test fails.
- """
- break_point = num_outputs // 2 if not break_point else break_point
-
- # Skip `break_point` items and store the remaining produced from ds_fn1
- # in `expected`.
- self.gen_outputs(
- ds_fn1, [],
- break_point,
- sparse_tensors=sparse_tensors,
- verify_exhausted=False)
- expected = self.gen_outputs(
- ds_fn1, [],
- num_outputs - break_point,
- ckpt_saved=True,
- sparse_tensors=sparse_tensors,
- verify_exhausted=verify_exhausted)
-
- # Generate `break_point` items from ds_fn1 and save checkpoint.
- self.gen_outputs(
- ds_fn1, [],
- break_point,
- sparse_tensors=sparse_tensors,
- verify_exhausted=False)
-
- actual = []
- # Build graph for ds_fn2 but load checkpoint for ds_fn1.
- with ops.Graph().as_default() as g:
- _, get_next_op, saver = self._build_graph(
- ds_fn2, sparse_tensors=sparse_tensors)
- get_next_op = remove_variants(get_next_op)
- with self.session(graph=g) as sess:
- self._restore(saver, sess)
- for _ in range(num_outputs - break_point):
- actual.append(sess.run(get_next_op))
- if verify_exhausted:
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next_op)
-
- self.match(expected, actual)
-
- def verify_restore_in_empty_graph(self,
- ds_fn,
- num_outputs,
- break_point=None,
- sparse_tensors=False,
- verify_exhausted=True):
- """Attempts to restore an iterator in an empty graph.
-
- Builds an input pipeline using ds_fn, runs it for `break_point` steps
- and saves a checkpoint. Then builds a new empty graph, restores
- the checkpoint from ds_fn and verifies that the restore is successful.
-
- Args:
- ds_fn: See `run_core_tests`.
- num_outputs: See `run_core_tests`.
- break_point: Break point. Optional. Defaults to num_outputs/2.
- sparse_tensors: See `run_core_tests`.
- verify_exhausted: See `gen_outputs`.
-
- Raises:
- AssertionError if any test fails.
- """
- break_point = num_outputs // 2 if not break_point else break_point
-
- # Skip `break_point` items and store the remaining produced from ds_fn
- # in `expected`.
- self.gen_outputs(
- ds_fn, [],
- break_point,
- sparse_tensors=sparse_tensors,
- verify_exhausted=False)
- expected = self.gen_outputs(
- ds_fn, [],
- num_outputs - break_point,
- ckpt_saved=True,
- sparse_tensors=sparse_tensors,
- verify_exhausted=verify_exhausted)
-
- # Generate `break_point` items from ds_fn and save checkpoint.
- self.gen_outputs(
- ds_fn, [],
- break_point,
- sparse_tensors=sparse_tensors,
- verify_exhausted=False)
-
- actual = []
- # Build an empty graph but load checkpoint for ds_fn.
- with ops.Graph().as_default() as g:
- get_next_op, saver = self._build_empty_graph(
- ds_fn, sparse_tensors=sparse_tensors)
- get_next_op = remove_variants(get_next_op)
- with self.session(graph=g) as sess:
- self._restore(saver, sess)
- for _ in range(num_outputs - break_point):
- actual.append(sess.run(get_next_op))
- if verify_exhausted:
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next_op)
-
- self.match(expected, actual)
-
- def verify_error_on_save(self,
- ds_fn,
- num_outputs,
- error,
- break_point=None,
- sparse_tensors=False):
- """Attempts to save a non-saveable iterator.
-
- Args:
- ds_fn: See `run_core_tests`.
- num_outputs: See `run_core_tests`.
- error: Declared error when trying to save iterator.
- break_point: Break point. Optional. Defaults to num_outputs/2.
- sparse_tensors: See `run_core_tests`.
-
- Raises:
- AssertionError if any test fails.
- """
-
- break_point = num_outputs // 2 if not break_point else break_point
- with ops.Graph().as_default() as g:
- init_op, get_next_op, saver = self._build_graph(
- ds_fn, sparse_tensors=sparse_tensors)
- get_next_op = remove_variants(get_next_op)
- with self.session(graph=g) as sess:
- self._initialize(init_op, sess)
- for _ in range(break_point):
- sess.run(get_next_op)
- with self.assertRaises(error):
- self._save(sess, saver)
-
- def verify_run_with_breaks(self,
- ds_fn,
- break_points,
- num_outputs,
- init_before_restore=False,
- sparse_tensors=False,
- verify_exhausted=True):
- """Verifies that ds_fn() produces the same outputs with and without breaks.
-
- 1. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it
- *without* stopping at break points.
- 2. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it
- with stopping at break points.
-
- Deep matches outputs from 1 and 2.
-
- Args:
- ds_fn: See `gen_outputs`.
- break_points: See `gen_outputs`.
- num_outputs: See `gen_outputs`.
- init_before_restore: See `gen_outputs`.
- sparse_tensors: See `run_core_tests`.
- verify_exhausted: See `gen_outputs`.
-
- Raises:
- AssertionError if any test fails.
- """
- expected = self.gen_outputs(
- ds_fn, [],
- num_outputs,
- init_before_restore=init_before_restore,
- sparse_tensors=sparse_tensors,
- verify_exhausted=verify_exhausted)
-
- actual = self.gen_outputs(
- ds_fn,
- break_points,
- num_outputs,
- init_before_restore=init_before_restore,
- sparse_tensors=sparse_tensors,
- verify_exhausted=verify_exhausted)
-
- self.match(expected, actual)
-
- def gen_outputs(self,
- ds_fn,
- break_points,
- num_outputs,
- ckpt_saved=False,
- init_before_restore=False,
- sparse_tensors=False,
- verify_exhausted=True,
- save_checkpoint_at_end=True):
- """Generates elements from input dataset while stopping at break points.
-
- Produces `num_outputs` outputs and saves the state of the iterator in the
- Saver checkpoint.
-
- Args:
- ds_fn: 0-argument function that returns the dataset.
- break_points: A list of integers. For each `break_point` in
- `break_points`, we produce outputs till `break_point` number of items
- have been produced and then checkpoint the state. The current graph
- and session are destroyed and a new graph and session are used to
- produce outputs till next checkpoint or till `num_outputs` elements
- have been produced. `break_point` must be <= `num_outputs`.
- num_outputs: The total number of outputs to produce from the iterator.
- ckpt_saved: Whether a checkpoint already exists. If False, we build the
- graph from ds_fn.
- init_before_restore: Whether init should be called before saver.restore.
- This is just so that we can verify that restoring an already initialized
- iterator works.
- sparse_tensors: Whether dataset is built from SparseTensor(s).
- verify_exhausted: Whether to verify that the iterator has been exhausted
- after producing `num_outputs` elements.
- save_checkpoint_at_end: Whether to save a checkpoint after producing all
- outputs. If False, checkpoints are saved each break point but not at the
- end. Note that checkpoints overwrite each other so there is always only
- a single checkpoint available. Defaults to True.
-
- Returns:
- A list of `num_outputs` items.
- """
- outputs = []
-
- def get_ops():
- if ckpt_saved:
- saver = self._import_meta_graph()
- init_op, get_next_op = self._get_iterator_ops_from_collection(
- ds_fn, sparse_tensors=sparse_tensors)
- else:
- init_op, get_next_op, saver = self._build_graph(
- ds_fn, sparse_tensors=sparse_tensors)
- return init_op, get_next_op, saver
-
- for i in range(len(break_points) + 1):
- with ops.Graph().as_default() as g:
- init_op, get_next_op, saver = get_ops()
- get_next_op = remove_variants(get_next_op)
- with self.session(graph=g) as sess:
- if ckpt_saved:
- if init_before_restore:
- self._initialize(init_op, sess)
- self._restore(saver, sess)
- else:
- self._initialize(init_op, sess)
- start = break_points[i - 1] if i > 0 else 0
- end = break_points[i] if i < len(break_points) else num_outputs
- num_iters = end - start
- for _ in range(num_iters):
- outputs.append(sess.run(get_next_op))
- if i == len(break_points) and verify_exhausted:
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next_op)
- if save_checkpoint_at_end or i < len(break_points):
- self._save(sess, saver)
- ckpt_saved = True
-
- return outputs
-
- def match(self, expected, actual):
- """Matches nested structures.
-
- Recursively matches shape and values of `expected` and `actual`.
- Handles scalars, numpy arrays and other python sequence containers
- e.g. list, dict.
-
- Args:
- expected: Nested structure 1.
- actual: Nested structure 2.
-
- Raises:
- AssertionError if matching fails.
- """
- if isinstance(expected, np.ndarray):
- expected = expected.tolist()
- if isinstance(actual, np.ndarray):
- actual = actual.tolist()
- self.assertEqual(type(expected), type(actual))
-
- if nest.is_sequence(expected):
- self.assertEqual(len(expected), len(actual))
- if isinstance(expected, dict):
- for key1, key2 in zip(sorted(expected), sorted(actual)):
- self.assertEqual(key1, key2)
- self.match(expected[key1], actual[key2])
- else:
- for item1, item2 in zip(expected, actual):
- self.match(item1, item2)
- else:
- self.assertEqual(expected, actual)
-
- def does_not_match(self, expected, actual):
- with self.assertRaises(AssertionError):
- self.match(expected, actual)
-
- def gen_break_points(self, num_outputs, num_samples=10):
- """Generates `num_samples` breaks points in [0, num_outputs]."""
- return np.linspace(0, num_outputs, num_samples, dtype=int)
-
- def _build_graph(self, ds_fn, sparse_tensors=False):
- iterator = ds_fn().make_initializable_iterator()
-
- saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
- ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
- init_op = iterator.initializer
- if sparse_tensors:
- get_next = sparse_tensor.SparseTensor(*iterator.get_next())
- else:
- get_next = iterator.get_next()
- self._add_iterator_ops_to_collection(init_op, get_next, ds_fn,
- sparse_tensors)
- saver = saver_lib.Saver(allow_empty=True)
- return init_op, get_next, saver
-
- def _build_empty_graph(self, ds_fn, sparse_tensors=False):
- iterator = iterator_ops.Iterator.from_structure(
- self._get_output_types(ds_fn),
- output_shapes=self._get_output_shapes(ds_fn),
- output_classes=self._get_output_classes(ds_fn))
- saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
- ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
- if sparse_tensors:
- get_next = sparse_tensor.SparseTensor(*iterator.get_next())
- else:
- get_next = iterator.get_next()
- saver = saver_lib.Saver(allow_empty=True)
- return get_next, saver
-
- def _add_iterator_ops_to_collection(self,
- init_op,
- get_next,
- ds_fn,
- sparse_tensors=False):
- ops.add_to_collection("iterator_ops", init_op)
- # `get_next` may be a tuple e.g. in TensorSliceDataset. Since Collections
- # do not support tuples we flatten the tensors and restore the shape in
- # `_get_iterator_ops_from_collection`.
- if sparse_tensors: # specific for deprecated `from_sparse_tensor_slices`.
- ops.add_to_collection("iterator_ops", get_next.indices)
- ops.add_to_collection("iterator_ops", get_next.values)
- ops.add_to_collection("iterator_ops", get_next.dense_shape)
- return
-
- get_next_list = nest.flatten(get_next)
- for i, output_class in enumerate(
- nest.flatten(self._get_output_classes(ds_fn))):
- if output_class is sparse_tensor.SparseTensor:
- ops.add_to_collection("iterator_ops", get_next_list[i].indices)
- ops.add_to_collection("iterator_ops", get_next_list[i].values)
- ops.add_to_collection("iterator_ops", get_next_list[i].dense_shape)
- else:
- ops.add_to_collection("iterator_ops", get_next_list[i])
-
- def _get_iterator_ops_from_collection(self, ds_fn, sparse_tensors=False):
- all_ops = ops.get_collection("iterator_ops")
- if sparse_tensors: # specific for deprecated `from_sparse_tensor_slices`.
- init_op, indices, values, dense_shape = all_ops
- return init_op, sparse_tensor.SparseTensor(indices, values, dense_shape)
- get_next_list = []
- i = 1
- for output_class in nest.flatten(self._get_output_classes(ds_fn)):
- if output_class is sparse_tensor.SparseTensor:
- indices, values, dense_shape = all_ops[i:i + 3]
- i += 3
- get_next_list.append(
- sparse_tensor.SparseTensor(indices, values, dense_shape))
- else:
- get_next_list.append(all_ops[i])
- i += 1
- return all_ops[0], nest.pack_sequence_as(
- self._get_output_types(ds_fn), get_next_list)
-
- def _get_output_types(self, ds_fn):
- with ops.Graph().as_default():
- return ds_fn().output_types
-
- def _get_output_shapes(self, ds_fn):
- with ops.Graph().as_default():
- return ds_fn().output_shapes
-
- def _get_output_classes(self, ds_fn):
- with ops.Graph().as_default():
- return ds_fn().output_classes
-
- def _ckpt_path(self):
- return os.path.join(self.get_temp_dir(), "iterator")
-
- def _latest_ckpt(self):
- return checkpoint_management.latest_checkpoint(self.get_temp_dir())
-
- def _save(self, sess, saver):
- saver.save(sess, self._ckpt_path())
-
- def _restore(self, saver, sess):
- sess.run(lookup_ops.tables_initializer())
- saver.restore(sess, self._latest_ckpt())
-
- def _initialize(self, init_op, sess):
- sess.run(variables.global_variables_initializer())
- sess.run(lookup_ops.tables_initializer())
- sess.run(init_op)
-
- def _import_meta_graph(self):
- meta_file_path = self._ckpt_path() + ".meta"
- return saver_lib.import_meta_graph(meta_file_path)
-
- def _delete_ckpt(self):
- # Remove all checkpoint files.
- prefix = self._ckpt_path()
- pattern = prefix + "*"
- files = gfile.Glob(pattern)
- map(gfile.Remove, files)
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/filter_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/filter_dataset_serialization_test.py
deleted file mode 100644
index 7c170078a1..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/filter_dataset_serialization_test.py
+++ /dev/null
@@ -1,71 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the FilterDataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.ops import math_ops
-from tensorflow.python.platform import test
-
-
-class FilterDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_filter_range_graph(self, div):
- return dataset_ops.Dataset.range(100).filter(
- lambda x: math_ops.not_equal(math_ops.mod(x, div), 2))
-
- def testFilterCore(self):
- div = 3
- num_outputs = np.sum([x % 3 != 2 for x in range(100)])
- self.run_core_tests(lambda: self._build_filter_range_graph(div),
- lambda: self._build_filter_range_graph(div * 2),
- num_outputs)
-
- def _build_filter_dict_graph(self):
- return dataset_ops.Dataset.range(10).map(
- lambda x: {"foo": x * 2, "bar": x ** 2}).filter(
- lambda d: math_ops.equal(d["bar"] % 2, 0)).map(
- lambda d: d["foo"] + d["bar"])
-
- def testFilterDictCore(self):
- num_outputs = np.sum([(x**2) % 2 == 0 for x in range(10)])
- self.run_core_tests(self._build_filter_dict_graph, None, num_outputs)
-
- def _build_sparse_filter(self):
-
- def _map_fn(i):
- return sparse_tensor.SparseTensor(
- indices=[[0, 0]], values=(i * [1]), dense_shape=[1, 1]), i
-
- def _filter_fn(_, i):
- return math_ops.equal(i % 2, 0)
-
- return dataset_ops.Dataset.range(10).map(_map_fn).filter(_filter_fn).map(
- lambda x, i: x)
-
- def testSparseCore(self):
- num_outputs = 5
- self.run_core_tests(self._build_sparse_filter, None, num_outputs)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py
deleted file mode 100644
index 34392d88d4..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py
+++ /dev/null
@@ -1,45 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the FixedLengthRecordDataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.python.data.ops import readers as core_readers
-from tensorflow.python.platform import test
-
-
-class FixedLengthRecordDatasetSerializationTest(
- reader_dataset_ops_test_base.FixedLengthRecordDatasetTestBase,
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_iterator_graph(self, num_epochs, compression_type=None):
- filenames = self._createFiles()
- return core_readers.FixedLengthRecordDataset(
- filenames, self._record_bytes, self._header_bytes,
- self._footer_bytes).repeat(num_epochs)
-
- def testFixedLengthRecordCore(self):
- num_epochs = 5
- num_outputs = num_epochs * self._num_files * self._num_records
- self.run_core_tests(lambda: self._build_iterator_graph(num_epochs),
- lambda: self._build_iterator_graph(num_epochs * 2),
- num_outputs)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/flat_map_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/flat_map_dataset_serialization_test.py
deleted file mode 100644
index 16051ffd3f..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/flat_map_dataset_serialization_test.py
+++ /dev/null
@@ -1,122 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the FlatMapDataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import function
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import random_ops
-from tensorflow.python.ops import sparse_ops
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.platform import test
-
-
-class FlatMapDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def testCore(self):
- # Complicated way of saying range(start, start+25).
- def build_ds(start):
-
- def map_fn(x):
- return dataset_ops.Dataset.range(x, x + 5)
-
- return dataset_ops.Dataset.range(start, start + 5 * 5, 5).flat_map(map_fn)
-
- self.run_core_tests(lambda: build_ds(0), lambda: build_ds(10), 25)
-
- def testMapThenFlatMap(self):
-
- def build_ds():
-
- def flat_map_fn(_):
-
- def map_fn(y):
- return 10 * math_ops.to_int32(y)
-
- return dataset_ops.Dataset.range(100).map(map_fn)
-
- return dataset_ops.Dataset.range(5).flat_map(flat_map_fn)
-
- self.run_core_tests(build_ds, None, 500)
-
- def testCaptureDefunInMapFn(self):
-
- def build_ds():
-
- def map_fn(x):
-
- @function.Defun(dtypes.int64)
- def defun_fn(x):
- return constant_op.constant(1000) + math_ops.to_int32(x)
-
- return dataset_ops.Dataset.from_tensor_slices([defun_fn(x)])
-
- return dataset_ops.Dataset.range(100).flat_map(map_fn)
-
- self.run_core_tests(build_ds, None, 100)
-
- def testDisallowVariableCapture(self):
-
- def build_ds():
- test_var = variable_scope.get_variable(
- name="test_var", shape=(), use_resource=True)
- return dataset_ops.Dataset.range(5).flat_map(
- lambda _: dataset_ops.Dataset.from_tensor_slices([test_var]))
-
- self.verify_error_on_save(build_ds, 5, errors.InvalidArgumentError)
-
- def testDisallowCapturingStatefulOps(self):
-
- def build_ds():
-
- def flat_map_fn(_):
-
- def map_fn(x):
- return random_ops.random_uniform(
- (), 0, 10, dtype=dtypes.int32) * math_ops.to_int32(x)
-
- return dataset_ops.Dataset.range(100).map(map_fn)
-
- return dataset_ops.Dataset.range(5).flat_map(flat_map_fn)
-
- self.verify_error_on_save(build_ds, 500, errors.InvalidArgumentError)
-
- def testSparseCore(self):
-
- def _map_fn(i):
- return sparse_tensor.SparseTensorValue(
- indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2])
-
- def _flat_map_fn(x):
- return dataset_ops.Dataset.from_tensor_slices(
- sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values))
-
- def _build_ds():
- return dataset_ops.Dataset.range(10).map(_map_fn).flat_map(_flat_map_fn)
-
- self.run_core_tests(_build_ds, None, 20)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/group_by_reducer_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/group_by_reducer_serialization_test.py
deleted file mode 100644
index 571e0899bb..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/group_by_reducer_serialization_test.py
+++ /dev/null
@@ -1,61 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the GroupByReducer serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import grouping
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.platform import test
-
-
-class GroupByReducerSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_dataset(self, components):
- reducer = grouping.Reducer(
- init_func=lambda _: np.int64(0),
- reduce_func=lambda x, y: x + y,
- finalize_func=lambda x: x)
-
- return dataset_ops.Dataset.from_tensor_slices(components).apply(
- grouping.group_by_reducer(lambda x: x % 5, reducer))
-
- def testCoreGroupByReducer(self):
- components = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=np.int64)
- self.verify_unused_iterator(
- lambda: self._build_dataset(components), 5, verify_exhausted=True)
- self.verify_init_before_restore(
- lambda: self._build_dataset(components), 5, verify_exhausted=True)
- self.verify_multiple_breaks(
- lambda: self._build_dataset(components), 5, verify_exhausted=True)
- self.verify_reset_restored_iterator(
- lambda: self._build_dataset(components), 5, verify_exhausted=True)
- self.verify_restore_in_empty_graph(
- lambda: self._build_dataset(components), 5, verify_exhausted=True)
- diff_components = np.array([5, 4, 3, 2, 1, 0], dtype=np.int64)
- self.verify_restore_in_modified_graph(
- lambda: self._build_dataset(components),
- lambda: self._build_dataset(diff_components),
- 5,
- verify_exhausted=True)
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/group_by_window_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/group_by_window_serialization_test.py
deleted file mode 100644
index f86af4084e..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/group_by_window_serialization_test.py
+++ /dev/null
@@ -1,57 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the GroupByWindow serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import grouping
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.platform import test
-
-
-class GroupByWindowSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_dataset(self, components):
- return dataset_ops.Dataset.from_tensor_slices(components).repeat(-1).apply(
- grouping.group_by_window(lambda x: x % 3, lambda _, xs: xs.batch(4), 4))
-
- def testCoreGroupByWindow(self):
- components = np.array(
- [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0], dtype=np.int64)
- self.verify_unused_iterator(
- lambda: self._build_dataset(components), 12, verify_exhausted=False)
- self.verify_init_before_restore(
- lambda: self._build_dataset(components), 12, verify_exhausted=False)
- self.verify_multiple_breaks(
- lambda: self._build_dataset(components), 12, verify_exhausted=False)
- self.verify_reset_restored_iterator(
- lambda: self._build_dataset(components), 12, verify_exhausted=False)
- self.verify_restore_in_empty_graph(
- lambda: self._build_dataset(components), 12, verify_exhausted=False)
- diff_components = np.array([0, 0, 0, 1, 1, 1], dtype=np.int64)
- self.verify_restore_in_modified_graph(
- lambda: self._build_dataset(components),
- lambda: self._build_dataset(diff_components),
- 12,
- verify_exhausted=False)
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/ignore_errors_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/ignore_errors_serialization_test.py
deleted file mode 100644
index 65ae9923b8..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/ignore_errors_serialization_test.py
+++ /dev/null
@@ -1,46 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the IgnoreErrors input pipeline ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import error_ops
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.platform import test
-
-
-class IgnoreErrorsSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_ds(self, components):
- return dataset_ops.Dataset.from_tensor_slices(components).map(
- lambda x: array_ops.check_numerics(x, "message")).apply(
- error_ops.ignore_errors())
-
- def testIgnoreErrorsCore(self):
- components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
- diff_components = np.array([1., 2., 3., np.nan]).astype(np.float32)
- num_outputs = 4
- self.run_core_tests(lambda: self._build_ds(components),
- lambda: self._build_ds(diff_components), num_outputs)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py
deleted file mode 100644
index 243f6405a1..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py
+++ /dev/null
@@ -1,83 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the InterleaveDataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from absl.testing import parameterized
-import numpy as np
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.ops import sparse_ops
-from tensorflow.python.platform import test
-
-
-class InterleaveDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase,
- parameterized.TestCase):
-
- def _build_iterator_graph(self, input_values, cycle_length, block_length,
- num_parallel_calls):
- repeat_count = 2
- return dataset_ops.Dataset.from_tensor_slices(input_values).repeat(
- repeat_count).interleave(
- lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x),
- cycle_length, block_length, num_parallel_calls)
-
- @parameterized.named_parameters(
- ("1", 2, 3, None),
- ("2", 2, 3, 1),
- ("3", 2, 3, 2),
- ("4", 1, 3, None),
- ("5", 1, 3, 1),
- ("6", 2, 1, None),
- ("7", 2, 1, 1),
- ("8", 2, 1, 2),
- )
- def testSerializationCore(self, cycle_length, block_length,
- num_parallel_calls):
- input_values = np.array([4, 5, 6], dtype=np.int64)
- num_outputs = np.sum(input_values) * 2
- # pylint: disable=g-long-lambda
- self.run_core_tests(
- lambda: self._build_iterator_graph(
- input_values, cycle_length, block_length, num_parallel_calls),
- lambda: self._build_iterator_graph(
- input_values, cycle_length * 2, block_length, num_parallel_calls),
- num_outputs)
- # pylint: enable=g-long-lambda
-
- def testSparseCore(self):
-
- def _map_fn(i):
- return sparse_tensor.SparseTensorValue(
- indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2])
-
- def _interleave_fn(x):
- return dataset_ops.Dataset.from_tensor_slices(
- sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values))
-
- def _build_dataset():
- return dataset_ops.Dataset.range(10).map(_map_fn).interleave(
- _interleave_fn, cycle_length=1)
-
- self.run_core_tests(_build_dataset, None, 20)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py
deleted file mode 100644
index c9cd211328..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py
+++ /dev/null
@@ -1,88 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the MapAndBatchDataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import math
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import batching
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.platform import test
-
-
-class MapAndBatchDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def testNumParallelBatches(self):
- range_size = 11
- num_repeats = 2
- batch_size = 5
- total_outputs = range_size * num_repeats
- num_outputs_drop_remainder = total_outputs // batch_size
- num_outputs_keep_remainder = int(math.ceil(total_outputs / batch_size))
- num_parallel_batches = 2
-
- def build_ds(range_start, drop_remainder=False):
-
- def _map_fn(x):
- return math_ops.square(x)
-
- return dataset_ops.Dataset.range(
- range_start, range_start + range_size).repeat(num_repeats).apply(
- batching.map_and_batch(
- map_func=_map_fn,
- batch_size=batch_size,
- num_parallel_batches=num_parallel_batches,
- drop_remainder=drop_remainder))
-
- self.run_core_tests(lambda: build_ds(10), lambda: build_ds(15),
- num_outputs_keep_remainder)
- self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True),
- num_outputs_drop_remainder)
-
- def testNumParallelCalls(self):
- range_size = 11
- num_repeats = 2
- batch_size = 5
- total_outputs = range_size * num_repeats
- num_outputs_drop_remainder = total_outputs // batch_size
- num_outputs_keep_remainder = int(math.ceil(total_outputs / batch_size))
- num_parallel_calls = 7
-
- def build_ds(range_start, drop_remainder=False):
-
- def _map_fn(x):
- return math_ops.square(x)
-
- return dataset_ops.Dataset.range(
- range_start, range_start + range_size).repeat(num_repeats).apply(
- batching.map_and_batch(
- map_func=_map_fn,
- batch_size=batch_size,
- num_parallel_calls=num_parallel_calls,
- drop_remainder=drop_remainder))
-
- self.run_core_tests(lambda: build_ds(10), lambda: build_ds(15),
- num_outputs_keep_remainder)
- self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True),
- num_outputs_drop_remainder)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/map_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/map_dataset_serialization_test.py
deleted file mode 100644
index ab783e5cce..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/map_dataset_serialization_test.py
+++ /dev/null
@@ -1,140 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the MapDataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import function
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import random_ops
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.platform import test
-
-
-class MapDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def setUp(self):
- self._tensor_slice_len = 7
- self._num_epochs = 14
- self._num_outputs = self._tensor_slice_len * self._num_epochs
-
- def _build_ds(self, multiplier=37.0):
- components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) *
- np.arange(self._tensor_slice_len)[:, np.newaxis],
- np.array(multiplier) * np.arange(self._tensor_slice_len))
-
- def _map_fn(x, y, z):
- return math_ops.square(x), math_ops.square(y), math_ops.square(z)
-
- return (
- dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn)
- .repeat(self._num_epochs))
-
- def testSaveRestoreCore(self):
- self.run_core_tests(
- self._build_ds,
- lambda: self._build_ds(multiplier=15.0),
- self._num_outputs)
-
- def testSaveStatefulFunction(self):
-
- def _build_ds():
-
- def _map_fn(x):
- return random_ops.random_uniform(
- (), 0, 10, dtype=dtypes.int32) * math_ops.to_int32(x)
-
- return dataset_ops.Dataset.range(100).map(_map_fn)
-
- self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError)
-
- def testCaptureVariableInMapFn(self):
-
- def _build_ds():
- counter_var = variable_scope.get_variable(
- "counter", (), dtypes.int32, use_resource=True)
- return (dataset_ops.Dataset.from_tensors(0).repeat(10).map(
- lambda _: counter_var.assign_add(1)))
-
- self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError)
-
- def testCaptureConstantInMapFn(self):
-
- def _build_ds():
- constant_var = constant_op.constant(5)
- return (dataset_ops.Dataset.from_tensors(0).repeat(10).map(
- lambda x: x + constant_var))
-
- self.run_core_tests(_build_ds, None, 10)
-
- def testCaptureDefunInMapFn(self):
- num_outputs = 100
-
- def _build_ds():
-
- @function.Defun(dtypes.int64)
- def defun_fn(x):
- return constant_op.constant(1000) + math_ops.to_int32(x)
-
- return dataset_ops.Dataset.range(num_outputs).map(defun_fn)
-
- self.run_core_tests(_build_ds, None, num_outputs)
-
- def testBuildDefunInMapFn(self):
- num_outputs = 100
-
- def _build_ds():
-
- @function.Defun(dtypes.int64)
- def defun_fn(x):
-
- @function.Defun(dtypes.int32)
- def defun_fn_deep(x):
- return constant_op.constant(1000) + math_ops.to_int32(x)
-
- return constant_op.constant(11000) + defun_fn_deep(math_ops.to_int32(x))
-
- return dataset_ops.Dataset.range(num_outputs).map(defun_fn)
-
- self.run_core_tests(_build_ds, None, num_outputs)
-
- def testSparseCore(self):
-
- def _sparse(i):
- return sparse_tensor.SparseTensorValue(
- indices=np.array([[0, 0]]),
- values=(i * np.array([1])),
- dense_shape=np.array([1, 1]))
-
- def _build_ds(num_outputs):
- return dataset_ops.Dataset.range(num_outputs).map(_sparse)
-
- num_outputs = 10
- self.run_core_tests(lambda: _build_ds(num_outputs),
- lambda: _build_ds(int(num_outputs / 2)), num_outputs)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/optimize_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/optimize_dataset_serialization_test.py
deleted file mode 100644
index d5c03495e3..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/optimize_dataset_serialization_test.py
+++ /dev/null
@@ -1,39 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the OptimizeDataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import optimization
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.platform import test
-
-
-class OptimizeDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def testCore(self):
-
- def build_dataset(num_elements, batch_size):
- return dataset_ops.Dataset.range(num_elements).map(lambda x: x * x).batch(
- batch_size).apply(optimization.optimize(["map_and_batch_fusion"]))
-
- self.run_core_tests(lambda: build_dataset(200, 10), None, 20)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/padded_batch_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/padded_batch_dataset_serialization_test.py
deleted file mode 100644
index 9ac42a461a..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/padded_batch_dataset_serialization_test.py
+++ /dev/null
@@ -1,66 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the PaddedBatchDataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import string_ops
-from tensorflow.python.platform import test
-
-
-class PaddedBatchDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def testPaddedBatch(self):
-
- def build_dataset(seq_lens):
- return dataset_ops.Dataset.from_tensor_slices(seq_lens).map(
- lambda x: array_ops.fill([x], x)).padded_batch(
- 4, padded_shapes=[-1])
-
- seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32)
- seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32)
- self.run_core_tests(lambda: build_dataset(seq_lens1),
- lambda: build_dataset(seq_lens2), 8)
-
- def testPaddedBatchNonDefaultPadding(self):
-
- def build_dataset(seq_lens):
-
- def fill_tuple(x):
- filled = array_ops.fill([x], x)
- return (filled, string_ops.as_string(filled))
-
- padded_shape = [-1]
- return dataset_ops.Dataset.from_tensor_slices(seq_lens).map(
- fill_tuple).padded_batch(
- 4,
- padded_shapes=(padded_shape, padded_shape),
- padding_values=(-1, "<end>"))
-
- seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32)
- seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32)
- self.run_core_tests(lambda: build_dataset(seq_lens1),
- lambda: build_dataset(seq_lens2), 8)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py
deleted file mode 100644
index 1f8a584df9..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py
+++ /dev/null
@@ -1,101 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the ParallelInterleaveDataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import interleave_ops
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.ops import sparse_ops
-from tensorflow.python.platform import test
-
-
-class ParallelInterleaveDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def setUp(self):
- self.input_values = np.array([4, 5, 6], dtype=np.int64)
- self.num_repeats = 2
- self.num_outputs = np.sum(self.input_values) * 2
-
- def _build_ds(self, cycle_length, block_length, sloppy=False):
- return (dataset_ops.Dataset.from_tensor_slices(
- self.input_values).repeat(self.num_repeats).apply(
- interleave_ops.parallel_interleave(
- lambda x: dataset_ops.Dataset.range(10 * x, 11 * x),
- cycle_length, block_length, sloppy)))
-
- def testSerializationCore(self):
- # cycle_length > 1, block_length > 1
- cycle_length = 2
- block_length = 3
- self.run_core_tests(
- lambda: self._build_ds(cycle_length, block_length),
- lambda: self._build_ds(cycle_length * 2, block_length * 1),
- self.num_outputs)
- # cycle_length = 1
- cycle_length = 1
- block_length = 3
- self.run_core_tests(lambda: self._build_ds(cycle_length, block_length),
- None, self.num_outputs)
- # block_length = 1
- cycle_length = 2
- block_length = 1
- self.run_core_tests(lambda: self._build_ds(cycle_length, block_length),
- None, self.num_outputs)
-
- def testSerializationWithSloppy(self):
- break_points = self.gen_break_points(self.num_outputs, 10)
- expected_outputs = np.repeat(
- np.concatenate([np.arange(10 * x, 11 * x) for x in self.input_values]),
- self.num_repeats).tolist()
-
- def run_test(cycle_length, block_length):
- actual = self.gen_outputs(
- lambda: self._build_ds(cycle_length, block_length, True),
- break_points, self.num_outputs)
- self.assertSequenceEqual(sorted(actual), expected_outputs)
-
- # cycle_length > 1, block_length > 1
- run_test(2, 3)
- # cycle_length = 1
- run_test(1, 3)
- # block_length = 1
- run_test(2, 1)
-
- def testSparseCore(self):
-
- def _map_fn(i):
- return sparse_tensor.SparseTensorValue(
- indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2])
-
- def _interleave_fn(x):
- return dataset_ops.Dataset.from_tensor_slices(
- sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values))
-
- def _build_dataset():
- return dataset_ops.Dataset.range(10).map(_map_fn).apply(
- interleave_ops.parallel_interleave(_interleave_fn, 1))
-
- self.run_core_tests(_build_dataset, None, 20)
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/parallel_map_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/parallel_map_dataset_serialization_test.py
deleted file mode 100644
index 3fb7605be1..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/parallel_map_dataset_serialization_test.py
+++ /dev/null
@@ -1,139 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the ParallelMapDataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import function
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import random_ops
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.platform import test
-
-
-class ParallelMapDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def setUp(self):
- self._tensor_slice_len = 7
- self._num_epochs = 1
- self._num_outputs = self._tensor_slice_len * self._num_epochs
-
- def _build_ds(self, multiplier=37.0):
- components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) *
- np.arange(self._tensor_slice_len)[:, np.newaxis],
- np.array(multiplier) * np.arange(self._tensor_slice_len))
-
- def _map_fn(x, y, z):
- return math_ops.square(x), math_ops.square(y), math_ops.square(z)
-
- return (dataset_ops.Dataset.from_tensor_slices(components).map(
- _map_fn, num_parallel_calls=3).repeat(self._num_epochs))
-
- def _build_ds_with_prefetch(self, multiplier=37.0):
- components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) *
- np.arange(self._tensor_slice_len)[:, np.newaxis],
- np.array(multiplier) * np.arange(self._tensor_slice_len))
-
- def _map_fn(x, y, z):
- return math_ops.square(x), math_ops.square(y), math_ops.square(z)
-
- return (dataset_ops.Dataset.from_tensor_slices(components).map(
- _map_fn, num_parallel_calls=3).repeat(self._num_epochs).prefetch(5))
-
- def testSaveRestoreCore(self):
- for ds_fn in [self._build_ds, self._build_ds_with_prefetch]:
- self.run_core_tests(
- ds_fn,
- lambda: ds_fn(multiplier=15.0),
- self._num_outputs)
-
- def testSaveStatefulFunction(self):
-
- def _build_ds():
-
- def _map_fn(x):
- return random_ops.random_uniform(
- (), 0, 10, dtype=dtypes.int32) * math_ops.to_int32(x)
-
- return dataset_ops.Dataset.range(100).map(
- _map_fn, num_parallel_calls=2).prefetch(2)
-
- self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError)
-
- def testCaptureVariableInMapFn(self):
-
- def _build_ds():
- counter_var = variable_scope.get_variable(
- "counter", (), dtypes.int32, use_resource=True)
- return (dataset_ops.Dataset.from_tensors(0).repeat(10).map(
- lambda _: counter_var.assign_add(1),
- num_parallel_calls=2).prefetch(2))
-
- self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError)
-
- def testCaptureConstantInMapFn(self):
-
- def _build_ds():
- constant_var = constant_op.constant(5)
- return (dataset_ops.Dataset.from_tensors(0).repeat(10).map(
- lambda x: x + constant_var, num_parallel_calls=2).prefetch(2))
-
- self.run_core_tests(_build_ds, None, 10)
-
- def testCaptureDefunInMapFn(self):
- num_outputs = 100
-
- def _build_ds():
-
- @function.Defun(dtypes.int64)
- def defun_fn(x):
- return constant_op.constant(1000) + math_ops.to_int32(x)
-
- return dataset_ops.Dataset.range(num_outputs).map(
- defun_fn, num_parallel_calls=2).prefetch(2)
-
- self.run_core_tests(_build_ds, None, num_outputs)
-
- def testBuildDefunInMapFn(self):
- num_outputs = 100
-
- def _build_ds():
-
- @function.Defun(dtypes.int64)
- def defun_fn(x):
-
- @function.Defun(dtypes.int32)
- def defun_fn_deep(x):
- return constant_op.constant(1000) + math_ops.to_int32(x)
-
- return constant_op.constant(11000) + defun_fn_deep(math_ops.to_int32(x))
-
- return dataset_ops.Dataset.range(num_outputs).map(
- defun_fn, num_parallel_calls=2).prefetch(2)
-
- self.run_core_tests(_build_ds, None, num_outputs)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/parse_example_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/parse_example_dataset_serialization_test.py
deleted file mode 100644
index d3fa84e74c..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/parse_example_dataset_serialization_test.py
+++ /dev/null
@@ -1,50 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the ParseExampleDataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.python.platform import test
-
-
-class ParseExampleDatasetSerializationTest(
- reader_dataset_ops_test_base.ReadBatchFeaturesTestBase,
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def ParseExampleDataset(self, num_repeat, batch_size):
- return self.make_batch_feature(
- filenames=self.test_filenames,
- num_epochs=num_repeat,
- batch_size=batch_size,
- reader_num_threads=5,
- parser_num_threads=10)
-
- def testSerializationCore(self):
- num_repeat = 5
- batch_size = 2
- num_outputs = self._num_records * self._num_files * num_repeat // batch_size
- # pylint: disable=g-long-lambda
- self.run_core_tests(
- lambda: self.ParseExampleDataset(
- num_repeat=num_repeat, batch_size=batch_size),
- lambda: self.ParseExampleDataset(num_repeat=10, batch_size=4),
- num_outputs)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/prefetch_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/prefetch_dataset_serialization_test.py
deleted file mode 100644
index c802402461..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/prefetch_dataset_serialization_test.py
+++ /dev/null
@@ -1,39 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the PrefetchDataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.platform import test
-
-
-class PrefetchDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def build_dataset(self, seed):
- return dataset_ops.Dataset.range(100).prefetch(10).shuffle(
- buffer_size=10, seed=seed, reshuffle_each_iteration=False)
-
- def testCore(self):
- num_outputs = 100
- self.run_core_tests(lambda: self.build_dataset(10),
- lambda: self.build_dataset(20), num_outputs)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/range_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/range_dataset_serialization_test.py
deleted file mode 100644
index 6341190847..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/range_dataset_serialization_test.py
+++ /dev/null
@@ -1,118 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the RangeDataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import os
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import gen_dataset_ops
-from tensorflow.python.ops import io_ops
-from tensorflow.python.ops import parsing_ops
-from tensorflow.python.ops import variables
-from tensorflow.python.platform import test
-
-
-class RangeDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _iterator_checkpoint_prefix_local(self):
- return os.path.join(self.get_temp_dir(), "iterator")
-
- def _save_op(self, iterator_resource):
- iterator_state_variant = gen_dataset_ops.serialize_iterator(
- iterator_resource)
- save_op = io_ops.write_file(
- self._iterator_checkpoint_prefix_local(),
- parsing_ops.serialize_tensor(iterator_state_variant))
- return save_op
-
- def _restore_op(self, iterator_resource):
- iterator_state_variant = parsing_ops.parse_tensor(
- io_ops.read_file(self._iterator_checkpoint_prefix_local()),
- dtypes.variant)
- restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
- iterator_state_variant)
- return restore_op
-
- def testSaveRestore(self):
-
- def _build_graph(start, stop):
- iterator = dataset_ops.Dataset.range(start,
- stop).make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
- save_op = self._save_op(iterator._iterator_resource)
- restore_op = self._restore_op(iterator._iterator_resource)
- return init_op, get_next, save_op, restore_op
-
- # Saving and restoring in different sessions.
- start = 2
- stop = 10
- break_point = 5
- with ops.Graph().as_default() as g:
- init_op, get_next, save_op, _ = _build_graph(start, stop)
- with self.session(graph=g) as sess:
- sess.run(variables.global_variables_initializer())
- sess.run(init_op)
- for i in range(start, break_point):
- self.assertEqual(i, sess.run(get_next))
- sess.run(save_op)
-
- with ops.Graph().as_default() as g:
- init_op, get_next, _, restore_op = _build_graph(start, stop)
- with self.session(graph=g) as sess:
- sess.run(init_op)
- sess.run(restore_op)
- for i in range(break_point, stop):
- self.assertEqual(i, sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Saving and restoring in same session.
- with ops.Graph().as_default() as g:
- init_op, get_next, save_op, restore_op = _build_graph(start, stop)
- with self.session(graph=g) as sess:
- sess.run(variables.global_variables_initializer())
- sess.run(init_op)
- for i in range(start, break_point):
- self.assertEqual(i, sess.run(get_next))
- sess.run(save_op)
- sess.run(restore_op)
- for i in range(break_point, stop):
- self.assertEqual(i, sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def _build_range_dataset(self, start, stop):
- return dataset_ops.Dataset.range(start, stop)
-
- def testRangeCore(self):
- start = 2
- stop = 10
- stop_1 = 8
- self.run_core_tests(lambda: self._build_range_dataset(start, stop),
- lambda: self._build_range_dataset(start, stop_1),
- stop - start)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/sample_from_datasets_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/sample_from_datasets_serialization_test.py
deleted file mode 100644
index fdb35ea624..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/sample_from_datasets_serialization_test.py
+++ /dev/null
@@ -1,46 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the SampleFromDatasets serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import interleave_ops
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.platform import test
-
-
-class SampleFromDatasetsSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_dataset(self, probs, num_samples):
- dataset = interleave_ops.sample_from_datasets(
- [
- dataset_ops.Dataset.from_tensors(i).repeat(None)
- for i in range(len(probs))
- ],
- probs,
- seed=1813)
- return dataset.take(num_samples)
-
- def testSerializationCore(self):
- self.run_core_tests(
- lambda: self._build_dataset([0.5, 0.5], 100),
- lambda: self._build_dataset([0.25, 0.25, 0.25, 0.25], 1000), 100)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/scan_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/scan_dataset_serialization_test.py
deleted file mode 100644
index af9ef48c0f..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/scan_dataset_serialization_test.py
+++ /dev/null
@@ -1,40 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the ScanDataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import scan_ops
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.platform import test
-
-
-class ScanDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_dataset(self, num_elements):
- return dataset_ops.Dataset.from_tensors(1).repeat(num_elements).apply(
- scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1])))
-
- def testScanCore(self):
- num_output = 5
- self.run_core_tests(lambda: self._build_dataset(num_output),
- lambda: self._build_dataset(2), num_output)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/sequence_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/sequence_dataset_serialization_test.py
deleted file mode 100644
index 2afebca0f5..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/sequence_dataset_serialization_test.py
+++ /dev/null
@@ -1,129 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the sequence datasets serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.platform import test
-
-
-class SkipDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_skip_dataset(self, count):
- components = (np.arange(10),)
- return dataset_ops.Dataset.from_tensor_slices(components).skip(count)
-
- def testSkipFewerThanInputs(self):
- count = 4
- num_outputs = 10 - count
- self.run_core_tests(lambda: self._build_skip_dataset(count),
- lambda: self._build_skip_dataset(count + 2),
- num_outputs)
-
- def testSkipVarious(self):
- # Skip more than inputs
- self.run_core_tests(lambda: self._build_skip_dataset(20), None, 0)
- # Skip exactly the input size
- self.run_core_tests(lambda: self._build_skip_dataset(10), None, 0)
- self.run_core_tests(lambda: self._build_skip_dataset(-1), None, 0)
- # Skip nothing
- self.run_core_tests(lambda: self._build_skip_dataset(0), None, 10)
-
- def testInvalidSkip(self):
- with self.assertRaisesRegexp(ValueError,
- 'Shape must be rank 0 but is rank 1'):
- self.run_core_tests(lambda: self._build_skip_dataset([1, 2]), None, 0)
-
-
-class TakeDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_take_dataset(self, count):
- components = (np.arange(10),)
- return dataset_ops.Dataset.from_tensor_slices(components).take(count)
-
- def testTakeFewerThanInputs(self):
- count = 4
- self.run_core_tests(
- lambda: self._build_take_dataset(count),
- lambda: self._build_take_dataset(count + 2),
- count,
- )
-
- def testTakeVarious(self):
- # Take more than inputs
- self.run_core_tests(lambda: self._build_take_dataset(20), None, 10)
- # Take exactly the input size
- self.run_core_tests(lambda: self._build_take_dataset(10), None, 10)
- # Take all
- self.run_core_tests(lambda: self._build_take_dataset(-1), None, 10)
- # Take nothing
- self.run_core_tests(lambda: self._build_take_dataset(0), None, 0)
-
- def testInvalidTake(self):
- with self.assertRaisesRegexp(ValueError,
- 'Shape must be rank 0 but is rank 1'):
- self.run_core_tests(lambda: self._build_take_dataset([1, 2]), None, 0)
-
-
-class RepeatDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_repeat_dataset(self, count, take_count=3):
- components = (np.arange(10),)
- return dataset_ops.Dataset.from_tensor_slices(components).take(
- take_count).repeat(count)
-
- def testFiniteRepeat(self):
- count = 10
- self.run_core_tests(lambda: self._build_repeat_dataset(count),
- lambda: self._build_repeat_dataset(count + 2),
- 3 * count)
-
- def testEmptyRepeat(self):
- self.run_core_tests(lambda: self._build_repeat_dataset(0), None, 0)
-
- def testInfiniteRepeat(self):
- self.verify_unused_iterator(
- lambda: self._build_repeat_dataset(-1), 10, verify_exhausted=False)
- self.verify_init_before_restore(
- lambda: self._build_repeat_dataset(-1), 10, verify_exhausted=False)
- self.verify_multiple_breaks(
- lambda: self._build_repeat_dataset(-1), 20, verify_exhausted=False)
- self.verify_reset_restored_iterator(
- lambda: self._build_repeat_dataset(-1), 20, verify_exhausted=False)
- self.verify_restore_in_modified_graph(
- lambda: self._build_repeat_dataset(-1),
- lambda: self._build_repeat_dataset(2),
- 20,
- verify_exhausted=False)
- # Test repeat empty dataset
- self.run_core_tests(lambda: self._build_repeat_dataset(-1, 0), None, 0)
-
- def testInvalidRepeat(self):
- with self.assertRaisesRegexp(
- ValueError, 'Shape must be rank 0 but is rank 1'):
- self.run_core_tests(lambda: self._build_repeat_dataset([1, 2], 0),
- None, 0)
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/serialization_integration_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/serialization_integration_test.py
deleted file mode 100644
index 6aac50ecd9..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/serialization_integration_test.py
+++ /dev/null
@@ -1,85 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Integration test for dataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import os
-
-from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import ops
-from tensorflow.python.platform import test
-from tensorflow.python.training import saver as saver_lib
-
-
-class SerializationIntegrationTest(test.TestCase):
-
- def _build_input_pipeline(self, name, num_outputs):
- with ops.name_scope(name):
- ds = dataset_ops.Dataset.range(num_outputs).shuffle(
- 10, reshuffle_each_iteration=False).prefetch(10)
- iterator = ds.make_initializable_iterator()
- saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
- ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
- return iterator.initializer, iterator.get_next()
-
- def _build_graph(self, num_pipelines, num_outputs):
- init_ops = []
- get_next_ops = []
- for i in range(num_pipelines):
- name = "input_pipeline_%d" % i
- init_op, get_next_op = self._build_input_pipeline(name, num_outputs)
- init_ops.append(init_op)
- get_next_ops.append(get_next_op)
- saver = saver_lib.Saver()
- return init_ops, get_next_ops, saver
-
- def _ckpt_path(self):
- return os.path.join(self.get_temp_dir(), "iterator")
-
- def testConcurrentSaves(self):
- num_pipelines = 100
- num_outputs = 100
- break_point = 10
- all_outputs = [[] for _ in range(num_pipelines)]
- with ops.Graph().as_default() as g:
- init_ops, get_next_ops, saver = self._build_graph(num_pipelines,
- num_outputs)
- with self.session(graph=g) as sess:
- sess.run(init_ops)
- for _ in range(break_point):
- output = sess.run(get_next_ops)
- for i in range(num_pipelines):
- all_outputs[i].append(output[i])
- saver.save(sess, self._ckpt_path())
-
- with ops.Graph().as_default() as g:
- init_ops, get_next_ops, saver = self._build_graph(num_pipelines,
- num_outputs)
- with self.session(graph=g) as sess:
- saver.restore(sess, self._ckpt_path())
- for _ in range(num_outputs - break_point):
- output = sess.run(get_next_ops)
- for i in range(num_pipelines):
- all_outputs[i].append(output[i])
-
- for output in all_outputs:
- self.assertSequenceEqual(sorted(output), range(num_outputs))
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py
deleted file mode 100644
index f199ec835e..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py
+++ /dev/null
@@ -1,39 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the ShuffleAndRepeatDataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import shuffle_ops
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.platform import test
-
-
-class ShuffleAndRepeatSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_ds(self, seed):
- return dataset_ops.Dataset.range(20).apply(
- shuffle_ops.shuffle_and_repeat(buffer_size=5, count=5, seed=seed))
-
- def testCore(self):
- self.run_core_tests(lambda: self._build_ds(10), lambda: self._build_ds(20),
- 100)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_dataset_serialization_test.py
deleted file mode 100644
index a59fa94d66..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_dataset_serialization_test.py
+++ /dev/null
@@ -1,148 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the ShuffleDataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import ops
-from tensorflow.python.platform import test
-from tensorflow.python.training import saver as saver_lib
-
-
-class ShuffleDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_shuffle_dataset(
- self,
- range_limit=10,
- num_repeats=5,
- buffer_size=5,
- seed=None,
- reshuffle_each_iteration=None,
- ):
- return dataset_ops.Dataset.range(range_limit).shuffle(
- buffer_size,
- seed=seed,
- reshuffle_each_iteration=reshuffle_each_iteration).repeat(num_repeats)
-
- def testShuffleCore(self):
-
- seed = 55
- range_limit = 5
- num_repeats = 2
- num_outputs = range_limit * num_repeats
- buffer_sizes = [1, 3, 5, 8, 10]
- # pylint: disable=cell-var-from-loop
- # pylint: disable=g-long-lambda
- for reshuffle_each_iteration in [True, False]:
- for buffer_size in buffer_sizes:
- self.run_core_tests(
- lambda: self._build_shuffle_dataset(
- range_limit=range_limit,
- num_repeats=num_repeats,
- buffer_size=buffer_size,
- seed=seed,
- reshuffle_each_iteration=reshuffle_each_iteration),
- lambda: self._build_shuffle_dataset(
- range_limit=range_limit,
- num_repeats=num_repeats,
- buffer_size=buffer_size,
- seed=10,
- reshuffle_each_iteration=reshuffle_each_iteration),
- num_outputs)
- # pylint: enable=cell-var-from-loop
- # pylint: enable=g-long-lambda
-
- def testNonDeterministicSeeding(self):
-
- range_limit = 5
- num_repeats = 2
- num_outputs = range_limit * num_repeats
- buffer_sizes = [1, 3, 5, 8, 10]
- for reshuffle_each_iteration in [True, False]:
- for buffer_size in buffer_sizes:
-
- def ds_fn():
- # pylint: disable=cell-var-from-loop
- return self._build_shuffle_dataset(
- range_limit=range_limit,
- num_repeats=num_repeats,
- buffer_size=buffer_size,
- seed=None, # Iterator seeds are generated non-deterministically.
- reshuffle_each_iteration=reshuffle_each_iteration)
- # pylint: enable=cell-var-from-loop
-
- # We checkpoint the initial state of the Dataset so that we can restore
- # the seeds in the next run. Since the seeding is non-deterministic
- # the dataset gets initialized with different seeds each time.
- expected = self.gen_outputs(
- ds_fn,
- break_points=[0],
- num_outputs=num_outputs,
- ckpt_saved=False,
- verify_exhausted=False,
- save_checkpoint_at_end=False)
- actual = self.gen_outputs(
- ds_fn,
- break_points=self.gen_break_points(num_outputs),
- num_outputs=num_outputs,
- ckpt_saved=True,
- verify_exhausted=False)
- self.match(expected, actual)
-
- def testMultipleIterators(self):
- range_limit = 5
- num_repeats = 2
- num_outputs = range_limit * num_repeats
- buffer_sizes = [1, 3, 5, 8, 10]
-
- for reshuffle_each_iteration in [True, False]:
- for buffer_size in buffer_sizes:
-
- def ds_fn():
- # pylint: disable=cell-var-from-loop
- return self._build_shuffle_dataset(
- range_limit=range_limit,
- num_repeats=num_repeats,
- buffer_size=buffer_size,
- seed=None, # Iterator seeds are generated non-deterministically.
- reshuffle_each_iteration=reshuffle_each_iteration)
- # pylint: enable=cell-var-from-loop
-
- with ops.Graph().as_default() as g:
- ds = ds_fn()
- iterators = [ds.make_one_shot_iterator(), ds.make_one_shot_iterator()]
- get_next_ops = [it.get_next() for it in iterators]
- saveables = [
- contrib_iterator_ops.make_saveable_from_iterator(it)
- for it in iterators
- ]
- for saveable in saveables:
- ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
- saver = saver_lib.Saver(allow_empty=True)
- with self.session(graph=g) as sess:
- self._save(sess, saver)
- expected = [sess.run(get_next_ops) for _ in range(num_outputs)]
- self._restore(saver, sess)
- actual = [sess.run(get_next_ops) for _ in range(num_outputs)]
- self.match(expected, actual)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/sql_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/sql_dataset_serialization_test.py
deleted file mode 100644
index 93b26ed58a..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/sql_dataset_serialization_test.py
+++ /dev/null
@@ -1,53 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the SqlDataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import os
-
-from tensorflow.contrib.data.python.kernel_tests import sql_dataset_op_test_base
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import readers
-from tensorflow.python.framework import dtypes
-from tensorflow.python.ops import array_ops
-from tensorflow.python.platform import test
-
-
-class SqlDatasetSerializationTest(
- sql_dataset_op_test_base.SqlDatasetTestBase,
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_dataset(self, num_repeats):
- data_source_name = os.path.join(test.get_temp_dir(), "tftest.sqlite")
- driver_name = array_ops.placeholder_with_default(
- array_ops.constant("sqlite", dtypes.string), shape=[])
- query = ("SELECT first_name, last_name, motto FROM students ORDER BY "
- "first_name DESC")
- output_types = (dtypes.string, dtypes.string, dtypes.string)
- return readers.SqlDataset(driver_name, data_source_name, query,
- output_types).repeat(num_repeats)
-
- def testSQLSaveable(self):
- num_repeats = 4
- num_outputs = num_repeats * 2
- self.run_core_tests(lambda: self._build_dataset(num_repeats),
- lambda: self._build_dataset(num_repeats // 2),
- num_outputs)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py
deleted file mode 100644
index a10f85263a..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py
+++ /dev/null
@@ -1,106 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the StatsDataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import stats_ops
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.platform import test
-
-
-# TODO(shivaniagrawal): Can not checkpoint input_pipeline with the
-# transformation `stats_ops.set_stats_aggregator`, since we don't support
-# serializing StatsAggregator yet.
-class StatsDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_dataset_bytes_stats(self, num_elements):
- return dataset_ops.Dataset.range(num_elements).map(
- lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).apply(
- stats_ops.bytes_produced_stats("bytes_produced"))
-
- def test_bytes_produced_stats_invalid_tag_shape(self):
- with self.assertRaisesRegexp(
- ValueError, "Shape must be rank 0 but is rank 1"):
- # pylint: disable=g-long-lambda
- self.run_core_tests(
- lambda: dataset_ops.Dataset.range(100).apply(
- stats_ops.bytes_produced_stats(["bytes_produced"])),
- None, 100)
- # pylint: enable=g-long-lambda
-
- def testBytesStatsDatasetSaveableCore(self):
- num_outputs = 100
- self.run_core_tests(
- lambda: self._build_dataset_bytes_stats(num_outputs),
- lambda: self._build_dataset_bytes_stats(num_outputs // 10), num_outputs)
-
- def _build_dataset_latency_stats(self, num_elements, tag="record_latency"):
- return dataset_ops.Dataset.range(num_elements).apply(
- stats_ops.latency_stats(tag))
-
- def _build_dataset_multiple_tags(self,
- num_elements,
- tag1="record_latency",
- tag2="record_latency_2"):
- return dataset_ops.Dataset.range(num_elements).apply(
- stats_ops.latency_stats(tag1)).apply(stats_ops.latency_stats(tag2))
-
- def test_latency_stats_invalid_tag_shape(self):
- with self.assertRaisesRegexp(
- ValueError, "Shape must be rank 0 but is rank 1"):
- # pylint: disable=g-long-lambda
- self.run_core_tests(
- lambda: dataset_ops.Dataset.range(100).apply(
- stats_ops.latency_stats(["record_latency", "record_latency_2"])),
- None, 100)
- # pylint: enable=g-long-lambda
-
- def testLatencyStatsDatasetSaveableCore(self):
- num_outputs = 100
-
- self.run_core_tests(
- lambda: self._build_dataset_latency_stats(num_outputs),
- lambda: self._build_dataset_latency_stats(num_outputs // 10),
- num_outputs)
-
- self.run_core_tests(lambda: self._build_dataset_multiple_tags(num_outputs),
- None, num_outputs)
-
- tag1 = "record_latency"
- tag2 = "record_latency"
- self.run_core_tests(
- lambda: self._build_dataset_multiple_tags(num_outputs, tag1, tag2),
- None, num_outputs)
-
- def _build_dataset_stats_aggregator(self):
- stats_aggregator = stats_ops.StatsAggregator()
- return dataset_ops.Dataset.range(10).apply(
- stats_ops.set_stats_aggregator(stats_aggregator))
-
- def test_set_stats_aggregator_not_support_checkpointing(self):
- with self.assertRaisesRegexp(errors.UnimplementedError,
- "does not support checkpointing"):
- self.run_core_tests(self._build_dataset_stats_aggregator, None, 10)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/textline_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/textline_dataset_serialization_test.py
deleted file mode 100644
index 2483787f44..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/textline_dataset_serialization_test.py
+++ /dev/null
@@ -1,53 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the TextLineDataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.python.data.ops import readers as core_readers
-from tensorflow.python.platform import test
-
-
-class TextLineDatasetSerializationTest(
- reader_dataset_ops_test_base.TextLineDatasetTestBase,
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_iterator_graph(self, test_filenames, compression_type=None):
- return core_readers.TextLineDataset(
- test_filenames, compression_type=compression_type, buffer_size=10)
-
- def testTextLineCore(self):
- compression_types = [None, "GZIP", "ZLIB"]
- num_files = 5
- lines_per_file = 5
- num_outputs = num_files * lines_per_file
- for compression_type in compression_types:
- test_filenames = self._createFiles(
- num_files,
- lines_per_file,
- crlf=True,
- compression_type=compression_type)
- # pylint: disable=cell-var-from-loop
- self.run_core_tests(
- lambda: self._build_iterator_graph(test_filenames, compression_type),
- lambda: self._build_iterator_graph(test_filenames), num_outputs)
- # pylint: enable=cell-var-from-loop
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/tf_record_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/tf_record_dataset_serialization_test.py
deleted file mode 100644
index 55a6257a27..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/tf_record_dataset_serialization_test.py
+++ /dev/null
@@ -1,99 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the TFRecordDataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import gzip
-import os
-import zlib
-
-from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.python.data.ops import readers as core_readers
-from tensorflow.python.platform import test
-
-
-class TFRecordDatasetSerializationTest(
- reader_dataset_ops_test_base.TFRecordDatasetTestBase,
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_iterator_graph(self,
- num_epochs,
- batch_size=1,
- compression_type=None,
- buffer_size=None):
- filenames = self._createFiles()
- if compression_type == "ZLIB":
- zlib_files = []
- for i, fn in enumerate(filenames):
- with open(fn, "rb") as f:
- cdata = zlib.compress(f.read())
- zfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.z" % i)
- with open(zfn, "wb") as f:
- f.write(cdata)
- zlib_files.append(zfn)
- filenames = zlib_files
-
- elif compression_type == "GZIP":
- gzip_files = []
- for i, fn in enumerate(self.test_filenames):
- with open(fn, "rb") as f:
- gzfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.gz" % i)
- with gzip.GzipFile(gzfn, "wb") as gzf:
- gzf.write(f.read())
- gzip_files.append(gzfn)
- filenames = gzip_files
-
- return core_readers.TFRecordDataset(
- filenames, compression_type,
- buffer_size=buffer_size).repeat(num_epochs).batch(batch_size)
-
- def testTFRecordWithoutBufferCore(self):
- num_epochs = 5
- batch_size = num_epochs
- num_outputs = num_epochs * self._num_files * self._num_records // batch_size
- # pylint: disable=g-long-lambda
- self.run_core_tests(
- lambda: self._build_iterator_graph(num_epochs, batch_size,
- buffer_size=0),
- lambda: self._build_iterator_graph(num_epochs * 2, batch_size),
- num_outputs)
- self.run_core_tests(
- lambda: self._build_iterator_graph(num_epochs, buffer_size=0), None,
- num_outputs * batch_size)
- # pylint: enable=g-long-lambda
-
- def testTFRecordWithBufferCore(self):
- num_epochs = 5
- num_outputs = num_epochs * self._num_files * self._num_records
- self.run_core_tests(lambda: self._build_iterator_graph(num_epochs),
- lambda: self._build_iterator_graph(num_epochs * 2),
- num_outputs)
-
- def testTFRecordWithCompressionCore(self):
- num_epochs = 5
- num_outputs = num_epochs * self._num_files * self._num_records
- self.run_core_tests(
- lambda: self._build_iterator_graph(num_epochs, compression_type="ZLIB"),
- lambda: self._build_iterator_graph(num_epochs * 2), num_outputs)
- self.run_core_tests(
- lambda: self._build_iterator_graph(num_epochs, compression_type="GZIP"),
- lambda: self._build_iterator_graph(num_epochs * 2), num_outputs)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/unbatch_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/unbatch_dataset_serialization_test.py
deleted file mode 100644
index b2a5a8a20d..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/unbatch_dataset_serialization_test.py
+++ /dev/null
@@ -1,51 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the UnbatchDataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import batching
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.platform import test
-
-
-class UnbatchDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def build_dataset(self, multiplier=15.0, tensor_slice_len=2, batch_size=2):
- components = (
- np.arange(tensor_slice_len),
- np.array([[1, 2, 3]]) * np.arange(tensor_slice_len)[:, np.newaxis],
- np.array(multiplier) * np.arange(tensor_slice_len))
-
- return dataset_ops.Dataset.from_tensor_slices(components).batch(
- batch_size).apply(batching.unbatch())
-
- def testCore(self):
- tensor_slice_len = 8
- batch_size = 2
- num_outputs = tensor_slice_len
- self.run_core_tests(
- lambda: self.build_dataset(15.0, tensor_slice_len, batch_size),
- lambda: self.build_dataset(20.0, tensor_slice_len, batch_size),
- num_outputs)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/unique_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/unique_dataset_serialization_test.py
deleted file mode 100644
index 22f15b8846..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/unique_dataset_serialization_test.py
+++ /dev/null
@@ -1,40 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the UniqueDataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import unique
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.platform import test
-
-
-class UniqueDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def testUnique(self):
-
- def build_dataset(num_elements, unique_elem_range):
- return dataset_ops.Dataset.range(num_elements).map(
- lambda x: x % unique_elem_range).apply(unique.unique())
-
- self.run_core_tests(lambda: build_dataset(200, 100),
- lambda: build_dataset(40, 100), 100)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/zip_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/zip_dataset_serialization_test.py
deleted file mode 100644
index 340a6ff72e..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/zip_dataset_serialization_test.py
+++ /dev/null
@@ -1,54 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the ZipDataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.platform import test
-
-
-class ZipDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_dataset(self, arr):
- components = [
- np.tile(np.array([[1], [2], [3], [4]]), 20),
- np.tile(np.array([[12], [13], [14], [15]]), 22),
- np.array(arr)
- ]
- datasets = [
- dataset_ops.Dataset.from_tensor_slices(component)
- for component in components
- ]
- return dataset_ops.Dataset.zip((datasets[0], (datasets[1], datasets[2])))
-
- def testCore(self):
- # Equal length components
- arr = [37.0, 38.0, 39.0, 40.0]
- num_outputs = len(arr)
- self.run_core_tests(lambda: self._build_dataset(arr), None, num_outputs)
- # Variable length components
- diff_size_arr = [1.0, 2.0]
- self.run_core_tests(lambda: self._build_dataset(diff_size_arr),
- lambda: self._build_dataset(arr), 2)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py
deleted file mode 100644
index c97002a255..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py
+++ /dev/null
@@ -1,115 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.data.python.ops import shuffle_ops
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
-from tensorflow.python.platform import test
-
-
-class ShuffleAndRepeatTest(test_base.DatasetTestBase):
-
- def _build_ds(self, seed, count=5, num_elements=20):
- return dataset_ops.Dataset.range(num_elements).apply(
- shuffle_ops.shuffle_and_repeat(buffer_size=5, count=count, seed=seed))
-
- def _gen_outputs(self, ds_fn, num_outputs, verify_exhausted=True):
- get_next = ds_fn().make_one_shot_iterator().get_next()
- outputs = []
- with self.cached_session() as sess:
- for _ in range(num_outputs):
- outputs.append(sess.run(get_next))
- if verify_exhausted:
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
- return outputs
-
- def testCorrectOutput(self):
- output = self._gen_outputs(lambda: self._build_ds(10), 100)
- self.assertSequenceEqual(
- sorted(output), sorted(
- np.array([range(20) for _ in range(5)]).flatten()))
- for i in range(5):
- self.assertSequenceEqual(sorted(output[i * 20:(i + 1) * 20]), range(20))
-
- def testReshuffling(self):
- # Check that the output orders of different epochs are indeed different.
- output = self._gen_outputs(lambda: self._build_ds(10), 100)
- for i in range(4):
- epoch1 = output[i * 20:(i + 1) * 20]
- epoch2 = output[(i + 1) * 20:(i + 2) * 20]
- self.assertNotEqual(epoch1, epoch2)
-
- def testSameOrderForSameSeeds(self):
- output1 = self._gen_outputs(lambda: self._build_ds(10), 100)
- output2 = self._gen_outputs(lambda: self._build_ds(10), 100)
- self.assertEqual(output1, output2)
-
- def testDifferentOrderForDifferentSeeds(self):
- output1 = self._gen_outputs(lambda: self._build_ds(10), 100)
- output2 = self._gen_outputs(lambda: self._build_ds(20), 100)
- self.assertNotEqual(output1, output2)
- self.assertEqual(sorted(output1), sorted(output2))
-
- def testCountNone(self):
- output1 = self._gen_outputs(
- lambda: self._build_ds(10, count=None), 100, verify_exhausted=False)
- output2 = self._gen_outputs(
- lambda: self._build_ds(20, count=None), 100, verify_exhausted=False)
- self.assertNotEqual(output1, output2)
- self.assertEqual(sorted(output1), sorted(output2))
-
- def testCountMinusOne(self):
- output1 = self._gen_outputs(
- lambda: self._build_ds(10, count=-1), 100, verify_exhausted=False)
- output2 = self._gen_outputs(
- lambda: self._build_ds(20, count=-1), 100, verify_exhausted=False)
- self.assertNotEqual(output1, output2)
- self.assertEqual(sorted(output1), sorted(output2))
-
- def testInfiniteOutputs(self):
- # Asserting the iterator is exhausted after producing 100 items should fail.
- with self.assertRaises(AssertionError):
- self._gen_outputs(lambda: self._build_ds(10, count=None), 100)
- with self.assertRaises(AssertionError):
- self._gen_outputs(lambda: self._build_ds(10, count=-1), 100)
-
- def testInfiniteEmpty(self):
- with self.assertRaises(errors.OutOfRangeError):
- self._gen_outputs(lambda: self._build_ds(10, count=None, num_elements=0),
- 100)
- with self.assertRaises(errors.OutOfRangeError):
- self._gen_outputs(lambda: self._build_ds(10, count=-1, num_elements=0),
- 100)
-
- def testLargeBufferSize(self):
- with ops.Graph().as_default() as g:
- ds = dataset_ops.Dataset.range(20).apply(
- shuffle_ops.shuffle_and_repeat(buffer_size=21))
- get_next_op = ds.make_one_shot_iterator().get_next()
- with self.session(graph=g) as sess:
- sess.run(get_next_op)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py
deleted file mode 100644
index 52823d3fca..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py
+++ /dev/null
@@ -1,590 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for experimental sql input op."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.data.python.kernel_tests import sql_dataset_op_test_base
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.platform import test
-
-
-class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
-
- # Test that SqlDataset can read from a database table.
- def testReadResultSet(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
- dtypes.string), 2)
- with self.cached_session() as sess:
- for _ in range(2): # Run twice to verify statelessness of db operations.
- sess.run(
- init_op,
- feed_dict={
- self.query: "SELECT first_name, last_name, motto FROM students "
- "ORDER BY first_name DESC"
- })
- for _ in range(2): # Dataset is repeated. See setUp.
- self.assertEqual((b"John", b"Doe", b"Hi!"), sess.run(get_next))
- self.assertEqual((b"Jane", b"Moe", b"Hi again!"), sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test that SqlDataset works on a join query.
- def testReadResultSetJoinQuery(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
- dtypes.string))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query:
- "SELECT students.first_name, state, motto FROM students "
- "INNER JOIN people "
- "ON students.first_name = people.first_name "
- "AND students.last_name = people.last_name"
- })
- self.assertEqual((b"John", b"California", b"Hi!"), sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test that SqlDataset can read a database entry with a null-terminator
- # in the middle of the text and place the entry in a `string` tensor.
- def testReadResultSetNullTerminator(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
- dtypes.string))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query:
- "SELECT first_name, last_name, favorite_nonsense_word "
- "FROM students ORDER BY first_name DESC"
- })
- self.assertEqual((b"John", b"Doe", b"n\0nsense"), sess.run(get_next))
- self.assertEqual((b"Jane", b"Moe", b"nonsense\0"), sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test that SqlDataset works when used on two different queries.
- # Because the output types of the dataset must be determined at graph-creation
- # time, the two queries must have the same number and types of columns.
- def testReadResultSetReuseSqlDataset(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
- dtypes.string))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query: "SELECT first_name, last_name, motto FROM students "
- "ORDER BY first_name DESC"
- })
- self.assertEqual((b"John", b"Doe", b"Hi!"), sess.run(get_next))
- self.assertEqual((b"Jane", b"Moe", b"Hi again!"), sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
- sess.run(
- init_op,
- feed_dict={
- self.query: "SELECT first_name, last_name, state FROM people "
- "ORDER BY first_name DESC"
- })
- self.assertEqual((b"John", b"Doe", b"California"), sess.run(get_next))
- self.assertEqual((b"Benjamin", b"Franklin", b"Pennsylvania"),
- sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test that an `OutOfRangeError` is raised on the first call to
- # `get_next_str_only` if result set is empty.
- def testReadEmptyResultSet(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
- dtypes.string))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query: "SELECT first_name, last_name, motto FROM students "
- "WHERE first_name = 'Nonexistent'"
- })
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test that an error is raised when `driver_name` is invalid.
- def testReadResultSetWithInvalidDriverName(self):
- init_op = self._createSqlDataset((dtypes.string, dtypes.string,
- dtypes.string))[0]
- with self.cached_session() as sess:
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(
- init_op,
- feed_dict={
- self.driver_name: "sqlfake",
- self.query: "SELECT first_name, last_name, motto FROM students "
- "ORDER BY first_name DESC"
- })
-
- # Test that an error is raised when a column name in `query` is nonexistent
- def testReadResultSetWithInvalidColumnName(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
- dtypes.string))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query:
- "SELECT first_name, last_name, fake_column FROM students "
- "ORDER BY first_name DESC"
- })
- with self.assertRaises(errors.UnknownError):
- sess.run(get_next)
-
- # Test that an error is raised when there is a syntax error in `query`.
- def testReadResultSetOfQueryWithSyntaxError(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
- dtypes.string))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query:
- "SELEmispellECT first_name, last_name, motto FROM students "
- "ORDER BY first_name DESC"
- })
- with self.assertRaises(errors.UnknownError):
- sess.run(get_next)
-
- # Test that an error is raised when the number of columns in `query`
- # does not match the length of `output_types`.
- def testReadResultSetWithMismatchBetweenColumnsAndOutputTypes(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
- dtypes.string))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query: "SELECT first_name, last_name FROM students "
- "ORDER BY first_name DESC"
- })
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(get_next)
-
- # Test that no results are returned when `query` is an insert query rather
- # than a select query. In particular, the error refers to the number of
- # output types passed to the op not matching the number of columns in the
- # result set of the query (namely, 0 for an insert statement.)
- def testReadResultSetOfInsertQuery(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
- dtypes.string))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query:
- "INSERT INTO students (first_name, last_name, motto) "
- "VALUES ('Foo', 'Bar', 'Baz'), ('Fizz', 'Buzz', 'Fizzbuzz')"
- })
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(get_next)
-
- # Test that `SqlDataset` can read an integer from a SQLite database table and
- # place it in an `int8` tensor.
- def testReadResultSetInt8(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int8))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query: "SELECT first_name, desk_number FROM students "
- "ORDER BY first_name DESC"
- })
- self.assertEqual((b"John", 9), sess.run(get_next))
- self.assertEqual((b"Jane", 127), sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test that `SqlDataset` can read a negative or 0-valued integer from a
- # SQLite database table and place it in an `int8` tensor.
- def testReadResultSetInt8NegativeAndZero(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int8,
- dtypes.int8))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query: "SELECT first_name, income, favorite_negative_number "
- "FROM students "
- "WHERE first_name = 'John' ORDER BY first_name DESC"
- })
- self.assertEqual((b"John", 0, -2), sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test that `SqlDataset` can read a large (positive or negative) integer from
- # a SQLite database table and place it in an `int8` tensor.
- def testReadResultSetInt8MaxValues(self):
- init_op, get_next = self._createSqlDataset((dtypes.int8, dtypes.int8))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query:
- "SELECT desk_number, favorite_negative_number FROM students "
- "ORDER BY first_name DESC"
- })
- self.assertEqual((9, -2), sess.run(get_next))
- # Max and min values of int8
- self.assertEqual((127, -128), sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test that `SqlDataset` can read an integer from a SQLite database table and
- # place it in an `int16` tensor.
- def testReadResultSetInt16(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query: "SELECT first_name, desk_number FROM students "
- "ORDER BY first_name DESC"
- })
- self.assertEqual((b"John", 9), sess.run(get_next))
- self.assertEqual((b"Jane", 127), sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test that `SqlDataset` can read a negative or 0-valued integer from a
- # SQLite database table and place it in an `int16` tensor.
- def testReadResultSetInt16NegativeAndZero(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16,
- dtypes.int16))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query: "SELECT first_name, income, favorite_negative_number "
- "FROM students "
- "WHERE first_name = 'John' ORDER BY first_name DESC"
- })
- self.assertEqual((b"John", 0, -2), sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test that `SqlDataset` can read a large (positive or negative) integer from
- # a SQLite database table and place it in an `int16` tensor.
- def testReadResultSetInt16MaxValues(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query: "SELECT first_name, favorite_medium_sized_number "
- "FROM students ORDER BY first_name DESC"
- })
- # Max value of int16
- self.assertEqual((b"John", 32767), sess.run(get_next))
- # Min value of int16
- self.assertEqual((b"Jane", -32768), sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test that `SqlDataset` can read an integer from a SQLite database table and
- # place it in an `int32` tensor.
- def testReadResultSetInt32(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query: "SELECT first_name, desk_number FROM students "
- "ORDER BY first_name DESC"
- })
- self.assertEqual((b"John", 9), sess.run(get_next))
- self.assertEqual((b"Jane", 127), sess.run(get_next))
-
- # Test that `SqlDataset` can read a negative or 0-valued integer from a
- # SQLite database table and place it in an `int32` tensor.
- def testReadResultSetInt32NegativeAndZero(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query: "SELECT first_name, income FROM students "
- "ORDER BY first_name DESC"
- })
- self.assertEqual((b"John", 0), sess.run(get_next))
- self.assertEqual((b"Jane", -20000), sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test that `SqlDataset` can read a large (positive or negative) integer from
- # a SQLite database table and place it in an `int32` tensor.
- def testReadResultSetInt32MaxValues(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query: "SELECT first_name, favorite_number FROM students "
- "ORDER BY first_name DESC"
- })
- # Max value of int32
- self.assertEqual((b"John", 2147483647), sess.run(get_next))
- # Min value of int32
- self.assertEqual((b"Jane", -2147483648), sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test that `SqlDataset` can read a numeric `varchar` from a SQLite database
- # table and place it in an `int32` tensor.
- def testReadResultSetInt32VarCharColumnAsInt(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query: "SELECT first_name, school_id FROM students "
- "ORDER BY first_name DESC"
- })
- self.assertEqual((b"John", 123), sess.run(get_next))
- self.assertEqual((b"Jane", 1000), sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test that `SqlDataset` can read an integer from a SQLite database table
- # and place it in an `int64` tensor.
- def testReadResultSetInt64(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query: "SELECT first_name, desk_number FROM students "
- "ORDER BY first_name DESC"
- })
- self.assertEqual((b"John", 9), sess.run(get_next))
- self.assertEqual((b"Jane", 127), sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test that `SqlDataset` can read a negative or 0-valued integer from a
- # SQLite database table and place it in an `int64` tensor.
- def testReadResultSetInt64NegativeAndZero(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query: "SELECT first_name, income FROM students "
- "ORDER BY first_name DESC"
- })
- self.assertEqual((b"John", 0), sess.run(get_next))
- self.assertEqual((b"Jane", -20000), sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test that `SqlDataset` can read a large (positive or negative) integer from
- # a SQLite database table and place it in an `int64` tensor.
- def testReadResultSetInt64MaxValues(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query:
- "SELECT first_name, favorite_big_number FROM students "
- "ORDER BY first_name DESC"
- })
- # Max value of int64
- self.assertEqual((b"John", 9223372036854775807), sess.run(get_next))
- # Min value of int64
- self.assertEqual((b"Jane", -9223372036854775808), sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test that `SqlDataset` can read an integer from a SQLite database table and
- # place it in a `uint8` tensor.
- def testReadResultSetUInt8(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint8))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query: "SELECT first_name, desk_number FROM students "
- "ORDER BY first_name DESC"
- })
- self.assertEqual((b"John", 9), sess.run(get_next))
- self.assertEqual((b"Jane", 127), sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test that `SqlDataset` can read the minimum and maximum uint8 values from a
- # SQLite database table and place them in `uint8` tensors.
- def testReadResultSetUInt8MinAndMaxValues(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint8))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query: "SELECT first_name, brownie_points FROM students "
- "ORDER BY first_name DESC"
- })
- # Min value of uint8
- self.assertEqual((b"John", 0), sess.run(get_next))
- # Max value of uint8
- self.assertEqual((b"Jane", 255), sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test that `SqlDataset` can read an integer from a SQLite database table
- # and place it in a `uint16` tensor.
- def testReadResultSetUInt16(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint16))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query: "SELECT first_name, desk_number FROM students "
- "ORDER BY first_name DESC"
- })
- self.assertEqual((b"John", 9), sess.run(get_next))
- self.assertEqual((b"Jane", 127), sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test that `SqlDataset` can read the minimum and maximum uint16 values from a
- # SQLite database table and place them in `uint16` tensors.
- def testReadResultSetUInt16MinAndMaxValues(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint16))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query: "SELECT first_name, account_balance FROM students "
- "ORDER BY first_name DESC"
- })
- # Min value of uint16
- self.assertEqual((b"John", 0), sess.run(get_next))
- # Max value of uint16
- self.assertEqual((b"Jane", 65535), sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test that `SqlDataset` can read a 0-valued and 1-valued integer from a
- # SQLite database table and place them as `True` and `False` respectively
- # in `bool` tensors.
- def testReadResultSetBool(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.bool))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query:
- "SELECT first_name, registration_complete FROM students "
- "ORDER BY first_name DESC"
- })
- self.assertEqual((b"John", True), sess.run(get_next))
- self.assertEqual((b"Jane", False), sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test that `SqlDataset` can read an integer that is not 0-valued or 1-valued
- # from a SQLite database table and place it as `True` in a `bool` tensor.
- def testReadResultSetBoolNotZeroOrOne(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.bool))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query: "SELECT first_name, favorite_medium_sized_number "
- "FROM students ORDER BY first_name DESC"
- })
- self.assertEqual((b"John", True), sess.run(get_next))
- self.assertEqual((b"Jane", True), sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test that `SqlDataset` can read a float from a SQLite database table
- # and place it in a `float64` tensor.
- def testReadResultSetFloat64(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
- dtypes.float64))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query:
- "SELECT first_name, last_name, victories FROM townspeople "
- "ORDER BY first_name"
- })
- self.assertEqual((b"George", b"Washington", 20.0), sess.run(get_next))
- self.assertEqual((b"John", b"Adams", -19.95), sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test that `SqlDataset` can read a float from a SQLite database table beyond
- # the precision of 64-bit IEEE, without throwing an error. Test that
- # `SqlDataset` identifies such a value as equal to itself.
- def testReadResultSetFloat64OverlyPrecise(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
- dtypes.float64))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query:
- "SELECT first_name, last_name, accolades FROM townspeople "
- "ORDER BY first_name"
- })
- self.assertEqual(
- (b"George", b"Washington",
- 1331241.321342132321324589798264627463827647382647382643874),
- sess.run(get_next))
- self.assertEqual(
- (b"John", b"Adams",
- 1331241321342132321324589798264627463827647382647382643874.0),
- sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test that `SqlDataset` can read a float from a SQLite database table,
- # representing the largest integer representable as a 64-bit IEEE float
- # such that the previous integer is also representable as a 64-bit IEEE float.
- # Test that `SqlDataset` can distinguish these two numbers.
- def testReadResultSetFloat64LargestConsecutiveWholeNumbersNotEqual(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
- dtypes.float64))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query:
- "SELECT first_name, last_name, triumphs FROM townspeople "
- "ORDER BY first_name"
- })
- self.assertNotEqual((b"George", b"Washington", 9007199254740992.0),
- sess.run(get_next))
- self.assertNotEqual((b"John", b"Adams", 9007199254740991.0),
- sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py
deleted file mode 100644
index 319a2ea263..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py
+++ /dev/null
@@ -1,95 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Base class for testing SqlDataset."""
-
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import os
-
-import sqlite3
-
-from tensorflow.contrib.data.python.ops import readers
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.framework import dtypes
-from tensorflow.python.ops import array_ops
-from tensorflow.python.platform import test
-
-
-class SqlDatasetTestBase(test_base.DatasetTestBase):
- """Base class for setting up and testing SqlDataset."""
-
- def _createSqlDataset(self, output_types, num_repeats=1):
- dataset = readers.SqlDataset(self.driver_name, self.data_source_name,
- self.query, output_types).repeat(num_repeats)
- iterator = dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
- return init_op, get_next
-
- def setUp(self):
- self.data_source_name = os.path.join(test.get_temp_dir(), "tftest.sqlite")
- self.driver_name = array_ops.placeholder_with_default(
- array_ops.constant("sqlite", dtypes.string), shape=[])
- self.query = array_ops.placeholder(dtypes.string, shape=[])
-
- conn = sqlite3.connect(self.data_source_name)
- c = conn.cursor()
- c.execute("DROP TABLE IF EXISTS students")
- c.execute("DROP TABLE IF EXISTS people")
- c.execute("DROP TABLE IF EXISTS townspeople")
- c.execute(
- "CREATE TABLE IF NOT EXISTS students (id INTEGER NOT NULL PRIMARY KEY, "
- "first_name VARCHAR(100), last_name VARCHAR(100), motto VARCHAR(100), "
- "school_id VARCHAR(100), favorite_nonsense_word VARCHAR(100), "
- "desk_number INTEGER, income INTEGER, favorite_number INTEGER, "
- "favorite_big_number INTEGER, favorite_negative_number INTEGER, "
- "favorite_medium_sized_number INTEGER, brownie_points INTEGER, "
- "account_balance INTEGER, registration_complete INTEGER)")
- c.executemany(
- "INSERT INTO students (first_name, last_name, motto, school_id, "
- "favorite_nonsense_word, desk_number, income, favorite_number, "
- "favorite_big_number, favorite_negative_number, "
- "favorite_medium_sized_number, brownie_points, account_balance, "
- "registration_complete) "
- "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
- [("John", "Doe", "Hi!", "123", "n\0nsense", 9, 0, 2147483647,
- 9223372036854775807, -2, 32767, 0, 0, 1),
- ("Jane", "Moe", "Hi again!", "1000", "nonsense\0", 127, -20000,
- -2147483648, -9223372036854775808, -128, -32768, 255, 65535, 0)])
- c.execute(
- "CREATE TABLE IF NOT EXISTS people (id INTEGER NOT NULL PRIMARY KEY, "
- "first_name VARCHAR(100), last_name VARCHAR(100), state VARCHAR(100))")
- c.executemany(
- "INSERT INTO PEOPLE (first_name, last_name, state) VALUES (?, ?, ?)",
- [("Benjamin", "Franklin", "Pennsylvania"), ("John", "Doe",
- "California")])
- c.execute(
- "CREATE TABLE IF NOT EXISTS townspeople (id INTEGER NOT NULL PRIMARY "
- "KEY, first_name VARCHAR(100), last_name VARCHAR(100), victories "
- "FLOAT, accolades FLOAT, triumphs FLOAT)")
- c.executemany(
- "INSERT INTO townspeople (first_name, last_name, victories, "
- "accolades, triumphs) VALUES (?, ?, ?, ?, ?)",
- [("George", "Washington", 20.00,
- 1331241.321342132321324589798264627463827647382647382643874,
- 9007199254740991.0),
- ("John", "Adams", -19.95,
- 1331241321342132321324589798264627463827647382647382643874.0,
- 9007199254740992.0)])
- conn.commit()
- conn.close()
diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
deleted file mode 100644
index be8ae5e955..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
+++ /dev/null
@@ -1,253 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline statistics gathering ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.data.python.kernel_tests import stats_dataset_test_base
-from tensorflow.contrib.data.python.ops import stats_ops
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.platform import test
-
-
-class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
-
- def testBytesProduced(self):
- stats_aggregator = stats_ops.StatsAggregator()
- dataset = dataset_ops.Dataset.range(100).map(
- lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).apply(
- stats_ops.bytes_produced_stats("bytes_produced")).apply(
- stats_ops.set_stats_aggregator(stats_aggregator))
- iterator = dataset.make_initializable_iterator()
- next_element = iterator.get_next()
- summary_t = stats_aggregator.get_summary()
-
- with self.cached_session() as sess:
- sess.run(iterator.initializer)
- expected_sum = 0.0
- for i in range(100):
- self.assertAllEqual(
- np.array([i] * i, dtype=np.int64), sess.run(next_element))
- summary_str = sess.run(summary_t)
- self._assertSummaryHasCount(summary_str, "bytes_produced", float(i + 1))
- expected_sum += i * 8.0
- self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
- summary_str = sess.run(summary_t)
- self._assertSummaryHasCount(summary_str, "bytes_produced", 100.0)
- self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum)
-
- def testLatencyStats(self):
- stats_aggregator = stats_ops.StatsAggregator()
- dataset = dataset_ops.Dataset.range(100).apply(
- stats_ops.latency_stats("record_latency")).apply(
- stats_ops.set_stats_aggregator(stats_aggregator))
- iterator = dataset.make_initializable_iterator()
- next_element = iterator.get_next()
- summary_t = stats_aggregator.get_summary()
-
- with self.cached_session() as sess:
- sess.run(iterator.initializer)
- for i in range(100):
- self.assertEqual(i, sess.run(next_element))
- self._assertSummaryHasCount(
- sess.run(summary_t), "record_latency", float(i + 1))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
- self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 100.0)
-
- def testPrefetchBufferUtilization(self):
- stats_aggregator = stats_ops.StatsAggregator()
- dataset = dataset_ops.Dataset.range(100).map(
- lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).prefetch(
- -1).apply(stats_ops.set_stats_aggregator(stats_aggregator))
- iterator = dataset.make_initializable_iterator()
- next_element = iterator.get_next()
- summary_t = stats_aggregator.get_summary()
-
- with self.cached_session() as sess:
- sess.run(iterator.initializer)
- for i in range(100):
- self.assertAllEqual(
- np.array([i] * i, dtype=np.int64), sess.run(next_element))
- summary_str = sess.run(summary_t)
- self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization",
- float(i + 1))
- self._assertSummaryContains(summary_str, "Prefetch::buffer_capacity")
- self._assertSummaryContains(summary_str, "Prefetch::buffer_size")
- self._assertSummaryHasRange(summary_str, "Prefetch::buffer_utilization",
- 0, 1)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
- summary_str = sess.run(summary_t)
- self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization",
- 100)
-
- def testPrefetchBufferScalars(self):
- stats_aggregator = stats_ops.StatsAggregator()
- dataset = dataset_ops.Dataset.range(10).map(
- lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).prefetch(
- 0).apply(stats_ops.set_stats_aggregator(stats_aggregator))
- iterator = dataset.make_initializable_iterator()
- next_element = iterator.get_next()
- summary_t = stats_aggregator.get_summary()
-
- with self.cached_session() as sess:
- sess.run(iterator.initializer)
- for i in range(10):
- self.assertAllEqual(
- np.array([i] * i, dtype=np.int64), sess.run(next_element))
- summary_str = sess.run(summary_t)
- self._assertSummaryHasScalarValue(summary_str,
- "Prefetch::buffer_capacity", 0)
- self._assertSummaryHasScalarValue(summary_str, "Prefetch::buffer_size",
- 0)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testFilteredElementsStats(self):
- stats_aggregator = stats_ops.StatsAggregator()
- dataset = dataset_ops.Dataset.range(101).filter(
- lambda x: math_ops.equal(math_ops.mod(x, 3), 0)).apply(
- stats_ops.set_stats_aggregator(stats_aggregator))
- iterator = dataset.make_initializable_iterator()
- next_element = iterator.get_next()
- summary_t = stats_aggregator.get_summary()
-
- with self.test_session() as sess:
- sess.run(iterator.initializer)
- for i in range(34):
- self.assertEqual(i * 3, sess.run(next_element))
- if i is not 0:
- self._assertSummaryHasScalarValue(
- sess.run(summary_t), "Filter::dropped_elements", float(i * 2))
- self._assertSummaryHasScalarValue(
- sess.run(summary_t), "Filter::filtered_elements", float(i + 1))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
- self._assertSummaryHasScalarValue(
- sess.run(summary_t), "Filter::dropped_elements", 67.0)
- self._assertSummaryHasScalarValue(
- sess.run(summary_t), "Filter::filtered_elements", 34.0)
-
- def testReinitialize(self):
- stats_aggregator = stats_ops.StatsAggregator()
- dataset = dataset_ops.Dataset.range(100).apply(
- stats_ops.latency_stats("record_latency")).apply(
- stats_ops.set_stats_aggregator(stats_aggregator))
- iterator = dataset.make_initializable_iterator()
- next_element = iterator.get_next()
- summary_t = stats_aggregator.get_summary()
-
- with self.cached_session() as sess:
- for j in range(5):
- sess.run(iterator.initializer)
- for i in range(100):
- self.assertEqual(i, sess.run(next_element))
- self._assertSummaryHasCount(
- sess.run(summary_t), "record_latency", float((j * 100) + i + 1))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
- self._assertSummaryHasCount(
- sess.run(summary_t), "record_latency", (j + 1) * 100.0)
-
- def testNoAggregatorRegistered(self):
- dataset = dataset_ops.Dataset.range(100).apply(
- stats_ops.latency_stats("record_latency"))
- iterator = dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(iterator.initializer)
- for i in range(100):
- self.assertEqual(i, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testMultipleTags(self):
- stats_aggregator = stats_ops.StatsAggregator()
- dataset = dataset_ops.Dataset.range(100).apply(
- stats_ops.latency_stats("record_latency")).apply(
- stats_ops.latency_stats("record_latency_2")).apply(
- stats_ops.set_stats_aggregator(stats_aggregator))
- iterator = dataset.make_initializable_iterator()
- next_element = iterator.get_next()
- summary_t = stats_aggregator.get_summary()
-
- with self.cached_session() as sess:
- sess.run(iterator.initializer)
- for i in range(100):
- self.assertEqual(i, sess.run(next_element))
- self._assertSummaryHasCount(
- sess.run(summary_t), "record_latency", float(i + 1))
- self._assertSummaryHasCount(
- sess.run(summary_t), "record_latency_2", float(i + 1))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
- self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 100.0)
- self._assertSummaryHasCount(
- sess.run(summary_t), "record_latency_2", 100.0)
-
- def testRepeatedTags(self):
- stats_aggregator = stats_ops.StatsAggregator()
- dataset = dataset_ops.Dataset.range(100).apply(
- stats_ops.latency_stats("record_latency")).apply(
- stats_ops.latency_stats("record_latency")).apply(
- stats_ops.set_stats_aggregator(stats_aggregator))
- iterator = dataset.make_initializable_iterator()
- next_element = iterator.get_next()
- summary_t = stats_aggregator.get_summary()
-
- with self.cached_session() as sess:
- sess.run(iterator.initializer)
- for i in range(100):
- self.assertEqual(i, sess.run(next_element))
- self._assertSummaryHasCount(
- sess.run(summary_t), "record_latency", float(2 * (i + 1)))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
- self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0)
-
- def testMultipleIteratorsSameAggregator(self):
- stats_aggregator = stats_ops.StatsAggregator()
- dataset = dataset_ops.Dataset.range(100).apply(
- stats_ops.latency_stats("record_latency")).apply(
- stats_ops.set_stats_aggregator(stats_aggregator))
- iterator_0 = dataset.make_initializable_iterator()
- iterator_1 = dataset.make_initializable_iterator()
- next_element = iterator_0.get_next() + iterator_1.get_next()
- summary_t = stats_aggregator.get_summary()
-
- with self.cached_session() as sess:
- sess.run([iterator_0.initializer, iterator_1.initializer])
- for i in range(100):
- self.assertEqual(i * 2, sess.run(next_element))
- self._assertSummaryHasCount(
- sess.run(summary_t), "record_latency", float(2 * (i + 1)))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
- self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py
deleted file mode 100644
index 80f2625927..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py
+++ /dev/null
@@ -1,71 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Base class for testing the input pipeline statistics gathering ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-
-from tensorflow.core.framework import summary_pb2
-from tensorflow.python.data.kernel_tests import test_base
-
-
-class StatsDatasetTestBase(test_base.DatasetTestBase):
- """Base class for testing statistics gathered in `StatsAggregator`."""
-
- def _assertSummaryContains(self, summary_str, tag):
- summary_proto = summary_pb2.Summary()
- summary_proto.ParseFromString(summary_str)
- for value in summary_proto.value:
- if tag == value.tag:
- return
- self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
-
- def _assertSummaryHasCount(self, summary_str, tag, expected_value):
- summary_proto = summary_pb2.Summary()
- summary_proto.ParseFromString(summary_str)
- for value in summary_proto.value:
- if tag == value.tag:
- self.assertEqual(expected_value, value.histo.num)
- return
- self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
-
- def _assertSummaryHasRange(self, summary_str, tag, min_value, max_value):
- summary_proto = summary_pb2.Summary()
- summary_proto.ParseFromString(summary_str)
- for value in summary_proto.value:
- if tag == value.tag:
- self.assertLessEqual(min_value, value.histo.min)
- self.assertGreaterEqual(max_value, value.histo.max)
- return
- self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
-
- def _assertSummaryHasSum(self, summary_str, tag, expected_value):
- summary_proto = summary_pb2.Summary()
- summary_proto.ParseFromString(summary_str)
- for value in summary_proto.value:
- if tag == value.tag:
- self.assertEqual(expected_value, value.histo.sum)
- return
- self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
-
- def _assertSummaryHasScalarValue(self, summary_str, tag, expected_value):
- summary_proto = summary_pb2.Summary()
- summary_proto.ParseFromString(summary_str)
- for value in summary_proto.value:
- if tag == value.tag:
- self.assertEqual(expected_value, value.simple_value)
- return
- self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
diff --git a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py
deleted file mode 100644
index 08de3a9143..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py
+++ /dev/null
@@ -1,91 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline statistics gathering ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import threading
-
-from absl.testing import parameterized
-import numpy as np
-
-from tensorflow.contrib.data.python.ops import threadpool
-from tensorflow.contrib.data.python.ops import unique
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.ops import script_ops
-from tensorflow.python.platform import test
-
-
-class OverrideThreadpoolDatasetTest(test_base.DatasetTestBase,
- parameterized.TestCase):
-
- @parameterized.named_parameters(
- ("1", 1, None),
- ("2", 2, None),
- ("3", 4, None),
- ("4", 8, None),
- ("5", 16, None),
- ("6", 4, -1),
- ("7", 4, 0),
- ("8", 4, 1),
- ("9", 4, 4),
- )
- def testNumThreads(self, num_threads, max_intra_op_parallelism):
-
- def get_thread_id(_):
- # Python creates a dummy thread object to represent the current
- # thread when called from an "alien" thread (such as a
- # `PrivateThreadPool` thread in this case). It does not include
- # the TensorFlow-given display name, but it has a unique
- # identifier that maps one-to-one with the underlying OS thread.
- return np.array(threading.current_thread().ident).astype(np.int64)
-
- dataset = (
- dataset_ops.Dataset.range(1000).map(
- lambda x: script_ops.py_func(get_thread_id, [x], dtypes.int64),
- num_parallel_calls=32).apply(unique.unique()))
-
- dataset = threadpool.override_threadpool(
- dataset,
- threadpool.PrivateThreadPool(
- num_threads,
- max_intra_op_parallelism=max_intra_op_parallelism,
- display_name="private_thread_pool_%d" % num_threads))
-
- iterator = dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(iterator.initializer)
- thread_ids = []
- try:
- while True:
- thread_ids.append(sess.run(next_element))
- except errors.OutOfRangeError:
- pass
- self.assertEqual(len(thread_ids), len(set(thread_ids)))
- self.assertGreater(len(thread_ids), 0)
- # NOTE(mrry): We don't control the thread pool scheduling, and
- # so cannot guarantee that all of the threads in the pool will
- # perform work.
- self.assertLessEqual(len(thread_ids), num_threads)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py
deleted file mode 100644
index 8856ce5afb..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py
+++ /dev/null
@@ -1,83 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.data.python.ops import unique
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.platform import test
-from tensorflow.python.util import compat
-
-
-class UniqueDatasetTest(test_base.DatasetTestBase):
-
- def _testSimpleHelper(self, dtype, test_cases):
- """Test the `unique()` transformation on a list of test cases.
-
- Args:
- dtype: The `dtype` of the elements in each test case.
- test_cases: A list of pairs of lists. The first component is the test
- input that will be passed to the transformation; the second component
- is the expected sequence of outputs from the transformation.
- """
-
- # The `current_test_case` will be updated when we loop over `test_cases`
- # below; declare it here so that the generator can capture it once.
- current_test_case = []
- dataset = dataset_ops.Dataset.from_generator(lambda: current_test_case,
- dtype).apply(unique.unique())
- iterator = dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- for test_case, expected in test_cases:
- current_test_case = test_case
- sess.run(iterator.initializer)
- for element in expected:
- if dtype == dtypes.string:
- element = compat.as_bytes(element)
- self.assertAllEqual(element, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testSimpleInt(self):
- for dtype in [dtypes.int32, dtypes.int64]:
- self._testSimpleHelper(dtype, [
- ([], []),
- ([1], [1]),
- ([1, 1, 1, 1, 1, 1, 1], [1]),
- ([1, 2, 3, 4], [1, 2, 3, 4]),
- ([1, 2, 4, 3, 2, 1, 2, 3, 4], [1, 2, 4, 3]),
- ([[1], [1, 1], [1, 1, 1]], [[1], [1, 1], [1, 1, 1]]),
- ([[1, 1], [1, 1], [2, 2], [3, 3], [1, 1]], [[1, 1], [2, 2], [3, 3]]),
- ])
-
- def testSimpleString(self):
- self._testSimpleHelper(dtypes.string, [
- ([], []),
- (["hello"], ["hello"]),
- (["hello", "hello", "hello"], ["hello"]),
- (["hello", "world"], ["hello", "world"]),
- (["foo", "bar", "baz", "baz", "bar", "foo"], ["foo", "bar", "baz"]),
- ])
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
deleted file mode 100644
index 79134c7bc6..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
+++ /dev/null
@@ -1,527 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from absl.testing import parameterized
-import numpy as np
-
-from tensorflow.contrib.data.python.ops import batching
-from tensorflow.contrib.data.python.ops import grouping
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import sparse_ops
-from tensorflow.python.platform import test
-
-
-class WindowDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
-
- def _structuredDataset(self, structure, shape, dtype):
- if structure is None:
- return dataset_ops.Dataset.from_tensors(
- array_ops.zeros(shape, dtype=dtype))
- else:
- return dataset_ops.Dataset.zip(
- tuple([
- self._structuredDataset(substructure, shape, dtype)
- for substructure in structure
- ]))
-
- def _structuredElement(self, structure, shape, dtype):
- if structure is None:
- return array_ops.zeros(shape, dtype=dtype)
- else:
- return tuple([
- self._structuredElement(substructure, shape, dtype)
- for substructure in structure
- ])
-
- def _assertEqual(self, xs, ys):
- self.assertEqual(type(xs), type(ys))
- if isinstance(xs, tuple) and isinstance(ys, tuple):
- self.assertEqual(len(xs), len(ys))
- for x, y in zip(xs, ys):
- self._assertEqual(x, y)
- elif isinstance(xs, np.ndarray) and isinstance(ys, np.ndarray):
- self.assertAllEqual(xs, ys)
- else:
- self.assertEqual(xs, ys)
-
- @parameterized.named_parameters(
- ("1", None, np.int32([]), dtypes.bool),
- ("2", None, np.int32([]), dtypes.int32),
- ("3", None, np.int32([]), dtypes.float32),
- ("4", None, np.int32([]), dtypes.string),
- ("5", None, np.int32([2]), dtypes.int32),
- ("6", None, np.int32([2, 2]), dtypes.int32),
- ("7", (None, None, None), np.int32([]), dtypes.int32),
- ("8", (None, (None, None)), np.int32([]), dtypes.int32),
- )
- def testWindowDatasetFlatMap(self, structure, shape, dtype):
- """Tests windowing by chaining it with flat map.
-
- Args:
- structure: the input structure
- shape: the input shape
- dtype: the input data type
- """
-
- def fn(*args):
- if len(args) == 1 and not isinstance(args[0], tuple):
- return args[0]
- return dataset_ops.Dataset.zip(
- tuple([fn(*arg) if isinstance(arg, tuple) else arg for arg in args]))
-
- dataset = self._structuredDataset(structure, shape, dtype).repeat(5).apply(
- grouping.window_dataset(5)).flat_map(fn)
- get_next = dataset.make_one_shot_iterator().get_next()
- with self.cached_session() as sess:
- expected = sess.run(self._structuredElement(structure, shape, dtype))
- for _ in range(5):
- actual = sess.run(get_next)
- self._assertEqual(expected, actual)
-
- @parameterized.named_parameters(
- ("1", None, np.int32([]), dtypes.bool),
- ("2", None, np.int32([]), dtypes.int32),
- ("3", None, np.int32([]), dtypes.float32),
- ("4", None, np.int32([]), dtypes.string),
- ("5", None, np.int32([2]), dtypes.int32),
- ("6", None, np.int32([2, 2]), dtypes.int32),
- ("7", (None, None, None), np.int32([]), dtypes.int32),
- ("8", (None, (None, None)), np.int32([]), dtypes.int32),
- )
- def testWindowDatasetBatchDense(self, structure, shape, dtype):
- """Tests batching of dense tensor windows.
-
- Args:
- structure: the input structure
- shape: the input shape
- dtype: the input data type
- """
-
- def fn(*args):
- if len(args) == 1 and not isinstance(args[0], tuple):
- return batching.batch_window(args[0])
-
- return tuple([
- fn(*arg) if isinstance(arg, tuple) else batching.batch_window(arg)
- for arg in args
- ])
-
- dataset = self._structuredDataset(structure, shape, dtype).repeat(5).apply(
- grouping.window_dataset(5)).apply(grouping._map_x_dataset(fn))
- get_next = dataset.make_one_shot_iterator().get_next()
- with self.cached_session() as sess:
- expected = sess.run(
- self._structuredElement(structure, np.concatenate(
- ([5], shape), axis=0), dtype))
- actual = sess.run(get_next)
- self._assertEqual(expected, actual)
-
- @parameterized.named_parameters(
- ("1", np.int32([])),
- ("2", np.int32([1])),
- ("3", np.int32([1, 2, 3])),
- )
- def testWindowDatasetBatchDenseDynamicShape(self, shape):
- """Tests batching of dynamically shaped dense tensor windows.
-
- Args:
- shape: the input shape
- """
-
- shape_t = array_ops.placeholder(dtypes.int32)
- dataset = dataset_ops.Dataset.from_tensors(
- array_ops.zeros(shape_t)).repeat(5).apply(
- grouping.window_dataset(5)).apply(
- grouping._map_x_dataset(batching.batch_window))
- iterator = dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- sess.run(init_op, {shape_t: shape})
- expected = sess.run(
- self._structuredElement(None, np.concatenate(([5], shape), axis=0),
- dtypes.int32))
- actual = sess.run(get_next)
- self._assertEqual(expected, actual)
-
- def _make_dense_to_sparse_fn(self, is_scalar):
-
- def dense_to_sparse_scalar(tensor):
- indices = [[]]
- values = array_ops.expand_dims(tensor, 0)
- shape = []
- return sparse_tensor.SparseTensorValue(indices, values, shape)
-
- def dense_to_sparse_non_scalar(tensor):
- indices = array_ops.where(array_ops.ones_like(tensor, dtype=dtypes.bool))
- values = array_ops.gather_nd(tensor, indices)
- shape = array_ops.shape(tensor, out_type=dtypes.int64)
- return sparse_tensor.SparseTensorValue(indices, values, shape)
-
- if is_scalar:
- return dense_to_sparse_scalar
- return dense_to_sparse_non_scalar
-
- def _structuredSparseDataset(self, structure, shape, dtype):
- dense_to_sparse = self._make_dense_to_sparse_fn(len(shape) == 0) # pylint: disable=g-explicit-length-test
- if structure is None:
- return dataset_ops.Dataset.from_tensors(
- dense_to_sparse(array_ops.zeros(shape, dtype=dtype)))
- else:
- return dataset_ops.Dataset.zip(
- tuple([
- self._structuredSparseDataset(substructure, shape, dtype)
- for substructure in structure
- ]))
-
- def _structuredSparseElement(self, structure, shape, dtype):
- dense_to_sparse = self._make_dense_to_sparse_fn(len(shape) == 0) # pylint: disable=g-explicit-length-test
- if structure is None:
- return dense_to_sparse(array_ops.zeros(shape, dtype=dtype))
- else:
- return tuple([
- self._structuredSparseElement(substructure, shape, dtype)
- for substructure in structure
- ])
-
- @parameterized.named_parameters(
- ("1", None, np.int32([]), dtypes.bool),
- ("2", None, np.int32([]), dtypes.int32),
- ("3", None, np.int32([]), dtypes.float32),
- ("4", None, np.int32([]), dtypes.string),
- ("5", None, np.int32([2]), dtypes.int32),
- ("6", None, np.int32([2, 2]), dtypes.int32),
- ("7", (None, None, None), np.int32([]), dtypes.int32),
- ("8", (None, (None, None)), np.int32([]), dtypes.int32),
- )
- def testWindowDatasetBatchSparse(self, structure, shape, dtype):
- """Tests batching of sparse tensor windows.
-
- Args:
- structure: the input structure
- shape: the input shape
- dtype: the input data type
- """
-
- def fn(*args):
- if len(args) == 1 and not isinstance(args[0], tuple):
- return batching.batch_window(args[0])
-
- return tuple([
- fn(*arg) if isinstance(arg, tuple) else batching.batch_window(arg)
- for arg in args
- ])
-
- dataset = self._structuredSparseDataset(
- structure, shape, dtype).repeat(5).apply(
- grouping.window_dataset(5)).apply(grouping._map_x_dataset(fn))
- get_next = dataset.make_one_shot_iterator().get_next()
- with self.cached_session() as sess:
- expected = sess.run(
- self._structuredSparseElement(structure,
- np.concatenate(([5], shape), axis=0),
- dtype))
- actual = sess.run(get_next)
- self._assertEqual(expected, actual)
-
- @parameterized.named_parameters(
- ("1", np.int32([])),
- ("2", np.int32([1])),
- ("3", np.int32([1, 2, 3])),
- )
- def testWindowDatasetBatchSparseDynamicShape(self, shape):
- """Tests batching of dynamically shaped sparse tensor windows.
-
- Args:
- shape: the input shape
- """
-
- shape_t = array_ops.placeholder(dtypes.int32)
- dataset = dataset_ops.Dataset.from_tensors(array_ops.zeros(shape_t)).map(
- self._make_dense_to_sparse_fn(len(shape) == 0)).repeat(5).apply( # pylint: disable=g-explicit-length-test
- grouping.window_dataset(5)).apply(
- grouping._map_x_dataset(batching.batch_window))
- iterator = dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- sess.run(init_op, {shape_t: shape})
- expected = sess.run(
- self._structuredSparseElement(None,
- np.concatenate(([5], shape), axis=0),
- dtypes.int32))
- actual = sess.run(get_next)
- self._assertEqual(expected, actual)
-
- def _structuredRaggedDataset(self, structure, shapes, dtype):
-
- if structure is None:
- return dataset_ops.Dataset.from_tensor_slices(shapes).map(
- lambda shape: array_ops.zeros(shape, dtype=dtype))
- else:
- return dataset_ops.Dataset.zip(
- tuple([
- self._structuredRaggedDataset(substructure, shapes, dtype)
- for substructure in structure
- ]))
-
- @parameterized.named_parameters(
- ("1", None, np.int32([[1], [2], [3]]), dtypes.bool, [-1]),
- ("2", None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
- ("3", None, np.int32([[1], [2], [3]]), dtypes.float32, [-1]),
- ("4", None, np.int32([[1], [2], [3]]), dtypes.string, [-1]),
- ("5", None, np.int32([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]),
- ("6", None, np.int32([[3, 1, 3], [1, 3, 1]]), dtypes.int32, [-1, -1, -1]),
- ("7", (None, None, None), np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
- ("8", (None,
- (None, None)), np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
- ("9", None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
- ("10", None, np.int32([[1], [2], [3]]), dtypes.int32, np.int32([10])),
- )
- def testWindowDatasetPaddedBatchDense(self, structure, shapes, dtype,
- padded_shape):
- """Tests padded batching of dense tensor windows.
-
- Args:
- structure: the input structure
- shapes: the input shapes
- dtype: the input data type
- padded_shape: the shape to pad the output to
- """
-
- def fn(*args):
- if len(args) == 1 and not isinstance(args[0], tuple):
- return batching.padded_batch_window(args[0], padded_shape)
-
- return tuple([
- fn(*arg) if isinstance(arg, tuple) else batching.padded_batch_window(
- arg, padded_shape) for arg in args
- ])
-
- dataset = self._structuredRaggedDataset(structure, shapes, dtype).apply(
- grouping.window_dataset(len(shapes))).apply(
- grouping._map_x_dataset(fn))
- get_next = dataset.make_one_shot_iterator().get_next()
- with self.cached_session() as sess:
- expected_shape = np.maximum(np.amax(shapes, axis=0), padded_shape)
- expected = sess.run(
- self._structuredElement(
- structure,
- np.concatenate((np.int32([len(shapes)]), expected_shape)), dtype))
- actual = sess.run(get_next)
- self._assertEqual(expected, actual)
-
- @parameterized.named_parameters(
- ("1", np.int32([[1], [2], [3]]), [-1]),
- ("2", np.int32([[1, 3], [2, 2], [3, 1]]), [-1, -1]),
- ("3", np.int32([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]),
- )
- def testWindowDatasetPaddedBatchDenseDynamicShape(self, shapes, padded_shape):
- """Tests padded batching of dynamically shaped dense tensor windows.
-
- Args:
- shapes: the input shapes
- padded_shape: the shape to pad the output to
- """
-
- shapes_t = array_ops.placeholder(dtypes.int32)
- dataset = dataset_ops.Dataset.from_tensor_slices(shapes_t).map(
- lambda shape: array_ops.zeros(shape, dtype=dtypes.int32)).apply(
- grouping.window_dataset(len(shapes))).apply(
- grouping._map_x_dataset(
- lambda x: batching.padded_batch_window(x, padded_shape)))
- iterator = dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- sess.run(init_op, {shapes_t: shapes})
- expected_shape = np.maximum(np.amax(shapes, axis=0), padded_shape)
- expected = sess.run(
- self._structuredElement(
- None, np.concatenate((np.int32([len(shapes)]), expected_shape)),
- dtypes.int32))
- actual = sess.run(get_next)
- self._assertEqual(expected, actual)
-
- @parameterized.named_parameters(
- ("1", np.int32([[1]]), np.int32([0])),
- ("2", np.int32([[10], [20]]), np.int32([15])),
- )
- def testWindowDatasetPaddedBatchDenseInvalid(self, shapes, padded_shape):
- """Tests invalid padded batching of dense tensor windows.
-
- Args:
- shapes: the input shapes
- padded_shape: the shape to pad the output to
- """
-
- dataset = dataset_ops.Dataset.from_tensor_slices(shapes).map(
- lambda shape: array_ops.zeros(shape, dtype=dtypes.int32)).apply(
- grouping.window_dataset(len(shapes))).apply(
- grouping._map_x_dataset(
- lambda x: batching.padded_batch_window(x, padded_shape)))
- get_next = dataset.make_one_shot_iterator().get_next()
- with self.cached_session() as sess:
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(get_next)
-
- def _structuredRaggedSparseDataset(self, structure, shapes, dtype):
-
- def map_fn(shape):
- dense_to_sparse = self._make_dense_to_sparse_fn(False)
- return dense_to_sparse(array_ops.zeros(shape, dtype=dtype))
-
- if structure is None:
- return dataset_ops.Dataset.from_tensor_slices(shapes).map(map_fn)
- else:
- return dataset_ops.Dataset.zip(
- tuple([
- self._structuredRaggedSparseDataset(substructure, shapes, dtype)
- for substructure in structure
- ]))
-
- def _structuredRaggedSparseElement(self, structure, shapes, dtype,
- padded_shape):
- if structure is None:
- dense_shape = np.maximum(np.amax(shapes, axis=0), padded_shape)
- values = []
- for shape in shapes:
- dense_to_sparse = self._make_dense_to_sparse_fn(len(shape) == 0) # pylint: disable=g-explicit-length-test
- sparse = dense_to_sparse(array_ops.zeros(shape, dtype=dtype))
- padded_sparse = sparse_tensor.SparseTensor(sparse.indices,
- sparse.values, dense_shape)
- reshaped_sparse = sparse_ops.sparse_reshape(
- padded_sparse,
- array_ops.concat([np.array([1], dtype=np.int64), dense_shape], 0))
- values.append(reshaped_sparse)
- return sparse_ops.sparse_concat(0, values)
- else:
- return tuple([
- self._structuredRaggedSparseElement(substructure, shapes, dtype,
- padded_shape)
- for substructure in structure
- ])
-
- @parameterized.named_parameters(
- ("1", None, np.int64([[1], [2], [3]]), dtypes.bool, [-1]),
- ("2", None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
- ("3", None, np.int64([[1], [2], [3]]), dtypes.float32, [-1]),
- ("4", None, np.int64([[1], [2], [3]]), dtypes.string, [-1]),
- ("5", None, np.int64([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]),
- ("6", None, np.int64([[1, 3, 1], [3, 1, 3]]), dtypes.int32, [-1, -1, -1]),
- ("7", (None, None, None), np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
- ("8", (None,
- (None, None)), np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
- ("9", None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
- ("10", None, np.int64([[1], [2], [3]]), dtypes.int32, np.int64([10])),
- )
- def testWindowDatasetPaddedBatchSparse(self, structure, shapes, dtype,
- padded_shape):
- """Tests padded batching of sparse tensor windows.
-
- Args:
- structure: the input structure
- shapes: the input shapes
- dtype: the input data type
- padded_shape: the shape to pad the output to
- """
-
- def fn(*args):
- if len(args) == 1 and not isinstance(args[0], tuple):
- return batching.padded_batch_window(args[0], padded_shape)
-
- return tuple([
- fn(*arg) if isinstance(arg, tuple) else batching.padded_batch_window(
- arg, padded_shape) for arg in args
- ])
-
- dataset = self._structuredRaggedSparseDataset(
- structure, shapes, dtype).apply(grouping.window_dataset(
- len(shapes))).apply(grouping._map_x_dataset(fn))
- get_next = dataset.make_one_shot_iterator().get_next()
- with self.cached_session() as sess:
- expected = sess.run(
- self._structuredRaggedSparseElement(structure, shapes, dtype,
- padded_shape))
- actual = sess.run(get_next)
- self._assertEqual(expected, actual)
-
- @parameterized.named_parameters(
- ("1", np.int64([[1], [2], [3]]), [-1]),
- ("2", np.int64([[1, 3], [2, 2], [3, 1]]), [-1, -1]),
- ("3", np.int64([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]),
- )
- def testWindowDatasetPaddedBatchSparseDynamicShape(self, shapes,
- padded_shape):
- """Tests padded batching of dynamically shaped sparse tensor windows.
-
- Args:
- shapes: the input shapes
- padded_shape: the shape to pad the output to
- """
-
- shapes_t = array_ops.placeholder(dtypes.int32)
- dataset = dataset_ops.Dataset.from_tensor_slices(shapes_t).map(
- lambda shape: array_ops.zeros(shape, dtype=dtypes.int32)).map(
- self._make_dense_to_sparse_fn(False)
- ).apply(grouping.window_dataset(len(shapes))).apply(
- grouping._map_x_dataset(
- lambda x: batching.padded_batch_window(x, padded_shape)))
- iterator = dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- sess.run(init_op, {shapes_t: shapes})
- expected = sess.run(
- self._structuredRaggedSparseElement(None, shapes, dtypes.int32,
- padded_shape))
- actual = sess.run(get_next)
- self._assertEqual(expected, actual)
-
- @parameterized.named_parameters(
- ("1", np.int64([[1]]), [0]),
- ("2", np.int64([[10], [20]]), [15]),
- )
- def testWindowDatasetPaddedBatchSparseInvalid(self, shapes, padded_shape):
- """Tests invalid padded batching of sparse tensor windows.
-
- Args:
- shapes: the input shapes
- padded_shape: the shape to pad the output to
- """
-
- dataset = dataset_ops.Dataset.from_tensor_slices(shapes).map(
- lambda shape: array_ops.zeros(shape, dtype=dtypes.int32)).map(
- self._make_dense_to_sparse_fn(False)
- ).apply(grouping.window_dataset(len(shapes))).apply(
- grouping._map_x_dataset(
- lambda x: batching.padded_batch_window(x, padded_shape)))
- get_next = dataset.make_one_shot_iterator().get_next()
- with self.cached_session() as sess:
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(get_next)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py
deleted file mode 100644
index fca546a570..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py
+++ /dev/null
@@ -1,118 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import os
-
-from tensorflow.contrib.data.python.ops import writers
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.ops import readers
-from tensorflow.python.framework import dtypes
-from tensorflow.python.lib.io import python_io
-from tensorflow.python.lib.io import tf_record
-from tensorflow.python.ops import array_ops
-from tensorflow.python.platform import test
-from tensorflow.python.util import compat
-
-
-class TFRecordWriterTest(test_base.DatasetTestBase):
-
- def setUp(self):
- super(TFRecordWriterTest, self).setUp()
- self._num_records = 7
- self.filename = array_ops.placeholder(dtypes.string, shape=[])
- self.compression_type = array_ops.placeholder_with_default("", shape=[])
-
- input_dataset = readers.TFRecordDataset([self.filename],
- self.compression_type)
- self.writer = writers.TFRecordWriter(
- self._outputFilename(), self.compression_type).write(input_dataset)
-
- def _record(self, i):
- return compat.as_bytes("Record %d" % (i))
-
- def _createFile(self, options=None):
- filename = self._inputFilename()
- writer = python_io.TFRecordWriter(filename, options)
- for i in range(self._num_records):
- writer.write(self._record(i))
- writer.close()
- return filename
-
- def _inputFilename(self):
- return os.path.join(self.get_temp_dir(), "tf_record.in.txt")
-
- def _outputFilename(self):
- return os.path.join(self.get_temp_dir(), "tf_record.out.txt")
-
- def testWrite(self):
- with self.cached_session() as sess:
- sess.run(
- self.writer, feed_dict={
- self.filename: self._createFile(),
- })
- for i, r in enumerate(tf_record.tf_record_iterator(self._outputFilename())):
- self.assertAllEqual(self._record(i), r)
-
- def testWriteZLIB(self):
- options = tf_record.TFRecordOptions(tf_record.TFRecordCompressionType.ZLIB)
- with self.cached_session() as sess:
- sess.run(
- self.writer,
- feed_dict={
- self.filename: self._createFile(options),
- self.compression_type: "ZLIB",
- })
- for i, r in enumerate(
- tf_record.tf_record_iterator(self._outputFilename(), options=options)):
- self.assertAllEqual(self._record(i), r)
-
- def testWriteGZIP(self):
- options = tf_record.TFRecordOptions(tf_record.TFRecordCompressionType.GZIP)
- with self.cached_session() as sess:
- sess.run(
- self.writer,
- feed_dict={
- self.filename: self._createFile(options),
- self.compression_type: "GZIP",
- })
- for i, r in enumerate(
- tf_record.tf_record_iterator(self._outputFilename(), options=options)):
- self.assertAllEqual(self._record(i), r)
-
- def testFailDataset(self):
- with self.assertRaises(TypeError):
- writers.TFRecordWriter(self._outputFilename(),
- self.compression_type).write("whoops")
-
- def testFailDType(self):
- input_dataset = dataset_ops.Dataset.from_tensors(10)
- with self.assertRaises(TypeError):
- writers.TFRecordWriter(self._outputFilename(),
- self.compression_type).write(input_dataset)
-
- def testFailShape(self):
- input_dataset = dataset_ops.Dataset.from_tensors([["hello"], ["world"]])
- with self.assertRaises(TypeError):
- writers.TFRecordWriter(self._outputFilename(),
- self.compression_type).write(input_dataset)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD
index 5cd1ed542b..34dc2379d0 100644
--- a/tensorflow/contrib/data/python/ops/BUILD
+++ b/tensorflow/contrib/data/python/ops/BUILD
@@ -16,10 +16,7 @@ py_library(
srcs = ["counter.py"],
srcs_version = "PY2AND3",
deps = [
- ":scan_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/experimental/ops:counter",
],
)
@@ -28,12 +25,7 @@ py_library(
srcs = ["get_single_element.py"],
srcs_version = "PY2AND3",
deps = [
- ":grouping",
- "//tensorflow/python:dataset_ops_gen",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
- "//third_party/py/numpy",
+ "//tensorflow/python/data/experimental/ops:get_single_element",
],
)
@@ -44,10 +36,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:dataset_ops_gen",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:training",
- "//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python/data/experimental/ops:iterator_ops",
],
)
@@ -58,15 +47,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dataset_ops_gen",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:random_seed",
- "//tensorflow/python:tensor_shape",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
+ "//tensorflow/python/data/experimental/ops:random_ops",
],
)
@@ -79,7 +60,6 @@ py_library(
deps = [
":batching",
":interleave_ops",
- ":optimization",
":parsing_ops",
":shuffle_ops",
"//tensorflow/python:constant_op",
@@ -91,6 +71,7 @@ py_library(
"//tensorflow/python:platform",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:util",
+ "//tensorflow/python/data/experimental/ops:readers",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:readers",
"//tensorflow/python/data/util:convert",
@@ -106,7 +87,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/experimental/ops:shuffle_ops",
],
)
@@ -125,6 +106,7 @@ py_library(
"//tensorflow/python:math_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:tensor_util",
+ "//tensorflow/python/data/experimental/ops:batching",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/util:convert",
"//tensorflow/python/data/util:nest",
@@ -138,8 +120,7 @@ py_library(
srcs = ["enumerate_ops.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:dtypes",
- "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/experimental/ops:enumerate_ops",
],
)
@@ -148,10 +129,7 @@ py_library(
srcs = ["error_ops.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:experimental_dataset_ops_gen",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
+ "//tensorflow/python/data/experimental/ops:error_ops",
],
)
@@ -160,16 +138,7 @@ py_library(
srcs = ["grouping.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:array_ops",
- "//tensorflow/python:check_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:function",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:tensor_shape",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
+ "//tensorflow/python/data/experimental/ops:grouping",
],
)
@@ -178,30 +147,7 @@ py_library(
srcs = ["interleave_ops.py"],
srcs_version = "PY2AND3",
deps = [
- ":random_ops",
- "//tensorflow/contrib/stateless",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:experimental_dataset_ops_gen",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:util",
- "//tensorflow/python/data/ops:readers",
- "//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
- ],
-)
-
-py_library(
- name = "optimization",
- srcs = ["optimization.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/python:dtypes",
- "//tensorflow/python:experimental_dataset_ops_gen",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
+ "//tensorflow/python/data/experimental/ops:interleave_ops",
],
)
@@ -210,25 +156,7 @@ py_library(
srcs = ["parsing_ops.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:dataset_ops_gen",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:parsing_ops",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python:tensor_shape",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- ],
-)
-
-py_library(
- name = "map_defun",
- srcs = ["map_defun.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/python:dataset_ops_gen",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/experimental/ops:parsing_ops",
],
)
@@ -237,18 +165,7 @@ py_library(
srcs = ["resampling.py"],
srcs_version = "PY2AND3",
deps = [
- ":batching",
- ":interleave_ops",
- ":scan_ops",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:logging_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:random_ops",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
+ "//tensorflow/python/data/experimental/ops:resampling",
],
)
@@ -257,12 +174,7 @@ py_library(
srcs = ["scan_ops.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:dataset_ops_gen",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:function",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
+ "//tensorflow/python/data/experimental/ops:scan_ops",
],
)
@@ -282,31 +194,11 @@ py_library(
)
py_library(
- name = "stats_ops",
- srcs = ["stats_ops.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/python:dataset_ops_gen",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/ops:iterator_ops",
- "//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
- ],
-)
-
-py_library(
name = "threadpool",
srcs = ["threadpool.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:experimental_dataset_ops_gen",
- "//tensorflow/python:resource_variable_ops",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
- "//tensorflow/python/eager:context",
+ "//tensorflow/python/data/experimental/ops:threadpool",
],
)
@@ -317,11 +209,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:dtypes",
- "//tensorflow/python:experimental_dataset_ops_gen",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
+ "//tensorflow/python/data/experimental/ops:unique",
],
)
@@ -332,20 +220,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:dtypes",
- "//tensorflow/python/data/ops:dataset_ops",
- ],
-)
-
-py_library(
- name = "indexed_dataset_ops",
- srcs = ["indexed_dataset_ops.py"],
- deps = [
- "//tensorflow/python:experimental_dataset_ops_gen",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
+ "//tensorflow/python/data/experimental/ops:writers",
],
)
@@ -353,11 +228,7 @@ py_library(
name = "prefetching_ops",
srcs = ["prefetching_ops.py"],
deps = [
- "//tensorflow/python:experimental_dataset_ops_gen",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
+ "//tensorflow/python/data/experimental/ops:prefetching_ops",
],
)
@@ -370,17 +241,14 @@ py_library(
":error_ops",
":get_single_element",
":grouping",
- ":indexed_dataset_ops",
":interleave_ops",
- ":map_defun",
- ":optimization",
":prefetching_ops",
+ ":random_ops",
":readers",
":resampling",
":scan_ops",
":shuffle_ops",
":sliding",
- ":stats_ops",
":threadpool",
":unique",
":writers",
diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py
index 7a0f221284..8c60459ca8 100644
--- a/tensorflow/contrib/data/python/ops/batching.py
+++ b/tensorflow/contrib/data/python/ops/batching.py
@@ -17,134 +17,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import numpy as np
-
-from tensorflow.contrib.data.python.ops import get_single_element
-from tensorflow.contrib.data.python.ops import grouping
from tensorflow.contrib.framework import with_shape
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.util import convert
+from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.util import nest
-from tensorflow.python.data.util import sparse
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import check_ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import gen_array_ops
-from tensorflow.python.ops import gen_dataset_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import sparse_ops
from tensorflow.python.util import deprecation
-def batch_window(dataset):
- """Batches a window of tensors.
-
- Args:
- dataset: the input dataset.
-
- Returns:
- A `Tensor` representing the batch of the entire input dataset.
- """
- if isinstance(dataset.output_classes, tuple):
- raise TypeError("Input dataset expected to have a single component")
- if dataset.output_classes is ops.Tensor:
- return _batch_dense_window(dataset)
- elif dataset.output_classes is sparse_tensor.SparseTensor:
- return _batch_sparse_window(dataset)
- else:
- raise TypeError("Unsupported dataset type: %s" % dataset.output_classes)
-
-
-def _batch_dense_window(dataset):
- """Batches a window of dense tensors."""
-
- def key_fn(_):
- return np.int64(0)
-
- def shape_init_fn(_):
- return array_ops.shape(first_element)
-
- def shape_reduce_fn(state, value):
- check_ops.assert_equal(state, array_ops.shape(value))
- return state
-
- def finalize_fn(state):
- return state
-
- if dataset.output_shapes.is_fully_defined():
- shape = dataset.output_shapes
- else:
- first_element = get_single_element.get_single_element(dataset.take(1))
- shape_reducer = grouping.Reducer(shape_init_fn, shape_reduce_fn,
- finalize_fn)
- shape = get_single_element.get_single_element(
- dataset.apply(grouping.group_by_reducer(key_fn, shape_reducer)))
-
- def batch_init_fn(_):
- batch_shape = array_ops.concat([[0], shape], 0)
- return gen_array_ops.empty(batch_shape, dtype=dataset.output_types)
-
- def batch_reduce_fn(state, value):
- return array_ops.concat([state, [value]], 0)
-
- batch_reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn)
- return get_single_element.get_single_element(
- dataset.apply(grouping.group_by_reducer(key_fn, batch_reducer)))
-
-
-def _batch_sparse_window(dataset):
- """Batches a window of sparse tensors."""
-
- def key_fn(_):
- return np.int64(0)
-
- def shape_init_fn(_):
- return first_element.dense_shape
-
- def shape_reduce_fn(state, value):
- check_ops.assert_equal(state, value.dense_shape)
- return state
-
- def finalize_fn(state):
- return state
-
- if dataset.output_shapes.is_fully_defined():
- shape = dataset.output_shapes
- else:
- first_element = get_single_element.get_single_element(dataset.take(1))
- shape_reducer = grouping.Reducer(shape_init_fn, shape_reduce_fn,
- finalize_fn)
- shape = get_single_element.get_single_element(
- dataset.apply(grouping.group_by_reducer(key_fn, shape_reducer)))
-
- def batch_init_fn(_):
- indices_shape = array_ops.concat([[0], [array_ops.size(shape) + 1]], 0)
- return sparse_tensor.SparseTensor(
- indices=gen_array_ops.empty(indices_shape, dtype=dtypes.int64),
- values=constant_op.constant([], shape=[0], dtype=dataset.output_types),
- dense_shape=array_ops.concat(
- [np.array([0], dtype=np.int64),
- math_ops.cast(shape, dtypes.int64)], 0))
-
- def batch_reduce_fn(state, value):
- return sparse_ops.sparse_concat(0, [state, value])
-
- def reshape_fn(value):
- return sparse_ops.sparse_reshape(
- value,
- array_ops.concat([np.array([1], dtype=np.int64), value.dense_shape], 0))
-
- batch_reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn)
- return get_single_element.get_single_element(
- dataset.map(reshape_fn).apply(
- grouping.group_by_reducer(key_fn, batch_reducer)))
-
-
+@deprecation.deprecated(
+ None, "Use `tf.data.experimental.dense_to_sparse_batch(...)`.")
def dense_to_sparse_batch(batch_size, row_shape):
"""A transformation that batches ragged elements into `tf.SparseTensor`s.
@@ -187,201 +67,10 @@ def dense_to_sparse_batch(batch_size, row_shape):
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
-
- def _apply_fn(dataset):
- return _DenseToSparseBatchDataset(dataset, batch_size, row_shape)
-
- return _apply_fn
-
-
-def padded_batch_window(dataset, padded_shape, padding_value=None):
- """Batches a window of tensors with padding.
-
- Args:
- dataset: the input dataset.
- padded_shape: (Optional.) `tf.TensorShape` or `tf.int64` vector tensor-like
- object representing the shape to which the input elements should be padded
- prior to batching. Any unknown dimensions (e.g. `tf.Dimension(None)` in a
- `tf.TensorShape` or `-1` in a tensor-like object) will be padded to the
- maximum size of that dimension in each batch.
- padding_value: (Optional.) A scalar-shaped `tf.Tensor`, representing the
- padding value to use. Defaults are `0` for numeric types and the empty
- string for string types. If `dataset` contains `tf.SparseTensor`, this
- value is ignored.
-
- Returns:
- A `Tensor` representing the batch of the entire input dataset.
-
- Raises:
- ValueError: if invalid arguments are provided.
- """
- if not issubclass(dataset.output_classes,
- (ops.Tensor, sparse_tensor.SparseTensor)):
- raise TypeError("Input dataset expected to have a single tensor component")
- if issubclass(dataset.output_classes, (ops.Tensor)):
- return _padded_batch_dense_window(dataset, padded_shape, padding_value)
- elif issubclass(dataset.output_classes, (sparse_tensor.SparseTensor)):
- if padding_value is not None:
- raise ValueError("Padding value not allowed for sparse tensors")
- return _padded_batch_sparse_window(dataset, padded_shape)
- else:
- raise TypeError("Unsupported dataset type: %s" % dataset.output_classes)
-
-
-def _padded_batch_dense_window(dataset, padded_shape, padding_value=None):
- """Batches a window of dense tensors with padding."""
-
- padded_shape = math_ops.cast(
- convert.partial_shape_to_tensor(padded_shape), dtypes.int32)
-
- def key_fn(_):
- return np.int64(0)
-
- def max_init_fn(_):
- return padded_shape
-
- def max_reduce_fn(state, value):
- """Computes the maximum shape to pad to."""
- condition = math_ops.reduce_all(
- math_ops.logical_or(
- math_ops.less_equal(array_ops.shape(value), padded_shape),
- math_ops.equal(padded_shape, -1)))
- assert_op = control_flow_ops.Assert(condition, [
- "Actual shape greater than padded shape: ",
- array_ops.shape(value), padded_shape
- ])
- with ops.control_dependencies([assert_op]):
- return math_ops.maximum(state, array_ops.shape(value))
-
- def finalize_fn(state):
- return state
-
- # Compute the padded shape.
- max_reducer = grouping.Reducer(max_init_fn, max_reduce_fn, finalize_fn)
- padded_shape = get_single_element.get_single_element(
- dataset.apply(grouping.group_by_reducer(key_fn, max_reducer)))
-
- if padding_value is None:
- if dataset.output_types == dtypes.string:
- padding_value = ""
- elif dataset.output_types == dtypes.bool:
- padding_value = False
- elif dataset.output_types == dtypes.variant:
- raise TypeError("Unable to create padding for field of type 'variant'")
- else:
- padding_value = 0
-
- def batch_init_fn(_):
- batch_shape = array_ops.concat(
- [np.array([0], dtype=np.int32), padded_shape], 0)
- return gen_array_ops.empty(batch_shape, dtype=dataset.output_types)
-
- def batch_reduce_fn(state, value):
- return array_ops.concat([state, [value]], 0)
-
- def pad_fn(value):
- shape = array_ops.shape(value)
- left = array_ops.zeros_like(shape)
- right = padded_shape - shape
- return array_ops.pad(
- value, array_ops.stack([left, right], 1), constant_values=padding_value)
-
- batch_reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn)
- return get_single_element.get_single_element(
- dataset.map(pad_fn).apply(
- grouping.group_by_reducer(key_fn, batch_reducer)))
-
-
-def _padded_batch_sparse_window(dataset, padded_shape):
- """Batches a window of sparse tensors with padding."""
-
- def key_fn(_):
- return np.int64(0)
-
- def max_init_fn(_):
- return convert.partial_shape_to_tensor(padded_shape)
-
- def max_reduce_fn(state, value):
- """Computes the maximum shape to pad to."""
- condition = math_ops.reduce_all(
- math_ops.logical_or(
- math_ops.less_equal(value.dense_shape, padded_shape),
- math_ops.equal(padded_shape, -1)))
- assert_op = control_flow_ops.Assert(condition, [
- "Actual shape greater than padded shape: ", value.dense_shape,
- padded_shape
- ])
- with ops.control_dependencies([assert_op]):
- return math_ops.maximum(state, value.dense_shape)
-
- def finalize_fn(state):
- return state
-
- # Compute the padded shape.
- max_reducer = grouping.Reducer(max_init_fn, max_reduce_fn, finalize_fn)
- padded_shape = get_single_element.get_single_element(
- dataset.apply(grouping.group_by_reducer(key_fn, max_reducer)))
-
- def batch_init_fn(_):
- indices_shape = array_ops.concat([[0], [array_ops.size(padded_shape) + 1]],
- 0)
- return sparse_tensor.SparseTensor(
- indices=gen_array_ops.empty(indices_shape, dtype=dtypes.int64),
- values=constant_op.constant([], shape=[0], dtype=dataset.output_types),
- dense_shape=array_ops.concat(
- [np.array([0], dtype=np.int64), padded_shape], 0))
-
- def batch_reduce_fn(state, value):
- padded_value = sparse_tensor.SparseTensor(
- indices=value.indices, values=value.values, dense_shape=padded_shape)
- reshaped_value = sparse_ops.sparse_reshape(
- padded_value,
- array_ops.concat(
- [np.array([1], dtype=np.int64), padded_value.dense_shape], 0))
- return sparse_ops.sparse_concat(0, [state, reshaped_value])
-
- reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn)
- return get_single_element.get_single_element(
- dataset.apply(grouping.group_by_reducer(key_fn, reducer)))
-
-
-class _UnbatchDataset(dataset_ops.UnaryDataset):
- """A dataset that splits the elements of its input into multiple elements."""
-
- def __init__(self, input_dataset):
- """See `unbatch()` for more details."""
- super(_UnbatchDataset, self).__init__(input_dataset)
- flat_shapes = nest.flatten(input_dataset.output_shapes)
- if any(s.ndims == 0 for s in flat_shapes):
- raise ValueError("Cannot unbatch an input with scalar components.")
- known_batch_dim = tensor_shape.Dimension(None)
- for s in flat_shapes:
- try:
- known_batch_dim = known_batch_dim.merge_with(s[0])
- except ValueError:
- raise ValueError("Cannot unbatch an input whose components have "
- "different batch sizes.")
- self._input_dataset = input_dataset
-
- def _as_variant_tensor(self):
- return gen_dataset_ops.unbatch_dataset(
- self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- **dataset_ops.flat_structure(self))
-
- @property
- def output_classes(self):
- return self._input_dataset.output_classes
-
- @property
- def output_shapes(self):
- return nest.map_structure(lambda s: s[1:],
- self._input_dataset.output_shapes)
-
- @property
- def output_types(self):
- return self._input_dataset.output_types
+ return batching.dense_to_sparse_batch(batch_size, row_shape)
+@deprecation.deprecated(None, "Use `tf.data.experimental.unbatch()`.")
def unbatch():
"""Splits elements of a dataset into multiple elements on the batch dimension.
@@ -403,39 +92,7 @@ def unbatch():
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
-
- def _apply_fn(dataset):
- """Function from `Dataset` to `Dataset` that applies the transformation."""
- if not sparse.any_sparse(dataset.output_classes):
- return _UnbatchDataset(dataset)
-
- # NOTE(mrry): We must ensure that any SparseTensors in `dataset`
- # are normalized to the rank-1 dense representation, so that the
- # sparse-oblivious unbatching logic will slice them
- # appropriately. This leads to a somewhat inefficient re-encoding step
- # for all SparseTensor components.
- # TODO(mrry): Consider optimizing this in future
- # if it turns out to be a bottleneck.
- def normalize(arg, *rest):
- if rest:
- return sparse.serialize_many_sparse_tensors((arg,) + rest)
- else:
- return sparse.serialize_many_sparse_tensors(arg)
-
- normalized_dataset = dataset.map(normalize)
-
- # NOTE(mrry): Our `map()` has lost information about the sparseness
- # of any SparseTensor components, so re-apply the structure of the
- # original dataset.
- restructured_dataset = _RestructuredDataset(
- normalized_dataset,
- dataset.output_types,
- dataset.output_shapes,
- dataset.output_classes,
- allow_unsafe_cast=True)
- return _UnbatchDataset(restructured_dataset)
-
- return _apply_fn
+ return batching.unbatch()
@deprecation.deprecated(
@@ -514,135 +171,8 @@ def padded_batch_and_drop_remainder(batch_size,
return _apply_fn
-class _DenseToSparseBatchDataset(dataset_ops.UnaryDataset):
- """A `Dataset` that batches ragged dense elements into `tf.SparseTensor`s."""
-
- def __init__(self, input_dataset, batch_size, row_shape):
- """See `Dataset.dense_to_sparse_batch()` for more details."""
- super(_DenseToSparseBatchDataset, self).__init__(input_dataset)
- if not isinstance(input_dataset.output_types, dtypes.DType):
- raise TypeError("DenseToSparseDataset requires an input whose elements "
- "have a single component, whereas the input has %r." %
- input_dataset.output_types)
- self._input_dataset = input_dataset
- self._batch_size = batch_size
- self._row_shape = row_shape
-
- def _as_variant_tensor(self):
- return gen_dataset_ops.dense_to_sparse_batch_dataset(
- self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- self._batch_size,
- row_shape=convert.partial_shape_to_tensor(self._row_shape),
- **dataset_ops.flat_structure(self))
-
- @property
- def output_classes(self):
- return sparse_tensor.SparseTensor
-
- @property
- def output_shapes(self):
- return tensor_shape.vector(None).concatenate(self._row_shape)
-
- @property
- def output_types(self):
- return self._input_dataset.output_types
-
-
-class _RestructuredDataset(dataset_ops.UnaryDataset):
- """An internal helper for changing the structure and shape of a dataset."""
-
- def __init__(self,
- dataset,
- output_types,
- output_shapes=None,
- output_classes=None,
- allow_unsafe_cast=False):
- """Creates a new dataset with the given output types and shapes.
-
- The given `dataset` must have a structure that is convertible:
- * `dataset.output_types` must be the same as `output_types` module nesting.
- * Each shape in `dataset.output_shapes` must be compatible with each shape
- in `output_shapes` (if given).
-
- Note: This helper permits "unsafe casts" for shapes, equivalent to using
- `tf.Tensor.set_shape()` where domain-specific knowledge is available.
-
- Args:
- dataset: A `Dataset` object.
- output_types: A nested structure of `tf.DType` objects.
- output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects.
- If omitted, the shapes will be inherited from `dataset`.
- output_classes: (Optional.) A nested structure of class types.
- If omitted, the class types will be inherited from `dataset`.
- allow_unsafe_cast: (Optional.) If `True`, the caller may switch the
- reported output types and shapes of the restructured dataset, e.g. to
- switch a sparse tensor represented as `tf.variant` to its user-visible
- type and shape.
-
- Raises:
- ValueError: If either `output_types` or `output_shapes` is not compatible
- with the structure of `dataset`.
- """
- super(_RestructuredDataset, self).__init__(dataset)
- self._input_dataset = dataset
-
- if not allow_unsafe_cast:
- # Validate that the types are compatible.
- output_types = nest.map_structure(dtypes.as_dtype, output_types)
- flat_original_types = nest.flatten(dataset.output_types)
- flat_new_types = nest.flatten(output_types)
- if flat_original_types != flat_new_types:
- raise ValueError(
- "Dataset with output types %r cannot be restructured to have "
- "output types %r" % (dataset.output_types, output_types))
-
- self._output_types = output_types
-
- if output_shapes is None:
- # Inherit shapes from the original `dataset`.
- self._output_shapes = nest.pack_sequence_as(output_types,
- nest.flatten(
- dataset.output_shapes))
- else:
- if not allow_unsafe_cast:
- # Validate that the shapes are compatible.
- nest.assert_same_structure(output_types, output_shapes)
- flat_original_shapes = nest.flatten(dataset.output_shapes)
- flat_new_shapes = nest.flatten_up_to(output_types, output_shapes)
-
- for original_shape, new_shape in zip(flat_original_shapes,
- flat_new_shapes):
- if not original_shape.is_compatible_with(new_shape):
- raise ValueError(
- "Dataset with output shapes %r cannot be restructured to have "
- "incompatible output shapes %r" % (dataset.output_shapes,
- output_shapes))
- self._output_shapes = nest.map_structure_up_to(
- output_types, tensor_shape.as_shape, output_shapes)
- if output_classes is None:
- # Inherit class types from the original `dataset`.
- self._output_classes = nest.pack_sequence_as(output_types,
- nest.flatten(
- dataset.output_classes))
- else:
- self._output_classes = output_classes
-
- def _as_variant_tensor(self):
- return self._input_dataset._as_variant_tensor() # pylint: disable=protected-access
-
- @property
- def output_classes(self):
- return self._output_classes
-
- @property
- def output_types(self):
- return self._output_types
-
- @property
- def output_shapes(self):
- return self._output_shapes
-
-
+# TODO(b/116817045): Move this to `tf.data.experimental` when the `with_shape()`
+# function is available in the core.
def assert_element_shape(expected_shapes):
"""Assert the shape of this `Dataset`.
@@ -687,7 +217,8 @@ def assert_element_shape(expected_shapes):
def _apply_fn(dataset):
output_shapes = _merge_output_shapes(dataset.output_shapes,
expected_shapes)
- return _RestructuredDataset(
+ # pylint: disable=protected-access
+ return batching._RestructuredDataset(
dataset.map(_check_shape),
dataset.output_types,
output_shapes=output_shapes,
@@ -696,49 +227,7 @@ def assert_element_shape(expected_shapes):
return _apply_fn
-class _MapAndBatchDataset(dataset_ops.MapDataset):
- """A `Dataset` that maps a function over a batch of elements."""
-
- def __init__(self, input_dataset, map_func, batch_size, num_parallel_calls,
- drop_remainder):
- """See `Dataset.map()` for details."""
- super(_MapAndBatchDataset, self).__init__(input_dataset, map_func)
- self._batch_size_t = ops.convert_to_tensor(
- batch_size, dtype=dtypes.int64, name="batch_size")
- self._num_parallel_calls_t = ops.convert_to_tensor(
- num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls")
- self._drop_remainder_t = ops.convert_to_tensor(
- drop_remainder, dtype=dtypes.bool, name="drop_remainder")
-
- self._batch_size = batch_size
- self._drop_remainder = drop_remainder
-
- def _as_variant_tensor(self):
- # pylint: disable=protected-access
- input_resource = self._input_dataset._as_variant_tensor()
- return gen_dataset_ops.map_and_batch_dataset_v2(
- input_resource,
- self._map_func.captured_inputs,
- f=self._map_func,
- batch_size=self._batch_size_t,
- num_parallel_calls=self._num_parallel_calls_t,
- drop_remainder=self._drop_remainder_t,
- **dataset_ops.flat_structure(self))
- # pylint: enable=protected-access
-
- @property
- def output_shapes(self):
- dim = self._batch_size if self._drop_remainder else None
- return nest.pack_sequence_as(self._output_shapes, [
- tensor_shape.vector(dim).concatenate(s)
- for s in nest.flatten(self._output_shapes)
- ])
-
- @property
- def output_types(self):
- return self._output_types
-
-
+@deprecation.deprecated(None, "Use `tf.data.experimental.map_and_batch(...)`.")
def map_and_batch(map_func,
batch_size,
num_parallel_batches=None,
@@ -779,17 +268,5 @@ def map_and_batch(map_func,
ValueError: If both `num_parallel_batches` and `num_parallel_calls` are
specified.
"""
-
- if num_parallel_batches is None and num_parallel_calls is None:
- num_parallel_calls = batch_size
- elif num_parallel_batches is not None and num_parallel_calls is None:
- num_parallel_calls = batch_size * num_parallel_batches
- elif num_parallel_batches is not None and num_parallel_calls is not None:
- raise ValueError("The `num_parallel_batches` and `num_parallel_calls` "
- "arguments are mutually exclusive.")
-
- def _apply_fn(dataset):
- return _MapAndBatchDataset(dataset, map_func, batch_size,
- num_parallel_calls, drop_remainder)
-
- return _apply_fn
+ return batching.map_and_batch(map_func, batch_size, num_parallel_batches,
+ drop_remainder, num_parallel_calls)
diff --git a/tensorflow/contrib/data/python/ops/counter.py b/tensorflow/contrib/data/python/ops/counter.py
index 6ef65f9624..4ff5bf3e39 100644
--- a/tensorflow/contrib/data/python/ops/counter.py
+++ b/tensorflow/contrib/data/python/ops/counter.py
@@ -17,13 +17,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.ops import scan_ops
-
-from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.experimental.ops import counter
from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
+from tensorflow.python.util import deprecation
+@deprecation.deprecated(None, "Use `tf.data.experimental.Counter(...)`.")
def Counter(start=0, step=1, dtype=dtypes.int64):
"""Creates a `Dataset` that counts from `start` in steps of size `step`.
@@ -46,8 +45,4 @@ def Counter(start=0, step=1, dtype=dtypes.int64):
Returns:
A `Dataset` of scalar `dtype` elements.
"""
- with ops.name_scope("counter"):
- start = ops.convert_to_tensor(start, dtype=dtype, name="start")
- step = ops.convert_to_tensor(step, dtype=dtype, name="step")
- return dataset_ops.Dataset.from_tensors(0).repeat(None).apply(
- scan_ops.scan(start, lambda state, _: (state + step, state)))
+ return counter.Counter(start, step, dtype)
diff --git a/tensorflow/contrib/data/python/ops/enumerate_ops.py b/tensorflow/contrib/data/python/ops/enumerate_ops.py
index 490281e0d2..a21da4d3ec 100644
--- a/tensorflow/contrib/data/python/ops/enumerate_ops.py
+++ b/tensorflow/contrib/data/python/ops/enumerate_ops.py
@@ -17,12 +17,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import numpy as np
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import dtypes
+from tensorflow.python.data.experimental.ops import enumerate_ops
+from tensorflow.python.util import deprecation
+@deprecation.deprecated(None,
+ "Use `tf.data.experimental.enumerate_dataset(...)`.")
def enumerate_dataset(start=0):
"""A transformation that enumerate the elements of a dataset.
@@ -49,10 +50,4 @@ def enumerate_dataset(start=0):
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
-
- def _apply_fn(dataset):
- max_value = np.iinfo(dtypes.int64.as_numpy_dtype).max
- return dataset_ops.Dataset.zip((dataset_ops.Dataset.range(start, max_value),
- dataset))
-
- return _apply_fn
+ return enumerate_ops.enumerate_dataset(start)
diff --git a/tensorflow/contrib/data/python/ops/error_ops.py b/tensorflow/contrib/data/python/ops/error_ops.py
index f962e623ee..0559a2e09c 100644
--- a/tensorflow/contrib/data/python/ops/error_ops.py
+++ b/tensorflow/contrib/data/python/ops/error_ops.py
@@ -17,10 +17,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.ops import gen_experimental_dataset_ops
+from tensorflow.python.data.experimental.ops import error_ops
+from tensorflow.python.util import deprecation
+@deprecation.deprecated(None, "Use `tf.data.experimental.ignore_errors()`.")
def ignore_errors():
"""Creates a `Dataset` from another `Dataset` and silently ignores any errors.
@@ -43,34 +44,4 @@ def ignore_errors():
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
-
- def _apply_fn(dataset):
- return _IgnoreErrorsDataset(dataset)
-
- return _apply_fn
-
-
-class _IgnoreErrorsDataset(dataset_ops.UnaryDataset):
- """A `Dataset` that silently ignores errors when computing its input."""
-
- def __init__(self, input_dataset):
- """See `Dataset.ignore_errors()` for details."""
- super(_IgnoreErrorsDataset, self).__init__(input_dataset)
- self._input_dataset = input_dataset
-
- def _as_variant_tensor(self):
- return gen_experimental_dataset_ops.experimental_ignore_errors_dataset(
- self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- **dataset_ops.flat_structure(self))
-
- @property
- def output_classes(self):
- return self._input_dataset.output_classes
-
- @property
- def output_shapes(self):
- return self._input_dataset.output_shapes
-
- @property
- def output_types(self):
- return self._input_dataset.output_types
+ return error_ops.ignore_errors()
diff --git a/tensorflow/contrib/data/python/ops/get_single_element.py b/tensorflow/contrib/data/python/ops/get_single_element.py
index a6713b017a..58ad9eea90 100644
--- a/tensorflow/contrib/data/python/ops/get_single_element.py
+++ b/tensorflow/contrib/data/python/ops/get_single_element.py
@@ -19,13 +19,13 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.data.python.ops import grouping
+from tensorflow.python.data.experimental.ops import get_single_element as experimental_get_single_element
from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.util import nest
-from tensorflow.python.data.util import sparse
-from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.util import deprecation
+@deprecation.deprecated(None,
+ "Use `tf.data.experimental.get_single_element(...)`.")
def get_single_element(dataset):
"""Returns the single element in `dataset` as a nested structure of tensors.
@@ -61,18 +61,10 @@ def get_single_element(dataset):
InvalidArgumentError (at runtime): if `dataset` does not contain exactly
one element.
"""
- if not isinstance(dataset, dataset_ops.Dataset):
- raise TypeError("`dataset` must be a `tf.data.Dataset` object.")
-
- nested_ret = nest.pack_sequence_as(
- dataset.output_types, gen_dataset_ops.dataset_to_single_element(
- dataset._as_variant_tensor(), # pylint: disable=protected-access
- **dataset_ops.flat_structure(dataset)))
- return sparse.deserialize_sparse_tensors(
- nested_ret, dataset.output_types, dataset.output_shapes,
- dataset.output_classes)
+ return experimental_get_single_element.get_single_element(dataset)
+@deprecation.deprecated(None, "Use `tf.data.Dataset.reduce(...)`.")
def reduce_dataset(dataset, reducer):
"""Returns the result of reducing the `dataset` using `reducer`.
@@ -90,11 +82,4 @@ def reduce_dataset(dataset, reducer):
if not isinstance(dataset, dataset_ops.Dataset):
raise TypeError("`dataset` must be a `tf.data.Dataset` object.")
- # The sentinel dataset is used in case the reduced dataset is empty.
- sentinel_dataset = dataset_ops.Dataset.from_tensors(
- reducer.finalize_func(reducer.init_func(np.int64(0))))
- reduced_dataset = dataset.apply(
- grouping.group_by_reducer(lambda x: np.int64(0), reducer))
-
- return get_single_element(
- reduced_dataset.concatenate(sentinel_dataset).take(1))
+ return dataset.reduce(reducer.init_func(np.int64(0)), reducer.reduce_func)
diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py
index 7cae33beb3..a99dc2f29a 100644
--- a/tensorflow/contrib/data/python/ops/grouping.py
+++ b/tensorflow/contrib/data/python/ops/grouping.py
@@ -17,20 +17,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import numpy as np
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.util import nest
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import check_ops
-from tensorflow.python.ops import gen_dataset_ops
-from tensorflow.python.ops import math_ops
+from tensorflow.python.data.experimental.ops import grouping
+from tensorflow.python.util import deprecation
+@deprecation.deprecated(None,
+ "Use `tf.data.experimental.group_by_reducer(...)`.")
def group_by_reducer(key_func, reducer):
"""A transformation that groups elements and performs a reduction.
@@ -52,14 +45,11 @@ def group_by_reducer(key_func, reducer):
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
-
- def _apply_fn(dataset):
- """Function from `Dataset` to `Dataset` that applies the transformation."""
- return _GroupByReducerDataset(dataset, key_func, reducer)
-
- return _apply_fn
+ return grouping.group_by_reducer(key_func, reducer)
+@deprecation.deprecated(None,
+ "Use `tf.data.experimental.group_by_window(...)`.")
def group_by_window(key_func,
reduce_func,
window_size=None,
@@ -98,27 +88,12 @@ def group_by_window(key_func,
ValueError: if neither or both of {`window_size`, `window_size_func`} are
passed.
"""
- if (window_size is not None and window_size_func or
- not (window_size is not None or window_size_func)):
- raise ValueError("Must pass either window_size or window_size_func.")
-
- if window_size is not None:
-
- def constant_window_func(unused_key):
- return ops.convert_to_tensor(window_size, dtype=dtypes.int64)
-
- window_size_func = constant_window_func
-
- assert window_size_func is not None
-
- def _apply_fn(dataset):
- """Function from `Dataset` to `Dataset` that applies the transformation."""
- return _GroupByWindowDataset(dataset, key_func, reduce_func,
- window_size_func)
-
- return _apply_fn
+ return grouping.group_by_window(key_func, reduce_func, window_size,
+ window_size_func)
+@deprecation.deprecated(
+ None, "Use `tf.data.experimental.bucket_by_sequence_length(...)`.")
def bucket_by_sequence_length(element_length_func,
bucket_boundaries,
bucket_batch_sizes,
@@ -163,342 +138,12 @@ def bucket_by_sequence_length(element_length_func,
Raises:
ValueError: if `len(bucket_batch_sizes) != len(bucket_boundaries) + 1`.
"""
- with ops.name_scope("bucket_by_seq_length"):
- if len(bucket_batch_sizes) != (len(bucket_boundaries) + 1):
- raise ValueError(
- "len(bucket_batch_sizes) must equal len(bucket_boundaries) + 1")
-
- batch_sizes = constant_op.constant(bucket_batch_sizes, dtype=dtypes.int64)
-
- def element_to_bucket_id(*args):
- """Return int64 id of the length bucket for this element."""
- seq_length = element_length_func(*args)
-
- boundaries = list(bucket_boundaries)
- buckets_min = [np.iinfo(np.int32).min] + boundaries
- buckets_max = boundaries + [np.iinfo(np.int32).max]
- conditions_c = math_ops.logical_and(
- math_ops.less_equal(buckets_min, seq_length),
- math_ops.less(seq_length, buckets_max))
- bucket_id = math_ops.reduce_min(array_ops.where(conditions_c))
-
- return bucket_id
-
- def window_size_fn(bucket_id):
- # The window size is set to the batch size for this bucket
- window_size = batch_sizes[bucket_id]
- return window_size
-
- def make_padded_shapes(shapes, none_filler=None):
- padded = []
- for shape in nest.flatten(shapes):
- shape = tensor_shape.TensorShape(shape)
- shape = [
- none_filler if d.value is None else d
- for d in shape
- ]
- padded.append(shape)
- return nest.pack_sequence_as(shapes, padded)
-
- def batching_fn(bucket_id, grouped_dataset):
- """Batch elements in dataset."""
- batch_size = window_size_fn(bucket_id)
- if no_padding:
- return grouped_dataset.batch(batch_size)
- none_filler = None
- if pad_to_bucket_boundary:
- err_msg = ("When pad_to_bucket_boundary=True, elements must have "
- "length < max(bucket_boundaries).")
- check = check_ops.assert_less(
- bucket_id,
- constant_op.constant(len(bucket_batch_sizes) - 1,
- dtype=dtypes.int64),
- message=err_msg)
- with ops.control_dependencies([check]):
- boundaries = constant_op.constant(bucket_boundaries,
- dtype=dtypes.int64)
- bucket_boundary = boundaries[bucket_id]
- none_filler = bucket_boundary - 1
- shapes = make_padded_shapes(
- padded_shapes or grouped_dataset.output_shapes,
- none_filler=none_filler)
- return grouped_dataset.padded_batch(batch_size, shapes, padding_values)
-
- def _apply_fn(dataset):
- return dataset.apply(
- group_by_window(element_to_bucket_id, batching_fn,
- window_size_func=window_size_fn))
-
- return _apply_fn
-
-
-def _map_x_dataset(map_func):
- """A transformation that maps `map_func` across its input.
-
- This transformation is similar to `tf.data.Dataset.map`, but in addition to
- supporting dense and sparse tensor inputs, it also supports dataset inputs.
-
- Args:
- map_func: A function mapping a nested structure of tensors and/or datasets
- (having shapes and types defined by `self.output_shapes` and
- `self.output_types`) to another nested structure of tensors and/or
- datasets.
-
- Returns:
- Dataset: A `Dataset`.
- """
-
- def _apply_fn(dataset):
- """Function from `Dataset` to `Dataset` that applies the transformation."""
- return _MapXDataset(dataset, map_func)
-
- return _apply_fn
-
-
-# TODO(b/115382007) Remove this once canned reducers move to core.
-def window_dataset(window_size):
- """A transformation that creates window datasets from the input dataset.
-
- The resulting datasets will contain `window_size` elements (or
- `N % window_size` for the last dataset if `window_size` does not divide the
- number of input elements `N` evenly).
-
- Args:
- window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
- consecutive elements of the input dataset to combine into a window.
-
- Returns:
- Dataset: A `Dataset`.
- """
-
- def _apply_fn(dataset):
- return dataset_ops.WindowDataset(
- dataset,
- size=window_size,
- shift=window_size,
- stride=1,
- drop_remainder=False)
-
- return _apply_fn
-
-
-class _GroupByReducerDataset(dataset_ops.UnaryDataset):
- """A `Dataset` that groups its input and performs a reduction."""
-
- def __init__(self, input_dataset, key_func, reducer):
- """See `group_by_reducer()` for details."""
- super(_GroupByReducerDataset, self).__init__(input_dataset)
+ return grouping.bucket_by_sequence_length(
+ element_length_func, bucket_boundaries, bucket_batch_sizes, padded_shapes,
+ padding_values, pad_to_bucket_boundary, no_padding)
- self._input_dataset = input_dataset
- self._make_key_func(key_func, input_dataset)
- self._make_init_func(reducer.init_func)
- self._make_reduce_func(reducer.reduce_func, input_dataset)
- self._make_finalize_func(reducer.finalize_func)
-
- def _make_key_func(self, key_func, input_dataset):
- """Make wrapping Defun for key_func."""
- wrapped_func = dataset_ops.StructuredFunctionWrapper(
- key_func, "tf.contrib.data.group_by_reducer()", input_dataset)
- if not (
- wrapped_func.output_types == dtypes.int64 and
- wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())):
- raise ValueError(
- "`key_func` must return a single tf.int64 tensor. "
- "Got type=%s and shape=%s"
- % (wrapped_func.output_types, wrapped_func.output_shapes))
- self._key_func = wrapped_func.function
-
- def _make_init_func(self, init_func):
- """Make wrapping Defun for init_func."""
- wrapped_func = dataset_ops.StructuredFunctionWrapper(
- init_func, "tf.contrib.data.group_by_reducer()",
- input_classes=ops.Tensor, input_shapes=tensor_shape.scalar(),
- input_types=dtypes.int64)
- self._init_func = wrapped_func.function
- self._state_classes = wrapped_func.output_classes
- self._state_shapes = wrapped_func.output_shapes
- self._state_types = wrapped_func.output_types
-
- def _make_reduce_func(self, reduce_func, input_dataset):
- """Make wrapping Defun for reduce_func."""
-
- # Iteratively rerun the reduce function until reaching a fixed point on
- # `self._state_shapes`.
- need_to_rerun = True
- while need_to_rerun:
-
- wrapped_func = dataset_ops.StructuredFunctionWrapper(
- reduce_func, "tf.contrib.data.group_by_reducer()",
- input_classes=(self._state_classes, input_dataset.output_classes),
- input_shapes=(self._state_shapes, input_dataset.output_shapes),
- input_types=(self._state_types, input_dataset.output_types),
- add_to_graph=False)
-
- # Extract and validate class information from the returned values.
- for new_state_class, state_class in zip(
- nest.flatten(wrapped_func.output_classes),
- nest.flatten(self._state_classes)):
- if not issubclass(new_state_class, state_class):
- raise TypeError(
- "The element classes for the new state must match the initial "
- "state. Expected %s; got %s." %
- (self._state_classes, wrapped_func.output_classes))
-
- # Extract and validate type information from the returned values.
- for new_state_type, state_type in zip(
- nest.flatten(wrapped_func.output_types),
- nest.flatten(self._state_types)):
- if new_state_type != state_type:
- raise TypeError(
- "The element types for the new state must match the initial "
- "state. Expected %s; got %s." %
- (self._state_types, wrapped_func.output_types))
-
- # Extract shape information from the returned values.
- flat_state_shapes = nest.flatten(self._state_shapes)
- flat_new_state_shapes = nest.flatten(wrapped_func.output_shapes)
- weakened_state_shapes = [
- original.most_specific_compatible_shape(new)
- for original, new in zip(flat_state_shapes, flat_new_state_shapes)
- ]
-
- need_to_rerun = False
- for original_shape, weakened_shape in zip(flat_state_shapes,
- weakened_state_shapes):
- if original_shape.ndims is not None and (
- weakened_shape.ndims is None or
- original_shape.as_list() != weakened_shape.as_list()):
- need_to_rerun = True
- break
-
- if need_to_rerun:
- self._state_shapes = nest.pack_sequence_as(self._state_shapes,
- weakened_state_shapes)
-
- self._reduce_func = wrapped_func.function
- self._reduce_func.add_to_graph(ops.get_default_graph())
-
- def _make_finalize_func(self, finalize_func):
- """Make wrapping Defun for finalize_func."""
- wrapped_func = dataset_ops.StructuredFunctionWrapper(
- finalize_func, "tf.contrib.data.group_by_reducer()",
- input_classes=self._state_classes, input_shapes=self._state_shapes,
- input_types=self._state_types)
- self._finalize_func = wrapped_func.function
- self._output_classes = wrapped_func.output_classes
- self._output_shapes = wrapped_func.output_shapes
- self._output_types = wrapped_func.output_types
-
- @property
- def output_classes(self):
- return self._output_classes
-
- @property
- def output_shapes(self):
- return self._output_shapes
-
- @property
- def output_types(self):
- return self._output_types
-
- def _as_variant_tensor(self):
- return gen_dataset_ops.group_by_reducer_dataset(
- self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- self._key_func.captured_inputs,
- self._init_func.captured_inputs,
- self._reduce_func.captured_inputs,
- self._finalize_func.captured_inputs,
- key_func=self._key_func,
- init_func=self._init_func,
- reduce_func=self._reduce_func,
- finalize_func=self._finalize_func,
- **dataset_ops.flat_structure(self))
-
-
-class _GroupByWindowDataset(dataset_ops.UnaryDataset):
- """A `Dataset` that groups its input and performs a windowed reduction."""
-
- def __init__(self, input_dataset, key_func, reduce_func, window_size_func):
- """See `group_by_window()` for details."""
- super(_GroupByWindowDataset, self).__init__(input_dataset)
-
- self._input_dataset = input_dataset
-
- self._make_key_func(key_func, input_dataset)
- self._make_reduce_func(reduce_func, input_dataset)
- self._make_window_size_func(window_size_func)
-
- def _make_window_size_func(self, window_size_func):
- """Make wrapping Defun for window_size_func."""
- def window_size_func_wrapper(key):
- return ops.convert_to_tensor(window_size_func(key), dtype=dtypes.int64)
- wrapped_func = dataset_ops.StructuredFunctionWrapper(
- window_size_func_wrapper, "tf.contrib.data.group_by_window()",
- input_classes=ops.Tensor, input_shapes=tensor_shape.scalar(),
- input_types=dtypes.int64)
- if not (
- wrapped_func.output_types == dtypes.int64 and
- wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())):
- raise ValueError(
- "`window_size_func` must return a single tf.int64 scalar tensor.")
- self._window_size_func = wrapped_func.function
-
- def _make_key_func(self, key_func, input_dataset):
- """Make wrapping Defun for key_func."""
- def key_func_wrapper(*args):
- return ops.convert_to_tensor(key_func(*args), dtype=dtypes.int64)
- wrapped_func = dataset_ops.StructuredFunctionWrapper(
- key_func_wrapper, "tf.contrib.data.group_by_window()", input_dataset)
- if not (
- wrapped_func.output_types == dtypes.int64 and
- wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())):
- raise ValueError(
- "`key_func` must return a single tf.int64 scalar tensor.")
- self._key_func = wrapped_func.function
-
- def _make_reduce_func(self, reduce_func, input_dataset):
- """Make wrapping Defun for reduce_func."""
- nested_dataset = dataset_ops._NestedDatasetComponent(input_dataset) # pylint: disable=protected-access
- wrapped_func = dataset_ops.StructuredFunctionWrapper(
- reduce_func, "tf.contrib.data.reduce_by_window()",
- input_classes=(ops.Tensor, nested_dataset),
- input_shapes=(tensor_shape.scalar(), nested_dataset),
- input_types=(dtypes.int64, nested_dataset),
- experimental_nested_dataset_support=True)
- if not isinstance(
- wrapped_func.output_classes, dataset_ops._NestedDatasetComponent): # pylint: disable=protected-access
- raise TypeError("`reduce_func` must return a `Dataset` object.")
- self._output_classes = wrapped_func.output_classes.output_classes
- self._output_types = wrapped_func.output_types.output_types
- self._output_shapes = wrapped_func.output_shapes.output_shapes
- self._reduce_func = wrapped_func.function
-
- @property
- def output_classes(self):
- return self._output_classes
-
- @property
- def output_shapes(self):
- return self._output_shapes
-
- @property
- def output_types(self):
- return self._output_types
-
- def _as_variant_tensor(self):
- return gen_dataset_ops.group_by_window_dataset(
- self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- self._key_func.captured_inputs,
- self._reduce_func.captured_inputs,
- self._window_size_func.captured_inputs,
- key_func=self._key_func,
- reduce_func=self._reduce_func,
- window_size_func=self._window_size_func,
- **dataset_ops.flat_structure(self))
-
-
-class Reducer(object):
+class Reducer(grouping.Reducer):
"""A reducer is used for reducing a set of elements.
A reducer is represented as a tuple of the three functions:
@@ -507,58 +152,6 @@ class Reducer(object):
3) finalization function: state => result
"""
+ @deprecation.deprecated(None, "Use `tf.data.experimental.Reducer(...)`.")
def __init__(self, init_func, reduce_func, finalize_func):
- self._init_func = init_func
- self._reduce_func = reduce_func
- self._finalize_func = finalize_func
-
- @property
- def init_func(self):
- return self._init_func
-
- @property
- def reduce_func(self):
- return self._reduce_func
-
- @property
- def finalize_func(self):
- return self._finalize_func
-
-
-class _MapXDataset(dataset_ops.UnaryDataset):
- """A `Dataset` that maps a function over elements in its input."""
-
- def __init__(self, input_dataset, map_func):
- """See `map_x_dataset()` for details."""
- super(_MapXDataset, self).__init__(input_dataset)
- self._input_dataset = input_dataset
-
- wrapped_func = dataset_ops.StructuredFunctionWrapper(
- map_func,
- "tf.contrib.data.map_x_dataset()",
- input_dataset,
- experimental_nested_dataset_support=True)
- self._output_classes = wrapped_func.output_classes
- self._output_shapes = wrapped_func.output_shapes
- self._output_types = wrapped_func.output_types
- self._map_func = wrapped_func.function
-
- def _as_variant_tensor(self):
- input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access
- return gen_dataset_ops.map_dataset(
- input_t,
- self._map_func.captured_inputs,
- f=self._map_func,
- **dataset_ops.flat_structure(self))
-
- @property
- def output_classes(self):
- return self._output_classes
-
- @property
- def output_shapes(self):
- return self._output_shapes
-
- @property
- def output_types(self):
- return self._output_types
+ super(Reducer, self).__init__(init_func, reduce_func, finalize_func)
diff --git a/tensorflow/contrib/data/python/ops/indexed_dataset_ops.py b/tensorflow/contrib/data/python/ops/indexed_dataset_ops.py
deleted file mode 100644
index 9c06474a2f..0000000000
--- a/tensorflow/contrib/data/python/ops/indexed_dataset_ops.py
+++ /dev/null
@@ -1,177 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Python wrappers for indexed datasets."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import abc
-
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.util import nest
-from tensorflow.python.data.util import sparse
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
-
-
-class MaterializedIndexedDataset(object):
- """MaterializedIndexedDataset is highly experimental!
- """
-
- def __init__(self, materialized_resource, materializer, output_classes,
- output_types, output_shapes):
- self._materialized_resource = materialized_resource
- self._materializer = materializer
- self._output_classes = output_classes
- self._output_types = output_types
- self._output_shapes = output_shapes
-
- @property
- def initializer(self):
- if self._materializer is not None:
- return self._materializer
- raise ValueError("MaterializedDataset does not have a materializer")
-
- def get(self, index):
- """Get retrieves a value (or set of values) from the IndexedDataset.
-
- Args:
- index: A uint64 scalar or vector tensor with the indices to retrieve.
-
- Returns:
- A tensor containing the values corresponding to `index`.
- """
- # TODO(saeta): nest.pack_sequence_as(...)
- return ged_ops.experimental_indexed_dataset_get(
- self._materialized_resource,
- index,
- output_types=nest.flatten(
- sparse.as_dense_types(self._output_types, self._output_classes)),
- output_shapes=nest.flatten(
- sparse.as_dense_types(self._output_shapes, self._output_classes)))
-
-
-class IndexedDataset(dataset_ops.Dataset):
- """IndexedDataset is highly experimental!
- """
-
- def __init__(self):
- pass
-
- def materialize(self, shared_name=None, container=None):
- """Materialize creates a MaterializedIndexedDataset.
-
- IndexedDatasets can be combined through operations such as TBD. Therefore,
- they are only materialized when absolutely required.
-
- Args:
- shared_name: a string for the shared name to use for the resource.
- container: a string for the container to store the resource.
-
- Returns:
- A MaterializedIndexedDataset.
- """
- if container is None:
- container = ""
- if shared_name is None:
- shared_name = ""
- materialized_resource = (
- ged_ops.experimental_materialized_index_dataset_handle(
- container=container,
- shared_name=shared_name,
- output_types=nest.flatten(
- sparse.as_dense_types(self.output_types, self.output_classes)),
- output_shapes=nest.flatten(
- sparse.as_dense_types(self.output_shapes,
- self.output_classes))))
-
- with ops.colocate_with(materialized_resource):
- materializer = ged_ops.experimental_indexed_dataset_materialize(
- self._as_variant_tensor(), materialized_resource)
- return MaterializedIndexedDataset(materialized_resource, materializer,
- self.output_classes, self.output_types,
- self.output_shapes)
-
- @abc.abstractproperty
- def output_types(self):
- """Returns the type of each component of an element of this IndexedDataset.
-
- Returns:
- A nested structure of `tf.DType` objects corresponding to each component
- of an element of this IndexedDataset.
- """
- raise NotImplementedError("IndexedDataset.output_types")
-
- @abc.abstractproperty
- def output_classes(self):
- """Returns the class of each component of an element of this IndexedDataset.
-
- The expected values are `tf.Tensor` and `tf.SparseTensor`.
-
- Returns:
- A nested structure of Python `type` objects corresponding to each
- component of an element of this IndexedDataset.
- """
- raise NotImplementedError("IndexedDataset.output_classes")
-
- @abc.abstractproperty
- def output_shapes(self):
- """Returns the shape of each component of an element of this IndexedDataset.
-
- Returns:
- A nested structure of `tf.TensorShape` objects corresponding to each
- component of an element of this IndexedDataset.
- """
- raise NotImplementedError("IndexedDataset.output_shapes")
-
- @abc.abstractmethod
- def _as_variant_tensor(self):
- """Creates a `tf.variant` `tf.Tensor` representing this IndexedDataset.
-
- Returns:
- A scalar `tf.Tensor` of `tf.variant` type, which represents this
- IndexedDataset.
- """
- raise NotImplementedError("IndexedDataset._as_variant_tensor")
-
-
-class IdentityIndexedDataset(IndexedDataset):
- """IdentityIndexedDataset is a trivial indexed dataset used for testing.
- """
-
- def __init__(self, size):
- super(IdentityIndexedDataset, self).__init__()
- # TODO(saeta): Verify _size is a scalar!
- self._size = ops.convert_to_tensor(size, dtype=dtypes.uint64, name="size")
-
- @property
- def output_types(self):
- return dtypes.uint64
-
- @property
- def output_classes(self):
- return ops.Tensor
-
- @property
- def output_shapes(self):
- return tensor_shape.scalar()
-
- def _as_variant_tensor(self):
- return ged_ops.experimental_identity_indexed_dataset(self._size)
-
- def _inputs(self):
- return []
diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py
index 1ee9db1aa8..f50da4d429 100644
--- a/tensorflow/contrib/data/python/ops/interleave_ops.py
+++ b/tensorflow/contrib/data/python/ops/interleave_ops.py
@@ -17,20 +17,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib import stateless
-from tensorflow.contrib.data.python.ops import random_ops
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.ops import readers
-from tensorflow.python.data.util import nest
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import gen_experimental_dataset_ops
-from tensorflow.python.ops import math_ops
+from tensorflow.python.data.experimental.ops import interleave_ops
from tensorflow.python.util import deprecation
+@deprecation.deprecated(None,
+ "Use `tf.data.experimental.parallel_interleave(...)`.")
def parallel_interleave(map_func,
cycle_length,
block_length=1,
@@ -80,12 +72,9 @@ def parallel_interleave(map_func,
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
- def _apply_fn(dataset):
- return readers.ParallelInterleaveDataset(
- dataset, map_func, cycle_length, block_length, sloppy,
- buffer_output_elements, prefetch_input_elements)
-
- return _apply_fn
+ return interleave_ops.parallel_interleave(
+ map_func, cycle_length, block_length, sloppy, buffer_output_elements,
+ prefetch_input_elements)
@deprecation.deprecated(
@@ -139,63 +128,12 @@ def sloppy_interleave(map_func, cycle_length, block_length=1):
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
- def _apply_fn(dataset):
- return readers.ParallelInterleaveDataset(
- dataset,
- map_func,
- cycle_length,
- block_length,
- sloppy=True,
- buffer_output_elements=None,
- prefetch_input_elements=None)
-
- return _apply_fn
-
-
-class _DirectedInterleaveDataset(dataset_ops.Dataset):
- """A substitute for `Dataset.interleave()` on a fixed list of datasets."""
-
- def __init__(self, selector_input, data_inputs):
- self._selector_input = selector_input
- self._data_inputs = list(data_inputs)
-
- for data_input in data_inputs[1:]:
- if (data_input.output_types != data_inputs[0].output_types or
- data_input.output_classes != data_inputs[0].output_classes):
- raise TypeError("All datasets must have the same type and class.")
-
- def _as_variant_tensor(self):
- # pylint: disable=protected-access
- return (
- gen_experimental_dataset_ops.experimental_directed_interleave_dataset(
- self._selector_input._as_variant_tensor(), [
- data_input._as_variant_tensor()
- for data_input in self._data_inputs
- ], **dataset_ops.flat_structure(self)))
- # pylint: enable=protected-access
-
- def _inputs(self):
- return [self._selector_input] + self._data_inputs
-
- @property
- def output_classes(self):
- return self._data_inputs[0].output_classes
-
- @property
- def output_shapes(self):
- ret = self._data_inputs[0].output_shapes
- for data_input in self._data_inputs[1:]:
- ret = nest.pack_sequence_as(ret, [
- ts1.most_specific_compatible_shape(ts2) for (ts1, ts2) in zip(
- nest.flatten(ret), nest.flatten(data_input.output_shapes))
- ])
- return ret
-
- @property
- def output_types(self):
- return self._data_inputs[0].output_types
+ return interleave_ops.parallel_interleave(
+ map_func, cycle_length, block_length, sloppy=True)
+@deprecation.deprecated(None,
+ "Use `tf.data.experimental.sample_from_datasets(...)`.")
def sample_from_datasets(datasets, weights=None, seed=None):
"""Samples elements at random from the datasets in `datasets`.
@@ -219,64 +157,11 @@ def sample_from_datasets(datasets, weights=None, seed=None):
ValueError: If the `weights` argument is specified and does not match the
length of the `datasets` element.
"""
- num_datasets = len(datasets)
- if not isinstance(weights, dataset_ops.Dataset):
- if weights is None:
- # Select inputs with uniform probability.
- logits = [[1.0] * num_datasets]
-
- else:
- # Use the given `weights` as the probability of choosing the respective
- # input.
- weights = ops.convert_to_tensor(weights, name="weights")
- if weights.dtype not in (dtypes.float32, dtypes.float64):
- raise TypeError("`weights` must be convertible to a tensor of "
- "`tf.float32` or `tf.float64` elements.")
- if not weights.shape.is_compatible_with([num_datasets]):
- raise ValueError(
- "`weights` must be a vector of length `len(datasets)`.")
-
- # The `stateless_multinomial()` op expects log-probabilities, as opposed
- # to weights.
- logits = array_ops.expand_dims(math_ops.log(weights, name="logits"), 0)
-
- # NOTE(mrry): We only specialize when `weights` is not a `Dataset`. When it
- # is a `Dataset`, it is possible that evaluating it has a side effect the
- # user depends on.
- if len(datasets) == 1:
- return datasets[0]
-
- def select_dataset_constant_logits(seed):
- return array_ops.squeeze(
- stateless.stateless_multinomial(logits, 1, seed=seed), axis=[0, 1])
-
- selector_input = dataset_ops.MapDataset(
- random_ops.RandomDataset(seed).batch(2),
- select_dataset_constant_logits,
- use_inter_op_parallelism=False)
-
- else:
- # Use each element of the given `weights` dataset as the probability of
- # choosing the respective input.
-
- # The `stateless_multinomial()` op expects log-probabilities, as opposed to
- # weights.
- logits_ds = weights.map(lambda *p: math_ops.log(p, name="logits"))
-
- def select_dataset_varying_logits(logits, seed):
- return array_ops.squeeze(
- stateless.stateless_multinomial(logits, 1, seed=seed), axis=[0, 1])
-
- logits_and_seeds = dataset_ops.Dataset.zip(
- (logits_ds, random_ops.RandomDataset(seed).batch(2)))
- selector_input = dataset_ops.MapDataset(
- logits_and_seeds,
- select_dataset_varying_logits,
- use_inter_op_parallelism=False)
-
- return _DirectedInterleaveDataset(selector_input, datasets)
+ return interleave_ops.sample_from_datasets(datasets, weights, seed)
+@deprecation.deprecated(None,
+ "Use `tf.data.experimental.choose_from_datasets(...)`.")
def choose_from_datasets(datasets, choice_dataset):
"""Creates a dataset that deterministically chooses elements from `datasets`.
@@ -312,10 +197,4 @@ def choose_from_datasets(datasets, choice_dataset):
TypeError: If the `datasets` or `choice_dataset` arguments have the wrong
type.
"""
- if not (choice_dataset.output_types == dtypes.int64
- and choice_dataset.output_shapes.is_compatible_with(
- tensor_shape.scalar())
- and choice_dataset.output_classes == ops.Tensor):
- raise TypeError("`choice_dataset` must be a dataset of scalar "
- "`tf.int64` tensors.")
- return _DirectedInterleaveDataset(choice_dataset, datasets)
+ return interleave_ops.choose_from_datasets(datasets, choice_dataset)
diff --git a/tensorflow/contrib/data/python/ops/iterator_ops.py b/tensorflow/contrib/data/python/ops/iterator_ops.py
index 18515e21ed..48c325c86f 100644
--- a/tensorflow/contrib/data/python/ops/iterator_ops.py
+++ b/tensorflow/contrib/data/python/ops/iterator_ops.py
@@ -16,15 +16,13 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.data.ops import iterator_ops
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import gen_dataset_ops
-from tensorflow.python.training import basic_session_run_hooks
-from tensorflow.python.training import checkpoint_management
-from tensorflow.python.training import saver as saver_lib
-from tensorflow.python.training import session_run_hook
+from tensorflow.python.data.experimental.ops import iterator_ops
+from tensorflow.python.util import deprecation
+
+@deprecation.deprecated(
+ None, "Use `tf.data.experimental.make_saveable_from_iterator(...)`.")
def make_saveable_from_iterator(iterator):
"""Returns a SaveableObject for saving/restore iterator state using Saver.
@@ -60,27 +58,10 @@ def make_saveable_from_iterator(iterator):
Note: Not all iterators support checkpointing yet. Attempting to save the
state of an unsupported iterator will throw an error.
"""
- return _Saveable(iterator._iterator_resource) # pylint: disable=protected-access
-
-
-class _Saveable(saver_lib.BaseSaverBuilder.SaveableObject):
- """SaveableObject for saving/restoring iterator state."""
+ return iterator_ops.make_saveable_from_iterator(iterator)
- def __init__(self, iterator_resource):
- serialized_iterator = gen_dataset_ops.serialize_iterator(iterator_resource)
- specs = [
- saver_lib.BaseSaverBuilder.SaveSpec(serialized_iterator, "",
- iterator_resource.name + "-state")
- ]
- super(_Saveable, self).__init__(iterator_resource, specs,
- iterator_resource.name)
- def restore(self, restored_tensors, unused_restored_shapes):
- with ops.colocate_with(self.op):
- return gen_dataset_ops.deserialize_iterator(self.op, restored_tensors[0])
-
-
-class CheckpointInputPipelineHook(session_run_hook.SessionRunHook):
+class CheckpointInputPipelineHook(iterator_ops.CheckpointInputPipelineHook):
"""Checkpoints input pipeline state every N steps or seconds.
This hook saves the state of the iterators in the `Graph` so that when
@@ -125,135 +106,7 @@ class CheckpointInputPipelineHook(session_run_hook.SessionRunHook):
collector when building the eval graph.
"""
+ @deprecation.deprecated(
+ None, "Use `tf.data.experimental.CheckpointInputPipelineHook(...)`.")
def __init__(self, estimator):
- """Initializes a `CheckpointInputPipelineHook`.
-
- Args:
- estimator: Estimator.
-
- Raises:
- ValueError: One of `save_steps` or `save_secs` should be set.
- ValueError: At most one of saver or scaffold should be set.
- """
- # `checkpoint_basename` is "input.ckpt" for non-distributed pipelines or
- # of the form "input_<task_type>_<task_id>.ckpt" for distributed pipelines.
- # Note: The default `checkpoint_basename` used by `CheckpointSaverHook` is
- # "model.ckpt". We intentionally choose the input pipeline checkpoint prefix
- # to be different to avoid conflicts with the model checkpoint.
-
- # pylint: disable=protected-access
- checkpoint_prefix = "input"
- if estimator._config.num_worker_replicas > 1:
- # Distributed setting.
- suffix = "_{}_{}".format(estimator._config.task_type,
- estimator._config.task_id)
- checkpoint_prefix += suffix
- # pylint: enable=protected-access
-
- # We use a composition paradigm instead of inheriting from
- # `CheckpointSaverHook` because `Estimator` does an `isinstance` check
- # to check whether a `CheckpointSaverHook` is already present in the list
- # of hooks and if not, adds one. Inheriting from `CheckpointSaverHook`
- # would thwart this behavior. This hook checkpoints *only the iterators*
- # and not the graph variables.
- self._checkpoint_saver_hook = basic_session_run_hooks.CheckpointSaverHook(
- estimator.model_dir,
- save_secs=estimator._config.save_checkpoints_secs, # pylint: disable=protected-access
- save_steps=estimator._config.save_checkpoints_steps, # pylint: disable=protected-access
- checkpoint_basename=checkpoint_prefix + ".ckpt")
-
- # Name for the protocol buffer file that will contain the list of most
- # recent checkpoints stored as a `CheckpointState` protocol buffer.
- # This file, kept in the same directory as the checkpoint files, is
- # automatically managed by the `Saver` to keep track of recent checkpoints.
- # The default name used by the `Saver` for this file is "checkpoint". Here
- # we use the name "checkpoint_<checkpoint_prefix>" so that in case the
- # `checkpoint_dir` is the same as the model checkpoint directory, there are
- # no conflicts during restore.
- self._latest_filename = "checkpoint_" + checkpoint_prefix
- self._first_run = True
-
- def begin(self):
- # Build a Saver that saves all iterators in the `GLOBAL_ITERATORS`
- # collection if no `Saver` or `Scaffold` is provided.
- # pylint: disable=protected-access
- if (self._checkpoint_saver_hook._saver is None and
- self._checkpoint_saver_hook._scaffold is None):
- iterators = ops.get_collection(iterator_ops.GLOBAL_ITERATORS)
- saveables = [_Saveable(i) for i in iterators]
- self._checkpoint_saver_hook._saver = _CustomSaver(saveables,
- self._latest_filename)
- # pylint: enable=protected-access
- self._checkpoint_saver_hook.begin()
-
- def _restore_or_save_initial_ckpt(self, session):
- # Ideally this should be run in after_create_session but is not for the
- # following reason:
- # Currently there is no way of enforcing an order of running the
- # `SessionRunHooks`. Hence it is possible that the `_DatasetInitializerHook`
- # is run *after* this hook. That is troublesome because
- # 1. If a checkpoint exists and this hook restores it, the initializer hook
- # will override it.
- # 2. If no checkpoint exists, this hook will try to save an initialized
- # iterator which will result in an exception.
- #
- # As a temporary fix we enter the following implicit contract between this
- # hook and the _DatasetInitializerHook.
- # 1. The _DatasetInitializerHook initializes the iterator in the call to
- # after_create_session.
- # 2. This hook saves the iterator on the first call to `before_run()`, which
- # is guaranteed to happen after `after_create_session()` of all hooks
- # have been run.
-
- # Check if there is an existing checkpoint. If so, restore from it.
- # pylint: disable=protected-access
- latest_checkpoint_path = checkpoint_management.latest_checkpoint(
- self._checkpoint_saver_hook._checkpoint_dir,
- latest_filename=self._latest_filename)
- if latest_checkpoint_path:
- self._checkpoint_saver_hook._get_saver().restore(session,
- latest_checkpoint_path)
- else:
- # The checkpoint saved here is the state at step "global_step".
- # Note: We do not save the GraphDef or MetaGraphDef here.
- global_step = session.run(self._checkpoint_saver_hook._global_step_tensor)
- self._checkpoint_saver_hook._save(session, global_step)
- self._checkpoint_saver_hook._timer.update_last_triggered_step(global_step)
- # pylint: enable=protected-access
-
- def before_run(self, run_context):
- if self._first_run:
- self._restore_or_save_initial_ckpt(run_context.session)
- self._first_run = False
- return self._checkpoint_saver_hook.before_run(run_context)
-
- def after_run(self, run_context, run_values):
- self._checkpoint_saver_hook.after_run(run_context, run_values)
-
- def end(self, session):
- self._checkpoint_saver_hook.end(session)
-
-
-class _CustomSaver(saver_lib.Saver):
- """`Saver` with a different default `latest_filename`.
-
- This is used in the `CheckpointInputPipelineHook` to avoid conflicts with
- the model ckpt saved by the `CheckpointSaverHook`.
- """
-
- def __init__(self, var_list, latest_filename):
- super(_CustomSaver, self).__init__(var_list)
- self._latest_filename = latest_filename
-
- def save(self,
- sess,
- save_path,
- global_step=None,
- latest_filename=None,
- meta_graph_suffix="meta",
- write_meta_graph=True,
- write_state=True,
- strip_default_attrs=False):
- return super(_CustomSaver, self).save(
- sess, save_path, global_step, latest_filename or self._latest_filename,
- meta_graph_suffix, write_meta_graph, write_state, strip_default_attrs)
+ super(CheckpointInputPipelineHook, self).__init__(estimator)
diff --git a/tensorflow/contrib/data/python/ops/map_defun.py b/tensorflow/contrib/data/python/ops/map_defun.py
deleted file mode 100644
index 3d0d0993c9..0000000000
--- a/tensorflow/contrib/data/python/ops/map_defun.py
+++ /dev/null
@@ -1,56 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Experimental API for optimizing `tf.data` pipelines."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.ops import gen_dataset_ops
-
-
-def map_defun(fn, elems, output_dtypes, output_shapes):
- """Map a function on the list of tensors unpacked from `elems` on dimension 0.
-
- Args:
- fn: A function (`function.Defun`) that takes a list of tensors and returns
- another list of tensors. The output list has the same types as
- output_dtypes. The elements of the output list have the same dimension 0
- as `elems`, and the remaining dimensions correspond to those of
- `fn_output_shapes`.
- elems: A list of tensors.
- output_dtypes: A list of dtypes corresponding to the output types of the
- function.
- output_shapes: A list of `TensorShape`s corresponding to the output
- shapes from each invocation of the function on slices of inputs.
-
- Raises:
- ValueError: if any of the inputs are malformed.
-
- Returns:
- A list of `Tensor` objects with the same types as `output_dtypes`.
- """
- if not isinstance(elems, list):
- raise ValueError("`elems` must be a list of tensors.")
- if not isinstance(output_dtypes, list):
- raise ValueError("`output_dtypes` must be a list of tensors.")
- if not isinstance(output_shapes, list):
- raise ValueError("`output_shapes` must be a list of tensors.")
-
- elems = [ops.convert_to_tensor(e) for e in elems]
- output_shapes = [tensor_shape.TensorShape(s) for s in output_shapes]
- return gen_dataset_ops.map_defun(elems, output_dtypes, output_shapes, fn)
diff --git a/tensorflow/contrib/data/python/ops/optimization.py b/tensorflow/contrib/data/python/ops/optimization.py
deleted file mode 100644
index 30348ede36..0000000000
--- a/tensorflow/contrib/data/python/ops/optimization.py
+++ /dev/null
@@ -1,171 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Experimental API for optimizing `tf.data` pipelines."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import gen_dataset_ops
-from tensorflow.python.ops import gen_experimental_dataset_ops
-
-# A constant that can be used to enable auto-tuning.
-AUTOTUNE = -1
-
-
-# TODO(jsimsa): Support RE matching for both individual transformation (e.g. to
-# account for indexing) and transformation sequence.
-def assert_next(transformations):
- """A transformation that asserts which transformations happen next.
-
- Args:
- transformations: A `tf.string` vector `tf.Tensor` identifying the
- transformations that are expected to happen next.
-
- Returns:
- A `Dataset` transformation function, which can be passed to
- `tf.data.Dataset.apply`.
- """
-
- def _apply_fn(dataset):
- """Function from `Dataset` to `Dataset` that applies the transformation."""
- return _AssertNextDataset(dataset, transformations)
-
- return _apply_fn
-
-
-def model():
- """A transformation that models performance.
-
- Returns:
- A `Dataset` transformation function, which can be passed to
- `tf.data.Dataset.apply`.
- """
-
- def _apply_fn(dataset):
- """Function from `Dataset` to `Dataset` that applies the transformation."""
- return _ModelDataset(dataset)
-
- return _apply_fn
-
-
-def optimize(optimizations=None):
- """A transformation that applies optimizations.
-
- Args:
- optimizations: (Optional.) A `tf.string` vector `tf.Tensor` identifying
- optimizations to use. If not specified, the default set of optimizations
- is applied.
-
- Returns:
- A `Dataset` transformation function, which can be passed to
- `tf.data.Dataset.apply`.
- """
-
- def _apply_fn(dataset):
- """Function from `Dataset` to `Dataset` that applies the transformation."""
- return _OptimizeDataset(dataset, optimizations)
-
- return _apply_fn
-
-
-class _AssertNextDataset(dataset_ops.UnaryDataset):
- """A `Dataset` that asserts which transformations happen next."""
-
- def __init__(self, input_dataset, transformations):
- """See `assert_next()` for details."""
- super(_AssertNextDataset, self).__init__(input_dataset)
- self._input_dataset = input_dataset
- if transformations is None:
- raise ValueError("At least one transformation should be specified")
- self._transformations = ops.convert_to_tensor(
- transformations, dtype=dtypes.string, name="transformations")
-
- def _as_variant_tensor(self):
- return gen_experimental_dataset_ops.experimental_assert_next_dataset(
- self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- self._transformations,
- **dataset_ops.flat_structure(self))
-
- @property
- def output_classes(self):
- return self._input_dataset.output_classes
-
- @property
- def output_shapes(self):
- return self._input_dataset.output_shapes
-
- @property
- def output_types(self):
- return self._input_dataset.output_types
-
-
-class _ModelDataset(dataset_ops.UnaryDataset):
- """A `Dataset` that acts as an identity, and models performance."""
-
- def __init__(self, input_dataset):
- """See `optimize()` for details."""
- super(_ModelDataset, self).__init__(input_dataset)
- self._input_dataset = input_dataset
-
- def _as_variant_tensor(self):
- return gen_dataset_ops.model_dataset(
- self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- **dataset_ops.flat_structure(self))
-
- @property
- def output_classes(self):
- return self._input_dataset.output_classes
-
- @property
- def output_shapes(self):
- return self._input_dataset.output_shapes
-
- @property
- def output_types(self):
- return self._input_dataset.output_types
-
-
-class _OptimizeDataset(dataset_ops.UnaryDataset):
- """A `Dataset` that acts as an identity, and applies optimizations."""
-
- def __init__(self, input_dataset, optimizations):
- """See `optimize()` for details."""
- super(_OptimizeDataset, self).__init__(input_dataset)
- self._input_dataset = input_dataset
- if optimizations is None:
- optimizations = []
- self._optimizations = ops.convert_to_tensor(
- optimizations, dtype=dtypes.string, name="optimizations")
-
- def _as_variant_tensor(self):
- return gen_dataset_ops.optimize_dataset(
- self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- self._optimizations,
- **dataset_ops.flat_structure(self))
-
- @property
- def output_classes(self):
- return self._input_dataset.output_classes
-
- @property
- def output_shapes(self):
- return self._input_dataset.output_shapes
-
- @property
- def output_types(self):
- return self._input_dataset.output_types
diff --git a/tensorflow/contrib/data/python/ops/parsing_ops.py b/tensorflow/contrib/data/python/ops/parsing_ops.py
index cfbba701b0..3aeee9d8e4 100644
--- a/tensorflow/contrib/data/python/ops/parsing_ops.py
+++ b/tensorflow/contrib/data/python/ops/parsing_ops.py
@@ -17,92 +17,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.util import nest
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.ops import gen_dataset_ops
-from tensorflow.python.ops import parsing_ops
+from tensorflow.python.data.experimental.ops import parsing_ops
+from tensorflow.python.util import deprecation
-class _ParseExampleDataset(dataset_ops.UnaryDataset):
- """A `Dataset` that parses `example` dataset into a `dict` dataset."""
-
- def __init__(self, input_dataset, features, num_parallel_calls):
- super(_ParseExampleDataset, self).__init__(input_dataset)
- self._input_dataset = input_dataset
- if not all(types == dtypes.string
- for types in nest.flatten(input_dataset.output_types)):
- raise TypeError("Input dataset should be a dataset of vectors of strings")
- self._num_parallel_calls = num_parallel_calls
- # pylint: disable=protected-access
- self._features = parsing_ops._prepend_none_dimension(features)
- # sparse_keys and dense_keys come back sorted here.
- (sparse_keys, sparse_types, dense_keys, dense_types, dense_defaults,
- dense_shapes) = parsing_ops._features_to_raw_params(
- self._features, [
- parsing_ops.VarLenFeature, parsing_ops.SparseFeature,
- parsing_ops.FixedLenFeature, parsing_ops.FixedLenSequenceFeature
- ])
- # TODO(b/112859642): Pass sparse_index and sparse_values for SparseFeature.
- (_, dense_defaults_vec, sparse_keys, sparse_types, dense_keys, dense_shapes,
- dense_shape_as_shape) = parsing_ops._process_raw_parameters(
- None, dense_defaults, sparse_keys, sparse_types, dense_keys,
- dense_types, dense_shapes)
- # pylint: enable=protected-access
- self._sparse_keys = sparse_keys
- self._sparse_types = sparse_types
- self._dense_keys = dense_keys
- self._dense_defaults = dense_defaults_vec
- self._dense_shapes = dense_shapes
- self._dense_types = dense_types
- dense_output_shapes = [
- self._input_dataset.output_shapes.concatenate(shape)
- for shape in dense_shape_as_shape
- ]
- sparse_output_shapes = [
- self._input_dataset.output_shapes.concatenate([None])
- for _ in range(len(sparse_keys))
- ]
-
- self._output_shapes = dict(
- zip(self._dense_keys + self._sparse_keys,
- dense_output_shapes + sparse_output_shapes))
- self._output_types = dict(
- zip(self._dense_keys + self._sparse_keys,
- self._dense_types + self._sparse_types))
- self._output_classes = dict(
- zip(self._dense_keys + self._sparse_keys,
- [ops.Tensor for _ in range(len(self._dense_defaults))] +
- [sparse_tensor.SparseTensor for _ in range(len(self._sparse_keys))
- ]))
-
- def _as_variant_tensor(self):
- return gen_dataset_ops.parse_example_dataset(
- self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- self._num_parallel_calls,
- self._dense_defaults,
- self._sparse_keys,
- self._dense_keys,
- self._sparse_types,
- self._dense_shapes,
- **dataset_ops.flat_structure(self))
-
- @property
- def output_shapes(self):
- return self._output_shapes
-
- @property
- def output_types(self):
- return self._output_types
-
- @property
- def output_classes(self):
- return self._output_classes
-
-
-# TODO(b/111553342): add arguments names and example names as well.
+@deprecation.deprecated(
+ None, "Use `tf.data.experimental.parse_example_dataset(...)`.")
def parse_example_dataset(features, num_parallel_calls=1):
"""A transformation that parses `Example` protos into a `dict` of tensors.
@@ -130,21 +50,4 @@ def parse_example_dataset(features, num_parallel_calls=1):
Raises:
ValueError: if features argument is None.
"""
- if features is None:
- raise ValueError("Missing: features was %s." % features)
-
- def _apply_fn(dataset):
- """Function from `Dataset` to `Dataset` that applies the transformation."""
- out_dataset = _ParseExampleDataset(dataset, features, num_parallel_calls)
- if any([
- isinstance(feature, parsing_ops.SparseFeature)
- for _, feature in features.items()
- ]):
- # pylint: disable=protected-access
- # pylint: disable=g-long-lambda
- out_dataset = out_dataset.map(
- lambda x: parsing_ops._construct_sparse_tensors_for_sparse_features(
- features, x), num_parallel_calls=num_parallel_calls)
- return out_dataset
-
- return _apply_fn
+ return parsing_ops.parse_example_dataset(features, num_parallel_calls)
diff --git a/tensorflow/contrib/data/python/ops/prefetching_ops.py b/tensorflow/contrib/data/python/ops/prefetching_ops.py
index 46f82e453a..adfb390cd9 100644
--- a/tensorflow/contrib/data/python/ops/prefetching_ops.py
+++ b/tensorflow/contrib/data/python/ops/prefetching_ops.py
@@ -17,321 +17,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import warnings
-
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.ops import iterator_ops
-from tensorflow.python.data.util import nest
-from tensorflow.python.data.util import sparse
-from tensorflow.python.eager import context
-from tensorflow.python.framework import device as framework_device
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import function
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import functional_ops
-from tensorflow.python.ops import gen_dataset_ops
-from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
-from tensorflow.python.ops import resource_variable_ops
-
-
-def function_buffering_resource(string_arg,
- target_device,
- f,
- buffer_size,
- output_types,
- container="",
- shared_name=None,
- name=None):
- """Creates a FunctionBufferingResource.
-
- A FunctionBufferingResource fills up a buffer by calling a function `f` on
- `target_device`. `f` should take in only a single string argument as input.
-
- Args:
- string_arg: The single string argument to the function.
- target_device: The device to run `f` on.
- f: The function to be executed.
- buffer_size: Size of the buffer to be populated.
- output_types: The output types generated by the function.
- container: (Optional) string. Defaults to "".
- shared_name: (Optional) string.
- name: (Optional) string to name the op.
-
- Returns:
- Handle to a FunctionBufferingResource.
- """
- if shared_name is None:
- shared_name = ""
- return ged_ops.experimental_function_buffering_resource(
- string_arg=string_arg,
- target_device=target_device,
- shared_name=shared_name,
- f=f,
- buffer_size=buffer_size,
- container=container,
- name=name,
- output_types=output_types)
-
-
-def function_buffering_resource_get_next(function_buffer_resource,
- output_types,
- name=None):
- return ged_ops.experimental_function_buffering_resource_get_next(
- function_buffer_resource=function_buffer_resource,
- output_types=output_types,
- name=name)
-
-
-def function_buffering_resource_reset(function_buffer_resource, name=None):
- return ged_ops.experimental_function_buffering_resource_reset(
- function_buffer_resource=function_buffer_resource, name=name)
-
-
-# pylint: disable=protected-access
-class _PrefetchToDeviceIterator(object):
- """A replacement for `tf.data.Iterator` that prefetches to another device.
-
- Args:
- input_dataset: The input dataset
- one_shot: If true, we make a one shot iterator that's already initialized.
- device: A fully specified device string where we want to prefetch to
- buffer_size: Size of the prefetching buffer.
- shared_name: (Optional.) If non-empty, the returned iterator will be
- shared under the given name across multiple sessions that share the
- same devices (e.g. when using a remote server).
-
- Returns:
- An Iterator type object.
- """
-
- def __init__(self,
- input_dataset,
- one_shot,
- device,
- buffer_size,
- shared_name=None):
- self._input_dataset = input_dataset
- self._get_next_call_count = 0
- self._one_shot = one_shot
- if shared_name is None:
- shared_name = ""
-
- if self._one_shot:
- self._input_iterator = input_dataset.make_one_shot_iterator()
- else:
- self._input_iterator = iterator_ops.Iterator.from_structure(
- self._input_dataset.output_types, self._input_dataset.output_shapes,
- shared_name, self._input_dataset.output_classes)
- input_iterator_handle = self._input_iterator.string_handle()
-
- @function.Defun(dtypes.string)
- def _prefetch_fn(handle):
- """Prefetches one element from `input_iterator`."""
- remote_iterator = iterator_ops.Iterator.from_string_handle(
- handle, self._input_iterator.output_types,
- self._input_iterator.output_shapes,
- self._input_iterator.output_classes)
- ret = remote_iterator.get_next()
- return nest.flatten(sparse.serialize_sparse_tensors(ret))
-
- iterator_device = ged_ops.experimental_iterator_get_device(
- self._input_iterator._iterator_resource)
-
- with ops.device(device):
- self._buffering_resource = function_buffering_resource(
- f=_prefetch_fn,
- target_device=iterator_device,
- string_arg=input_iterator_handle,
- buffer_size=buffer_size,
- shared_name=shared_name,
- output_types=nest.flatten(
- sparse.as_dense_types(self._input_dataset.output_types,
- self._input_dataset.output_classes)))
-
- if not self._one_shot:
- reset_op = function_buffering_resource_reset(self._buffering_resource)
- with ops.control_dependencies([reset_op]):
- self._initializer = self._input_iterator.make_initializer(
- self._input_dataset)
-
- def get_next(self, name=None):
- """See `tf.data.Iterator.get_next`."""
- self._get_next_call_count += 1
- if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD:
- warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE)
-
- flat_ret = ged_ops.experimental_function_buffering_resource_get_next(
- self._buffering_resource,
- output_types=nest.flatten(
- sparse.as_dense_types(self.output_types, self.output_classes)),
- name=name)
-
- ret = sparse.deserialize_sparse_tensors(
- nest.pack_sequence_as(self.output_types, flat_ret),
- self.output_types, self.output_shapes, self.output_classes)
-
- for tensor, shape in zip(
- nest.flatten(ret), nest.flatten(self.output_shapes)):
- if isinstance(tensor, ops.Tensor):
- tensor.set_shape(shape)
-
- return ret
-
- @property
- def initializer(self):
- if self._one_shot:
- raise NotImplementedError("Can't initialize a one_shot_iterator")
- return self._initializer
-
- @property
- def output_classes(self):
- return self._input_dataset.output_classes
-
- @property
- def output_shapes(self):
- return self._input_dataset.output_shapes
-
- @property
- def output_types(self):
- return self._input_dataset.output_types
-
-
-class _PrefetchToDeviceEagerIterator(iterator_ops.EagerIterator):
- """A replacement for `tf.data.Iterator` that prefetches to another device.
-
- Args:
- input_dataset: The input dataset
- one_shot: If true, we make a one shot iterator that's already initialized.
- device: A fully specified device string where we want to prefetch to
- buffer_size: Size of the prefetching buffer.
- shared_name: (Optional.) If non-empty, the returned iterator will be
- shared under the given name across multiple sessions that share the
- same devices (e.g. when using a remote server).
-
- Returns:
- An Iterator type object.
- """
-
- def __init__(self,
- input_dataset,
- device,
- buffer_size):
- with ops.device("/device:CPU:0"):
- super(_PrefetchToDeviceEagerIterator, self).__init__(input_dataset)
- input_iterator_handle = gen_dataset_ops.iterator_to_string_handle(
- self._resource)
-
- self._device = device
-
- @function.Defun(dtypes.string)
- def _prefetch_fn(handle):
- """Prefetches one element from `input_iterator`."""
- remote_iterator = iterator_ops.Iterator.from_string_handle(
- handle, self.output_types, self.output_shapes, self.output_classes)
- ret = remote_iterator.get_next()
- return nest.flatten(sparse.serialize_sparse_tensors(ret))
-
- _prefetch_fn.add_to_graph(None)
-
- with ops.device(device):
- self._buffering_resource = function_buffering_resource(
- f=_prefetch_fn,
- output_types=self._flat_output_types,
- target_device=ged_ops.experimental_iterator_get_device(
- self._resource),
- string_arg=input_iterator_handle,
- buffer_size=buffer_size,
- shared_name=iterator_ops._generate_shared_name(
- "function_buffer_resource"))
-
- def _next_internal(self):
- """Returns a nested structure of `tf.Tensor`s containing the next element.
- """
- # This runs in sync mode as iterators use an error status to communicate
- # that there is no more data to iterate over.
- # TODO(b/77291417): Fix
- with context.execution_mode(context.SYNC):
- with ops.device(self._device):
- ret = ged_ops.experimental_function_buffering_resource_get_next(
- function_buffer_resource=self._buffering_resource,
- output_types=self._flat_output_types)
- return sparse.deserialize_sparse_tensors(
- nest.pack_sequence_as(self._output_types, ret), self._output_types,
- self._output_shapes, self._output_classes)
-# pylint: enable=protected-access
-
-
-class _PrefetchToDeviceDataset(dataset_ops.UnaryDataset):
- """A `Dataset` whose iterator prefetches elements to another device."""
-
- def __init__(self, input_dataset, device, buffer_size):
- super(_PrefetchToDeviceDataset, self).__init__(input_dataset)
- self._input_dataset = input_dataset
- self._device = device
- self._buffer_size = buffer_size if buffer_size is not None else 1
-
- # The static analysis cannot tell that the eager iterator's superclass has
- # a `next()` method.
- # pylint: disable=non-iterator-returned
- def __iter__(self):
- """Creates an `Iterator` for enumerating the elements of this dataset.
-
- The returned iterator implements the Python iterator protocol and therefore
- can only be used in eager mode.
-
- Returns:
- An `Iterator` over the elements of this dataset.
-
- Raises:
- RuntimeError: If eager execution is enabled.
- """
- if context.executing_eagerly():
- return _PrefetchToDeviceEagerIterator(self._input_dataset, self._device,
- self._buffer_size)
- else:
- raise RuntimeError("dataset.__iter__() is only supported when eager "
- "execution is enabled.")
- # pylint: enable=non-iterator-returned
-
- def make_one_shot_iterator(self):
- if context.executing_eagerly():
- return _PrefetchToDeviceEagerIterator(self._input_dataset, self._device,
- self._buffer_size)
- else:
- return _PrefetchToDeviceIterator(self._input_dataset, one_shot=True,
- device=self._device,
- buffer_size=self._buffer_size)
-
- def make_initializable_iterator(self, shared_name=None):
- return _PrefetchToDeviceIterator(
- self._input_dataset,
- one_shot=False,
- device=self._device,
- buffer_size=self._buffer_size,
- shared_name=shared_name)
-
- def _as_variant_tensor(self):
- # TODO(mrry): Raise this error earlier (e.g. when one of the Dataset
- # transformation methods is called.
- # TODO(mrry): Investigate support for chaining further transformations after
- # the prefetch, including GPU support.
- raise NotImplementedError("`prefetch_to_device()` must be the last "
- "transformation in a dataset pipeline.")
-
- @property
- def output_types(self):
- return self._input_dataset.output_types
-
- @property
- def output_shapes(self):
- return self._input_dataset.output_shapes
-
- @property
- def output_classes(self):
- return self._input_dataset.output_classes
+from tensorflow.python.data.experimental.ops import prefetching_ops
+from tensorflow.python.util import deprecation
+@deprecation.deprecated(None,
+ "Use `tf.data.experimental.prefetch_to_device(...)`.")
def prefetch_to_device(device, buffer_size=None):
"""A transformation that prefetches dataset values to the given `device`.
@@ -347,12 +38,10 @@ def prefetch_to_device(device, buffer_size=None):
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
- def _apply_fn(dataset):
- return _PrefetchToDeviceDataset(dataset, device, buffer_size)
-
- return _apply_fn
+ return prefetching_ops.prefetch_to_device(device, buffer_size)
+@deprecation.deprecated(None, "Use `tf.data.experimental.copy_to_device(...)`.")
def copy_to_device(target_device, source_device="/cpu:0"):
"""A transformation that copies dataset elements to the given `target_device`.
@@ -364,165 +53,4 @@ def copy_to_device(target_device, source_device="/cpu:0"):
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
-
- def _apply_fn(dataset):
- return _CopyToDeviceDataset(
- dataset, target_device=target_device, source_device=source_device)
-
- return _apply_fn
-
-
-# TODO(rohanj): Use the _input_hostmem attr on the RemoteCall ops to indicate
-# all inputs to the Op are in host memory, thereby avoiding some unnecessary
-# Sends and Recvs.
-class _CopyToDeviceDataset(dataset_ops.UnaryDataset):
- """A `Dataset` that copies elements to another device."""
-
- def __init__(self, input_dataset, target_device, source_device="/cpu:0"):
- """Constructs a _CopyToDeviceDataset.
-
- Args:
- input_dataset: `Dataset` to be copied
- target_device: The name of the device to which elements would be copied.
- source_device: Device where input_dataset would be placed.
- """
- super(_CopyToDeviceDataset, self).__init__(input_dataset)
- self._input_dataset = input_dataset
- self._target_device = target_device
- spec = framework_device.DeviceSpec().from_string(self._target_device)
- self._is_gpu_target = (spec.device_type == "GPU")
- self._source_device_string = source_device
- self._source_device = ops.convert_to_tensor(source_device)
-
- self._flat_output_shapes = nest.flatten(
- sparse.as_dense_shapes(self._input_dataset.output_shapes,
- self._input_dataset.output_classes))
- self._flat_output_types = nest.flatten(
- sparse.as_dense_types(self._input_dataset.output_types,
- self._input_dataset.output_classes))
-
- @function.Defun()
- def _init_func():
- """Creates an iterator for the input dataset.
-
- Returns:
- A `string` tensor that encapsulates the iterator created.
- """
- # pylint: disable=protected-access
- ds_variant = self._input_dataset._as_variant_tensor()
- resource = gen_dataset_ops.anonymous_iterator(
- output_types=self._flat_output_types,
- output_shapes=self._flat_output_shapes)
- with ops.control_dependencies(
- [gen_dataset_ops.make_iterator(ds_variant, resource)]):
- return gen_dataset_ops.iterator_to_string_handle(resource)
-
- @function.Defun()
- def _remote_init_func():
- return functional_ops.remote_call(
- target=self._source_device,
- args=_init_func.captured_inputs,
- Tout=[dtypes.string],
- f=_init_func)
-
- self._init_func = _remote_init_func
- self._init_captured_args = _remote_init_func.captured_inputs
-
- @function.Defun(dtypes.string)
- def _next_func(string_handle):
- """Calls get_next for created iterator.
-
- Args:
- string_handle: An iterator string handle created by _init_func
- Returns:
- The elements generated from `input_dataset`
- """
- with ops.device(self._source_device_string):
- iterator = iterator_ops.Iterator.from_string_handle(
- string_handle, self.output_types, self.output_shapes,
- self.output_classes)
- ret = iterator.get_next()
- return nest.flatten(sparse.serialize_sparse_tensors(ret))
-
- @function.Defun(dtypes.string)
- def _remote_next_func(string_handle):
- return functional_ops.remote_call(
- target=self._source_device,
- args=[string_handle] + _next_func.captured_inputs,
- Tout=self._flat_output_types,
- f=_next_func)
-
- self._next_func = _remote_next_func
- self._next_captured_args = _remote_next_func.captured_inputs
-
- @function.Defun(dtypes.string)
- def _finalize_func(string_handle):
- """Destroys the iterator resource created.
-
- Args:
- string_handle: An iterator string handle created by _init_func
- Returns:
- Tensor constant 0
- """
- iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2(
- string_handle,
- output_types=self._flat_output_types,
- output_shapes=self._flat_output_shapes)
- with ops.control_dependencies([
- resource_variable_ops.destroy_resource_op(
- iterator_resource, ignore_lookup_error=True)]):
- return array_ops.constant(0, dtypes.int64)
-
- @function.Defun(dtypes.string)
- def _remote_finalize_func(string_handle):
- return functional_ops.remote_call(
- target=self._source_device,
- args=[string_handle] + _finalize_func.captured_inputs,
- Tout=[dtypes.int64],
- f=_finalize_func)
-
- self._finalize_func = _remote_finalize_func
- self._finalize_captured_args = _remote_finalize_func.captured_inputs
-
- g = ops.get_default_graph()
- _remote_init_func.add_to_graph(g)
- _remote_next_func.add_to_graph(g)
- _remote_finalize_func.add_to_graph(g)
- # pylint: enable=protected-scope
-
- # The one_shot_iterator implementation needs a 0 arg _make_dataset function
- # that thereby captures all the inputs required to create the dataset. Since
- # there are strings that are inputs to the GeneratorDataset which can't be
- # placed on a GPU, this fails for the GPU case. Therefore, disabling it for
- # GPU
- def make_one_shot_iterator(self):
- if self._is_gpu_target:
- raise ValueError("Cannot create a one shot iterator when using "
- "`tf.contrib.data.copy_to_device()` on GPU. Please use "
- "`Dataset.make_initializable_iterator()` instead.")
- else:
- return super(_CopyToDeviceDataset, self).make_one_shot_iterator()
-
- def _as_variant_tensor(self):
- with ops.device(self._target_device):
- return gen_dataset_ops.generator_dataset(
- self._init_captured_args,
- self._next_captured_args,
- self._finalize_captured_args,
- init_func=self._init_func,
- next_func=self._next_func,
- finalize_func=self._finalize_func,
- output_types=self._flat_output_types,
- output_shapes=self._flat_output_shapes)
-
- @property
- def output_types(self):
- return self._input_dataset.output_types
-
- @property
- def output_shapes(self):
- return self._input_dataset.output_shapes
-
- @property
- def output_classes(self):
- return self._input_dataset.output_classes
+ return prefetching_ops.copy_to_device(target_device, source_device)
diff --git a/tensorflow/contrib/data/python/ops/random_ops.py b/tensorflow/contrib/data/python/ops/random_ops.py
index 344a0763c8..2c95125636 100644
--- a/tensorflow/contrib/data/python/ops/random_ops.py
+++ b/tensorflow/contrib/data/python/ops/random_ops.py
@@ -17,36 +17,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.util import random_seed
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.data.experimental.ops import random_ops
+from tensorflow.python.util import deprecation
-class RandomDataset(dataset_ops.DatasetSource):
+class RandomDataset(random_ops.RandomDataset):
"""A `Dataset` of pseudorandom values."""
+ @deprecation.deprecated(
+ None, "Use `tf.data.experimental.RandomDataset(...)`.")
def __init__(self, seed=None):
- """A `Dataset` of pseudorandom values."""
- super(RandomDataset, self).__init__()
- self._seed, self._seed2 = random_seed.get_seed(seed)
-
- def _as_variant_tensor(self):
- return gen_dataset_ops.random_dataset(
- seed=self._seed,
- seed2=self._seed2,
- **dataset_ops.flat_structure(self))
-
- @property
- def output_classes(self):
- return ops.Tensor
-
- @property
- def output_shapes(self):
- return tensor_shape.scalar()
-
- @property
- def output_types(self):
- return dtypes.int64
+ super(RandomDataset, self).__init__(seed)
diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py
index 360971e200..4601376dff 100644
--- a/tensorflow/contrib/data/python/ops/readers.py
+++ b/tensorflow/contrib/data/python/ops/readers.py
@@ -17,295 +17,20 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import collections
-import csv
-
-import numpy as np
-
-from tensorflow.contrib.data.python.ops import batching
-from tensorflow.contrib.data.python.ops import interleave_ops
-from tensorflow.contrib.data.python.ops import optimization
-from tensorflow.contrib.data.python.ops import parsing_ops
-from tensorflow.contrib.data.python.ops import shuffle_ops
+from tensorflow.python.data.experimental.ops import optimization
+from tensorflow.python.data.experimental.ops import readers
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers as core_readers
-from tensorflow.python.data.util import convert
from tensorflow.python.data.util import nest
-from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
-from tensorflow.python.lib.io import file_io
-from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import gen_experimental_dataset_ops
-from tensorflow.python.platform import gfile
from tensorflow.python.util import deprecation
-_ACCEPTABLE_CSV_TYPES = (dtypes.float32, dtypes.float64, dtypes.int32,
- dtypes.int64, dtypes.string)
-
-
-def _is_valid_int32(str_val):
- try:
- # Checks equality to prevent int32 overflow
- return dtypes.int32.as_numpy_dtype(str_val) == dtypes.int64.as_numpy_dtype(
- str_val)
- except (ValueError, OverflowError):
- return False
-
-
-def _is_valid_int64(str_val):
- try:
- dtypes.int64.as_numpy_dtype(str_val)
- return True
- except (ValueError, OverflowError):
- return False
-
-
-def _is_valid_float(str_val, float_dtype):
- try:
- return float_dtype.as_numpy_dtype(str_val) < np.inf
- except ValueError:
- return False
-
-
-def _infer_type(str_val, na_value, prev_type):
- """Given a string, infers its tensor type.
-
- Infers the type of a value by picking the least 'permissive' type possible,
- while still allowing the previous type inference for this column to be valid.
-
- Args:
- str_val: String value to infer the type of.
- na_value: Additional string to recognize as a NA/NaN CSV value.
- prev_type: Type previously inferred based on values of this column that
- we've seen up till now.
- Returns:
- Inferred dtype.
- """
- if str_val in ("", na_value):
- # If the field is null, it gives no extra information about its type
- return prev_type
-
- type_list = [
- dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64, dtypes.string
- ] # list of types to try, ordered from least permissive to most
-
- type_functions = [
- _is_valid_int32,
- _is_valid_int64,
- lambda str_val: _is_valid_float(str_val, dtypes.float32),
- lambda str_val: _is_valid_float(str_val, dtypes.float64),
- lambda str_val: True,
- ] # Corresponding list of validation functions
-
- for i in range(len(type_list)):
- validation_fn = type_functions[i]
- if validation_fn(str_val) and (prev_type is None or
- prev_type in type_list[:i + 1]):
- return type_list[i]
-
-
-def _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header):
- """Generator that yields rows of CSV file(s) in order."""
- for fn in filenames:
- with file_io.FileIO(fn, "r") as f:
- rdr = csv.reader(
- f,
- delimiter=field_delim,
- quoting=csv.QUOTE_MINIMAL if use_quote_delim else csv.QUOTE_NONE)
- if header:
- next(rdr) # Skip header lines
-
- for csv_row in rdr:
- if len(csv_row) != num_cols:
- raise ValueError(
- "Problem inferring types: CSV row has different number of fields "
- "than expected.")
- yield csv_row
-
-
-def _infer_column_defaults(filenames, num_cols, field_delim, use_quote_delim,
- na_value, header, num_rows_for_inference,
- select_columns):
- """Infers column types from the first N valid CSV records of files."""
- if select_columns is None:
- select_columns = range(num_cols)
- inferred_types = [None] * len(select_columns)
-
- for i, csv_row in enumerate(
- _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header)):
- if num_rows_for_inference is not None and i >= num_rows_for_inference:
- break
-
- for j, col_index in enumerate(select_columns):
- inferred_types[j] = _infer_type(csv_row[col_index], na_value,
- inferred_types[j])
-
- # Replace None's with a default type
- inferred_types = [t or dtypes.string for t in inferred_types]
- # Default to 0 or '' for null values
- return [
- constant_op.constant([0 if t is not dtypes.string else ""], dtype=t)
- for t in inferred_types
- ]
-
-
-def _infer_column_names(filenames, field_delim, use_quote_delim):
- """Infers column names from first rows of files."""
- csv_kwargs = {
- "delimiter": field_delim,
- "quoting": csv.QUOTE_MINIMAL if use_quote_delim else csv.QUOTE_NONE
- }
- with file_io.FileIO(filenames[0], "r") as f:
- try:
- column_names = next(csv.reader(f, **csv_kwargs))
- except StopIteration:
- raise ValueError(("Received StopIteration when reading the header line "
- "of %s. Empty file?") % filenames[0])
-
- for name in filenames[1:]:
- with file_io.FileIO(name, "r") as f:
- try:
- if next(csv.reader(f, **csv_kwargs)) != column_names:
- raise ValueError(
- "Files have different column names in the header row.")
- except StopIteration:
- raise ValueError(("Received StopIteration when reading the header line "
- "of %s. Empty file?") % filenames[0])
- return column_names
-
-
-def _get_sorted_col_indices(select_columns, column_names):
- """Transforms select_columns argument into sorted column indices."""
- names_to_indices = {n: i for i, n in enumerate(column_names)}
- num_cols = len(column_names)
- for i, v in enumerate(select_columns):
- if isinstance(v, int):
- if v < 0 or v >= num_cols:
- raise ValueError(
- "Column index %d specified in select_columns out of valid range." %
- v)
- continue
- if v not in names_to_indices:
- raise ValueError(
- "Value '%s' specified in select_columns not a valid column index or "
- "name." % v)
- select_columns[i] = names_to_indices[v]
-
- # Sort and ensure there are no duplicates
- result = sorted(set(select_columns))
- if len(result) != len(select_columns):
- raise ValueError("select_columns contains duplicate columns")
- return result
-
-
-def _maybe_shuffle_and_repeat(
- dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed):
- """Optionally shuffle and repeat dataset, as requested."""
- if num_epochs != 1 and shuffle:
- # Use shuffle_and_repeat for perf
- return dataset.apply(
- shuffle_ops.shuffle_and_repeat(shuffle_buffer_size, num_epochs,
- shuffle_seed))
- elif shuffle:
- return dataset.shuffle(shuffle_buffer_size, shuffle_seed)
- elif num_epochs != 1:
- return dataset.repeat(num_epochs)
- return dataset
-
-
-def make_tf_record_dataset(file_pattern,
- batch_size,
- parser_fn=None,
- num_epochs=None,
- shuffle=True,
- shuffle_buffer_size=None,
- shuffle_seed=None,
- prefetch_buffer_size=optimization.AUTOTUNE,
- num_parallel_reads=None,
- num_parallel_parser_calls=None,
- drop_final_batch=False):
- """Reads and optionally parses TFRecord files into a dataset.
-
- Provides common functionality such as batching, optional parsing, shuffling,
- and performant defaults.
-
- Args:
- file_pattern: List of files or patterns of TFRecord file paths.
- See `tf.gfile.Glob` for pattern rules.
- batch_size: An int representing the number of records to combine
- in a single batch.
- parser_fn: (Optional.) A function accepting string input to parse
- and process the record contents. This function must map records
- to components of a fixed shape, so they may be batched. By
- default, uses the record contents unmodified.
- num_epochs: (Optional.) An int specifying the number of times this
- dataset is repeated. If None (the default), cycles through the
- dataset forever.
- shuffle: (Optional.) A bool that indicates whether the input
- should be shuffled. Defaults to `True`.
- shuffle_buffer_size: (Optional.) Buffer size to use for
- shuffling. A large buffer size ensures better shuffling, but
- increases memory usage and startup time.
- shuffle_seed: (Optional.) Randomization seed to use for shuffling.
- prefetch_buffer_size: (Optional.) An int specifying the number of
- feature batches to prefetch for performance improvement.
- Defaults to auto-tune. Set to 0 to disable prefetching.
- num_parallel_reads: (Optional.) Number of threads used to read
- records from files. By default or if set to a value >1, the
- results will be interleaved.
- num_parallel_parser_calls: (Optional.) Number of parallel
- records to parse in parallel. Defaults to an automatic selection.
- drop_final_batch: (Optional.) Whether the last batch should be
- dropped in case its size is smaller than `batch_size`; the
- default behavior is not to drop the smaller batch.
-
- Returns:
- A dataset, where each element matches the output of `parser_fn`
- except it will have an additional leading `batch-size` dimension,
- or a `batch_size`-length 1-D tensor of strings if `parser_fn` is
- unspecified.
- """
- files = dataset_ops.Dataset.list_files(
- file_pattern, shuffle=shuffle, seed=shuffle_seed)
-
- if num_parallel_reads is None:
- # Note: We considered auto-tuning this value, but there is a concern
- # that this affects the mixing of records from different files, which
- # could affect training convergence/accuracy, so we are defaulting to
- # a constant for now.
- num_parallel_reads = 24
- dataset = core_readers.TFRecordDataset(
- files, num_parallel_reads=num_parallel_reads)
-
- if shuffle_buffer_size is None:
- # TODO(josh11b): Auto-tune this value when not specified
- shuffle_buffer_size = 10000
- dataset = _maybe_shuffle_and_repeat(
- dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed)
-
- # NOTE(mrry): We set `drop_final_batch=True` when `num_epochs is None` to
- # improve the shape inference, because it makes the batch dimension static.
- # It is safe to do this because in that case we are repeating the input
- # indefinitely, and all batches will be full-sized.
- drop_final_batch = drop_final_batch or num_epochs is None
-
- if parser_fn is None:
- dataset = dataset.batch(batch_size, drop_remainder=drop_final_batch)
- else:
- # TODO(josh11b): if num_parallel_parser_calls is None, use some function
- # of num cores instead of map_and_batch's default behavior of one batch.
- dataset = dataset.apply(batching.map_and_batch(
- parser_fn, batch_size, num_parallel_calls=num_parallel_parser_calls,
- drop_remainder=drop_final_batch))
-
- if prefetch_buffer_size == 0:
- return dataset
- else:
- return dataset.prefetch(buffer_size=prefetch_buffer_size)
-
+@deprecation.deprecated(None,
+ "Use `tf.data.experimental.make_csv_dataset(...)`.")
def make_csv_dataset(
file_pattern,
batch_size,
@@ -387,7 +112,6 @@ def make_csv_dataset(
prefetch_buffer_size: An int specifying the number of feature
batches to prefetch for performance improvement. Recommended value is the
number of batches consumed per training step. Defaults to auto-tune.
-
num_parallel_reads: Number of threads used to read CSV records from files.
If >1, the results will be interleaved.
sloppy: If `True`, reading performance will be improved at
@@ -411,106 +135,18 @@ def make_csv_dataset(
Raises:
ValueError: If any of the arguments is malformed.
"""
- # Create dataset of all matching filenames
- filenames = _get_file_names(file_pattern, False)
- dataset = dataset_ops.Dataset.from_tensor_slices(filenames)
- if shuffle:
- dataset = dataset.shuffle(len(filenames), shuffle_seed)
-
- # Clean arguments; figure out column names and defaults
+ return readers.make_csv_dataset(
+ file_pattern, batch_size, column_names, column_defaults, label_name,
+ select_columns, field_delim, use_quote_delim, na_value, header,
+ num_epochs, shuffle, shuffle_buffer_size, shuffle_seed,
+ prefetch_buffer_size, num_parallel_reads, sloppy, num_rows_for_inference,
+ compression_type)
- if column_names is None:
- if not header:
- raise ValueError("Cannot infer column names without a header line.")
- # If column names are not provided, infer from the header lines
- column_names = _infer_column_names(filenames, field_delim, use_quote_delim)
- if len(column_names) != len(set(column_names)):
- raise ValueError("Cannot have duplicate column names.")
- if select_columns is not None:
- select_columns = _get_sorted_col_indices(select_columns, column_names)
-
- if column_defaults is not None:
- column_defaults = [
- constant_op.constant([], dtype=x) if x in _ACCEPTABLE_CSV_TYPES else x
- for x in column_defaults
- ]
- else:
- # If column defaults are not provided, infer from records at graph
- # construction time
- column_defaults = _infer_column_defaults(
- filenames, len(column_names), field_delim, use_quote_delim, na_value,
- header, num_rows_for_inference, select_columns)
-
- if select_columns is not None and len(column_defaults) != len(select_columns):
- raise ValueError(
- "If specified, column_defaults and select_columns must have same "
- "length."
- )
- if select_columns is not None and len(column_names) > len(select_columns):
- # Pick the relevant subset of column names
- column_names = [column_names[i] for i in select_columns]
-
- if label_name is not None and label_name not in column_names:
- raise ValueError("`label_name` provided must be one of the columns.")
-
- def filename_to_dataset(filename):
- return CsvDataset(
- filename,
- record_defaults=column_defaults,
- field_delim=field_delim,
- use_quote_delim=use_quote_delim,
- na_value=na_value,
- select_cols=select_columns,
- header=header,
- compression_type=compression_type,
- )
-
- def map_fn(*columns):
- """Organizes columns into a features dictionary.
-
- Args:
- *columns: list of `Tensor`s corresponding to one csv record.
- Returns:
- An OrderedDict of feature names to values for that particular record. If
- label_name is provided, extracts the label feature to be returned as the
- second element of the tuple.
- """
- features = collections.OrderedDict(zip(column_names, columns))
- if label_name is not None:
- label = features.pop(label_name)
- return features, label
- return features
-
- # Read files sequentially (if num_parallel_reads=1) or in parallel
- dataset = dataset.apply(
- interleave_ops.parallel_interleave(
- filename_to_dataset, cycle_length=num_parallel_reads, sloppy=sloppy))
-
- dataset = _maybe_shuffle_and_repeat(
- dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed)
-
- # Apply batch before map for perf, because map has high overhead relative
- # to the size of the computation in each map.
- # NOTE(mrry): We set `drop_remainder=True` when `num_epochs is None` to
- # improve the shape inference, because it makes the batch dimension static.
- # It is safe to do this because in that case we are repeating the input
- # indefinitely, and all batches will be full-sized.
- dataset = dataset.batch(batch_size=batch_size,
- drop_remainder=num_epochs is None)
- dataset = dataset_ops.MapDataset(
- dataset, map_fn, use_inter_op_parallelism=False)
- dataset = dataset.prefetch(prefetch_buffer_size)
-
- return dataset
-
-
-_DEFAULT_READER_BUFFER_SIZE_BYTES = 4 * 1024 * 1024 # 4 MB
-
-
-class CsvDataset(dataset_ops.DatasetSource):
+class CsvDataset(readers.CsvDataset):
"""A Dataset comprising lines from one or more CSV files."""
+ @deprecation.deprecated(None, "Use `tf.data.experimental.CsvDataset(...)`.")
def __init__(self,
filenames,
record_defaults,
@@ -521,140 +157,13 @@ class CsvDataset(dataset_ops.DatasetSource):
use_quote_delim=True,
na_value="",
select_cols=None):
- """Creates a `CsvDataset` by reading and decoding CSV files.
-
- The elements of this dataset correspond to records from the file(s).
- RFC 4180 format is expected for CSV files
- (https://tools.ietf.org/html/rfc4180)
- Note that we allow leading and trailing spaces with int or float field.
-
-
- For example, suppose we have a file 'my_file0.csv' with four CSV columns of
- different data types:
- ```
- abcdefg,4.28E10,5.55E6,12
- hijklmn,-5.3E14,,2
- ```
-
- We can construct a CsvDataset from it as follows:
- ```python
- dataset = tf.contrib.data.CsvDataset(
- "my_file*.csv",
- [tf.float32, # Required field, use dtype or empty tensor
- tf.constant([0.0], dtype=tf.float32), # Optional field, default to 0.0
- tf.int32, # Required field, use dtype or empty tensor
- ],
- select_cols=[1,2,3] # Only parse last three columns
- )
- ```
-
- The expected output of its iterations is:
- ```python
- next_element = dataset.make_one_shot_iterator().get_next()
- with tf.Session() as sess:
- while True:
- try:
- print(sess.run(next_element))
- except tf.errors.OutOfRangeError:
- break
-
- >> (4.28e10, 5.55e6, 12)
- >> (-5.3e14, 0.0, 2)
- ```
-
- Args:
- filenames: A `tf.string` tensor containing one or more filenames.
- record_defaults: A list of default values for the CSV fields. Each item in
- the list is either a valid CSV `DType` (float32, float64, int32, int64,
- string), or a `Tensor` object with one of the above types. One per
- column of CSV data, with either a scalar `Tensor` default value for the
- column if it is optional, or `DType` or empty `Tensor` if required. If
- both this and `select_columns` are specified, these must have the same
- lengths, and `column_defaults` is assumed to be sorted in order of
- increasing column index.
- compression_type: (Optional.) A `tf.string` scalar evaluating to one of
- `""` (no compression), `"ZLIB"`, or `"GZIP"`. Defaults to no
- compression.
- buffer_size: (Optional.) A `tf.int64` scalar denoting the number of bytes
- to buffer while reading files. Defaults to 4MB.
- header: (Optional.) A `tf.bool` scalar indicating whether the CSV file(s)
- have header line(s) that should be skipped when parsing. Defaults to
- `False`.
- field_delim: (Optional.) A `tf.string` scalar containing the delimiter
- character that separates fields in a record. Defaults to `","`.
- use_quote_delim: (Optional.) A `tf.bool` scalar. If `False`, treats
- double quotation marks as regular characters inside of string fields
- (ignoring RFC 4180, Section 2, Bullet 5). Defaults to `True`.
- na_value: (Optional.) A `tf.string` scalar indicating a value that will
- be treated as NA/NaN.
- select_cols: (Optional.) A sorted list of column indices to select from
- the input data. If specified, only this subset of columns will be
- parsed. Defaults to parsing all columns.
- """
- super(CsvDataset, self).__init__()
- self._filenames = ops.convert_to_tensor(
- filenames, dtype=dtypes.string, name="filenames")
- self._compression_type = convert.optional_param_to_tensor(
- "compression_type",
- compression_type,
- argument_default="",
- argument_dtype=dtypes.string)
- record_defaults = [
- constant_op.constant([], dtype=x) if x in _ACCEPTABLE_CSV_TYPES else x
- for x in record_defaults
- ]
- self._record_defaults = ops.convert_n_to_tensor(
- record_defaults, name="record_defaults")
- self._buffer_size = convert.optional_param_to_tensor(
- "buffer_size", buffer_size, _DEFAULT_READER_BUFFER_SIZE_BYTES)
- self._header = ops.convert_to_tensor(
- header, dtype=dtypes.bool, name="header")
- self._field_delim = ops.convert_to_tensor(
- field_delim, dtype=dtypes.string, name="field_delim")
- self._use_quote_delim = ops.convert_to_tensor(
- use_quote_delim, dtype=dtypes.bool, name="use_quote_delim")
- self._na_value = ops.convert_to_tensor(
- na_value, dtype=dtypes.string, name="na_value")
- self._select_cols = convert.optional_param_to_tensor(
- "select_cols",
- select_cols,
- argument_default=[],
- argument_dtype=dtypes.int64,
- )
- self._output_shapes = tuple(
- tensor_shape.scalar() for _ in range(len(record_defaults)))
- self._output_types = tuple(d.dtype for d in self._record_defaults)
- self._output_classes = tuple(
- ops.Tensor for _ in range(len(record_defaults)))
-
- def _as_variant_tensor(self):
- # Constructs graph node for the dataset op.
- return gen_experimental_dataset_ops.experimental_csv_dataset(
- filenames=self._filenames,
- record_defaults=self._record_defaults,
- buffer_size=self._buffer_size,
- header=self._header,
- output_shapes=self._output_shapes,
- field_delim=self._field_delim,
- use_quote_delim=self._use_quote_delim,
- na_value=self._na_value,
- select_cols=self._select_cols,
- compression_type=self._compression_type,
- )
-
- @property
- def output_types(self):
- return self._output_types
-
- @property
- def output_shapes(self):
- return self._output_shapes
-
- @property
- def output_classes(self):
- return self._output_classes
+ super(CsvDataset, self).__init__(
+ filenames, record_defaults, compression_type, buffer_size, header,
+ field_delim, use_quote_delim, na_value, select_cols)
+@deprecation.deprecated(
+ None, "Use `tf.data.experimental.make_batched_features_dataset(...)`.")
def make_batched_features_dataset(file_pattern,
batch_size,
features,
@@ -759,57 +268,15 @@ def make_batched_features_dataset(file_pattern,
Raises:
ValueError: If `label_key` is not one of the `features` keys.
"""
- # Create dataset of all matching filenames
- filenames = _get_file_names(file_pattern, False)
- dataset = dataset_ops.Dataset.from_tensor_slices(filenames)
- if shuffle:
- dataset = dataset.shuffle(len(filenames), shuffle_seed)
-
- # Read `Example` records from files as tensor objects.
- if reader_args is None:
- reader_args = []
+ return readers.make_batched_features_dataset(
+ file_pattern, batch_size, features, reader, label_key, reader_args,
+ num_epochs, shuffle, shuffle_buffer_size, shuffle_seed,
+ prefetch_buffer_size, reader_num_threads, parser_num_threads,
+ sloppy_ordering, drop_final_batch)
- # Read files sequentially (if reader_num_threads=1) or in parallel
- dataset = dataset.apply(
- interleave_ops.parallel_interleave(
- lambda filename: reader(filename, *reader_args),
- cycle_length=reader_num_threads,
- sloppy=sloppy_ordering))
- # Extract values if the `Example` tensors are stored as key-value tuples.
- if dataset.output_types == (dtypes.string, dtypes.string):
- dataset = dataset_ops.MapDataset(
- dataset, lambda _, v: v, use_inter_op_parallelism=False)
-
- # Apply dataset repeat and shuffle transformations.
- dataset = _maybe_shuffle_and_repeat(
- dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed)
-
- # NOTE(mrry): We set `drop_remainder=True` when `num_epochs is None` to
- # improve the shape inference, because it makes the batch dimension static.
- # It is safe to do this because in that case we are repeating the input
- # indefinitely, and all batches will be full-sized.
- dataset = dataset.batch(
- batch_size, drop_remainder=drop_final_batch or num_epochs is None)
-
- # Parse `Example` tensors to a dictionary of `Feature` tensors.
- dataset = dataset.apply(
- parsing_ops.parse_example_dataset(
- features, num_parallel_calls=parser_num_threads))
-
- if label_key:
- if label_key not in features:
- raise ValueError(
- "The `label_key` provided (%r) must be one of the `features` keys." %
- label_key)
- dataset = dataset.map(lambda x: (x, x.pop(label_key)))
-
- dataset = dataset.prefetch(prefetch_buffer_size)
- return dataset
-
-
-@deprecation.deprecated(None,
- "Use `tf.contrib.data.make_batched_features_dataset`")
+@deprecation.deprecated(
+ None, "Use `tf.data.experimental.make_batched_features_dataset(...)`")
def read_batch_features(file_pattern,
batch_size,
features,
@@ -879,7 +346,7 @@ def read_batch_features(file_pattern,
Returns:
A dict from keys in features to `Tensor` or `SparseTensor` objects.
"""
- dataset = make_batched_features_dataset(
+ dataset = readers.make_batched_features_dataset(
file_pattern,
batch_size,
features,
@@ -893,96 +360,13 @@ def read_batch_features(file_pattern,
return outputs
-def _get_file_names(file_pattern, shuffle):
- """Parse list of file names from pattern, optionally shuffled.
-
- Args:
- file_pattern: File glob pattern, or list of glob patterns.
- shuffle: Whether to shuffle the order of file names.
-
- Returns:
- List of file names matching `file_pattern`.
-
- Raises:
- ValueError: If `file_pattern` is empty, or pattern matches no files.
- """
- if isinstance(file_pattern, list):
- if not file_pattern:
- raise ValueError("File pattern is empty.")
- file_names = []
- for entry in file_pattern:
- file_names.extend(gfile.Glob(entry))
- else:
- file_names = list(gfile.Glob(file_pattern))
-
- if not file_names:
- raise ValueError("No files match %s." % file_pattern)
-
- # Sort files so it will be deterministic for unit tests.
- if not shuffle:
- file_names = sorted(file_names)
- return file_names
-
-
-class SqlDataset(dataset_ops.DatasetSource):
+class SqlDataset(readers.SqlDataset):
"""A `Dataset` consisting of the results from a SQL query."""
+ @deprecation.deprecated(None, "Use `tf.data.experimental.SqlDataset(...)`.")
def __init__(self, driver_name, data_source_name, query, output_types):
- """Creates a `SqlDataset`.
-
- `SqlDataset` allows a user to read data from the result set of a SQL query.
- For example:
-
- ```python
- dataset = tf.contrib.data.SqlDataset("sqlite", "/foo/bar.sqlite3",
- "SELECT name, age FROM people",
- (tf.string, tf.int32))
- iterator = dataset.make_one_shot_iterator()
- next_element = iterator.get_next()
- # Prints the rows of the result set of the above query.
- while True:
- try:
- print(sess.run(next_element))
- except tf.errors.OutOfRangeError:
- break
- ```
-
- Args:
- driver_name: A 0-D `tf.string` tensor containing the database type.
- Currently, the only supported value is 'sqlite'.
- data_source_name: A 0-D `tf.string` tensor containing a connection string
- to connect to the database.
- query: A 0-D `tf.string` tensor containing the SQL query to execute.
- output_types: A tuple of `tf.DType` objects representing the types of the
- columns returned by `query`.
- """
- super(SqlDataset, self).__init__()
- self._driver_name = ops.convert_to_tensor(
- driver_name, dtype=dtypes.string, name="driver_name")
- self._data_source_name = ops.convert_to_tensor(
- data_source_name, dtype=dtypes.string, name="data_source_name")
- self._query = ops.convert_to_tensor(
- query, dtype=dtypes.string, name="query")
- self._output_types = output_types
-
- def _as_variant_tensor(self):
- return gen_dataset_ops.sql_dataset(self._driver_name,
- self._data_source_name, self._query,
- nest.flatten(self.output_types),
- nest.flatten(self.output_shapes))
-
- @property
- def output_classes(self):
- return nest.map_structure(lambda _: ops.Tensor, self._output_types)
-
- @property
- def output_shapes(self):
- return nest.map_structure(lambda _: tensor_shape.TensorShape([]),
- self._output_types)
-
- @property
- def output_types(self):
- return self._output_types
+ super(SqlDataset, self).__init__(
+ driver_name, data_source_name, query, output_types)
class LMDBDataset(dataset_ops.DatasetSource):
diff --git a/tensorflow/contrib/data/python/ops/resampling.py b/tensorflow/contrib/data/python/ops/resampling.py
index 75642f143e..29d77528d9 100644
--- a/tensorflow/contrib/data/python/ops/resampling.py
+++ b/tensorflow/contrib/data/python/ops/resampling.py
@@ -17,22 +17,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import numpy as np
-
-from tensorflow.contrib.data.python.ops import batching
-from tensorflow.contrib.data.python.ops import interleave_ops
-from tensorflow.contrib.data.python.ops import scan_ops
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_util
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import logging_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import random_ops
+from tensorflow.python.data.experimental.ops import resampling
+from tensorflow.python.util import deprecation
+@deprecation.deprecated(None,
+ "Use `tf.data.experimental.rejection_resample(...)`.")
def rejection_resample(class_func, target_dist, initial_dist=None, seed=None):
"""A transformation that resamples a dataset to achieve a target distribution.
@@ -52,243 +42,5 @@ def rejection_resample(class_func, target_dist, initial_dist=None, seed=None):
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
- def _apply_fn(dataset):
- """Function from `Dataset` to `Dataset` that applies the transformation."""
- target_dist_t = ops.convert_to_tensor(target_dist, name="target_dist")
- class_values_ds = dataset.map(class_func)
-
- # Get initial distribution.
- if initial_dist is not None:
- initial_dist_t = ops.convert_to_tensor(initial_dist, name="initial_dist")
- acceptance_dist, prob_of_original = (
- _calculate_acceptance_probs_with_mixing(initial_dist_t,
- target_dist_t))
- initial_dist_ds = dataset_ops.Dataset.from_tensors(
- initial_dist_t).repeat()
- acceptance_dist_ds = dataset_ops.Dataset.from_tensors(
- acceptance_dist).repeat()
- prob_of_original_ds = dataset_ops.Dataset.from_tensors(
- prob_of_original).repeat()
- else:
- initial_dist_ds = _estimate_initial_dist_ds(
- target_dist_t, class_values_ds)
- acceptance_and_original_prob_ds = initial_dist_ds.map(
- lambda initial: _calculate_acceptance_probs_with_mixing(
- initial, target_dist_t))
- acceptance_dist_ds = acceptance_and_original_prob_ds.map(
- lambda accept_prob, _: accept_prob)
- prob_of_original_ds = acceptance_and_original_prob_ds.map(
- lambda _, prob_original: prob_original)
- filtered_ds = _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds,
- class_values_ds, seed)
- # Prefetch filtered dataset for speed.
- filtered_ds = filtered_ds.prefetch(3)
-
- prob_original_static = _get_prob_original_static(
- initial_dist_t, target_dist_t) if initial_dist is not None else None
- if prob_original_static == 1:
- return dataset_ops.Dataset.zip((class_values_ds, dataset))
- elif prob_original_static == 0:
- return filtered_ds
- else:
- return interleave_ops.sample_from_datasets(
- [dataset_ops.Dataset.zip((class_values_ds, dataset)), filtered_ds],
- weights=prob_of_original_ds.map(lambda prob: [(prob, 1.0 - prob)]),
- seed=seed)
-
- return _apply_fn
-
-
-def _get_prob_original_static(initial_dist_t, target_dist_t):
- """Returns the static probability of sampling from the original.
-
- `tensor_util.constant_value(prob_of_original)` returns `None` if it encounters
- an Op that it isn't defined for. We have some custom logic to avoid this.
-
- Args:
- initial_dist_t: A tensor of the initial distribution.
- target_dist_t: A tensor of the target distribution.
-
- Returns:
- The probability of sampling from the original distribution as a constant,
- if it is a constant, or `None`.
- """
- init_static = tensor_util.constant_value(initial_dist_t)
- target_static = tensor_util.constant_value(target_dist_t)
-
- if init_static is None or target_static is None:
- return None
- else:
- return np.min(target_static / init_static)
-
-
-def _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds, class_values_ds,
- seed):
- """Filters a dataset based on per-class acceptance probabilities.
-
- Args:
- dataset: The dataset to be filtered.
- acceptance_dist_ds: A dataset of acceptance probabilities.
- initial_dist_ds: A dataset of the initial probability distribution, given or
- estimated.
- class_values_ds: A dataset of the corresponding classes.
- seed: (Optional.) Python integer seed for the resampler.
-
- Returns:
- A dataset of (class value, data) after filtering.
- """
- def maybe_warn_on_large_rejection(accept_dist, initial_dist):
- proportion_rejected = math_ops.reduce_sum((1 - accept_dist) * initial_dist)
- return control_flow_ops.cond(
- math_ops.less(proportion_rejected, .5),
- lambda: accept_dist,
- lambda: logging_ops.Print( # pylint: disable=g-long-lambda
- accept_dist, [proportion_rejected, initial_dist, accept_dist],
- message="Proportion of examples rejected by sampler is high: ",
- summarize=100,
- first_n=10))
-
- acceptance_dist_ds = (dataset_ops.Dataset.zip((acceptance_dist_ds,
- initial_dist_ds))
- .map(maybe_warn_on_large_rejection))
-
- def _gather_and_copy(class_val, acceptance_prob, data):
- return class_val, array_ops.gather(acceptance_prob, class_val), data
-
- current_probabilities_and_class_and_data_ds = dataset_ops.Dataset.zip(
- (class_values_ds, acceptance_dist_ds, dataset)).map(_gather_and_copy)
- filtered_ds = (
- current_probabilities_and_class_and_data_ds
- .filter(lambda _1, p, _2: random_ops.random_uniform([], seed=seed) < p))
- return filtered_ds.map(lambda class_value, _, data: (class_value, data))
-
-
-def _estimate_initial_dist_ds(
- target_dist_t, class_values_ds, dist_estimation_batch_size=32,
- smoothing_constant=10):
- num_classes = (target_dist_t.shape[0].value or
- array_ops.shape(target_dist_t)[0])
- initial_examples_per_class_seen = array_ops.fill(
- [num_classes], np.int64(smoothing_constant))
-
- def update_estimate_and_tile(num_examples_per_class_seen, c):
- updated_examples_per_class_seen, dist = _estimate_data_distribution(
- c, num_examples_per_class_seen)
- tiled_dist = array_ops.tile(
- array_ops.expand_dims(dist, 0), [dist_estimation_batch_size, 1])
- return updated_examples_per_class_seen, tiled_dist
-
- initial_dist_ds = (class_values_ds.batch(dist_estimation_batch_size)
- .apply(scan_ops.scan(initial_examples_per_class_seen,
- update_estimate_and_tile))
- .apply(batching.unbatch()))
-
- return initial_dist_ds
-
-
-def _get_target_to_initial_ratio(initial_probs, target_probs):
- # Add tiny to initial_probs to avoid divide by zero.
- denom = (initial_probs + np.finfo(initial_probs.dtype.as_numpy_dtype).tiny)
- return target_probs / denom
-
-
-def _estimate_data_distribution(c, num_examples_per_class_seen):
- """Estimate data distribution as labels are seen.
-
- Args:
- c: The class labels. Type `int32`, shape `[batch_size]`.
- num_examples_per_class_seen: Type `int64`, shape `[num_classes]`,
- containing counts.
-
- Returns:
- num_examples_per_lass_seen: Updated counts. Type `int64`, shape
- `[num_classes]`.
- dist: The updated distribution. Type `float32`, shape `[num_classes]`.
- """
- num_classes = num_examples_per_class_seen.get_shape()[0].value
- # Update the class-count based on what labels are seen in batch.
- num_examples_per_class_seen = math_ops.add(
- num_examples_per_class_seen, math_ops.reduce_sum(
- array_ops.one_hot(c, num_classes, dtype=dtypes.int64), 0))
- init_prob_estimate = math_ops.truediv(
- num_examples_per_class_seen,
- math_ops.reduce_sum(num_examples_per_class_seen))
- dist = math_ops.cast(init_prob_estimate, dtypes.float32)
- return num_examples_per_class_seen, dist
-
-
-def _calculate_acceptance_probs_with_mixing(initial_probs, target_probs):
- """Calculates the acceptance probabilities and mixing ratio.
-
- In this case, we assume that we can *either* sample from the original data
- distribution with probability `m`, or sample from a reshaped distribution
- that comes from rejection sampling on the original distribution. This
- rejection sampling is done on a per-class basis, with `a_i` representing the
- probability of accepting data from class `i`.
-
- This method is based on solving the following analysis for the reshaped
- distribution:
-
- Let F be the probability of a rejection (on any example).
- Let p_i be the proportion of examples in the data in class i (init_probs)
- Let a_i is the rate the rejection sampler should *accept* class i
- Let t_i is the target proportion in the minibatches for class i (target_probs)
-
- ```
- F = sum_i(p_i * (1-a_i))
- = 1 - sum_i(p_i * a_i) using sum_i(p_i) = 1
- ```
-
- An example with class `i` will be accepted if `k` rejections occur, then an
- example with class `i` is seen by the rejector, and it is accepted. This can
- be written as follows:
-
- ```
- t_i = sum_k=0^inf(F^k * p_i * a_i)
- = p_i * a_j / (1 - F) using geometric series identity, since 0 <= F < 1
- = p_i * a_i / sum_j(p_j * a_j) using F from above
- ```
-
- Note that the following constraints hold:
- ```
- 0 <= p_i <= 1, sum_i(p_i) = 1
- 0 <= a_i <= 1
- 0 <= t_i <= 1, sum_i(t_i) = 1
- ```
-
- A solution for a_i in terms of the other variables is the following:
- ```a_i = (t_i / p_i) / max_i[t_i / p_i]```
-
- If we try to minimize the amount of data rejected, we get the following:
-
- M_max = max_i [ t_i / p_i ]
- M_min = min_i [ t_i / p_i ]
-
- The desired probability of accepting data if it comes from class `i`:
-
- a_i = (t_i/p_i - m) / (M_max - m)
-
- The desired probability of pulling a data element from the original dataset,
- rather than the filtered one:
-
- m = M_min
-
- Args:
- initial_probs: A Tensor of the initial probability distribution, given or
- estimated.
- target_probs: A Tensor of the corresponding classes.
-
- Returns:
- (A 1D Tensor with the per-class acceptance probabilities, the desired
- probability of pull from the original distribution.)
- """
- ratio_l = _get_target_to_initial_ratio(initial_probs, target_probs)
- max_ratio = math_ops.reduce_max(ratio_l)
- min_ratio = math_ops.reduce_min(ratio_l)
-
- # Target prob to sample from original distribution.
- m = min_ratio
-
- # TODO(joelshor): Simplify fraction, if possible.
- a_i = (ratio_l - m) / (max_ratio - m)
- return a_i, m
+ return resampling.rejection_resample(class_func, target_dist, initial_dist,
+ seed)
diff --git a/tensorflow/contrib/data/python/ops/scan_ops.py b/tensorflow/contrib/data/python/ops/scan_ops.py
index c52582cd35..0ca9fddb23 100644
--- a/tensorflow/contrib/data/python/ops/scan_ops.py
+++ b/tensorflow/contrib/data/python/ops/scan_ops.py
@@ -17,137 +17,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import collections
-
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.util import nest
-from tensorflow.python.data.util import sparse
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.ops import gen_dataset_ops
-
-
-class _ScanDataset(dataset_ops.UnaryDataset):
- """A dataset that scans a function across its input."""
-
- def __init__(self, input_dataset, initial_state, scan_func):
- """See `scan()` for details."""
- super(_ScanDataset, self).__init__(input_dataset)
- self._input_dataset = input_dataset
-
- with ops.name_scope("initial_state"):
- # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
- # values to tensors.
- self._initial_state = nest.pack_sequence_as(initial_state, [
- sparse_tensor.SparseTensor.from_value(t)
- if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(
- t, name="component_%d" % i)
- for i, t in enumerate(nest.flatten(initial_state))
- ])
-
- # Compute initial values for the state classes, shapes and types based on
- # the initial state. The shapes may be refined by running `tf_scan_func` one
- # or more times below.
- self._state_classes = sparse.get_classes(self._initial_state)
- self._state_shapes = nest.pack_sequence_as(
- self._initial_state,
- [t.get_shape() for t in nest.flatten(self._initial_state)])
- self._state_types = nest.pack_sequence_as(
- self._initial_state,
- [t.dtype for t in nest.flatten(self._initial_state)])
-
- # Will be populated by calling `tf_scan_func`.
- self._output_classes = None
- self._output_shapes = None
- self._output_types = None
-
- # Iteratively rerun the scan function until reaching a fixed point on
- # `self._state_shapes`.
- need_to_rerun = True
- while need_to_rerun:
-
- wrapped_func = dataset_ops.StructuredFunctionWrapper(
- scan_func, "tf.contrib.data.scan()",
- input_classes=(self._state_classes, input_dataset.output_classes),
- input_shapes=(self._state_shapes, input_dataset.output_shapes),
- input_types=(self._state_types, input_dataset.output_types),
- add_to_graph=False)
- if not (
- isinstance(wrapped_func.output_types, collections.Sequence) and
- len(wrapped_func.output_types) == 2):
- raise TypeError("The scan function must return a pair comprising the "
- "new state and the output value.")
-
- new_state_classes, self._output_classes = wrapped_func.output_classes
-
- # Extract and validate class information from the returned values.
- for new_state_class, state_class in zip(
- nest.flatten(new_state_classes),
- nest.flatten(self._state_classes)):
- if not issubclass(new_state_class, state_class):
- raise TypeError(
- "The element classes for the new state must match the initial "
- "state. Expected %s; got %s." %
- (self._state_classes, new_state_classes))
-
- # Extract and validate type information from the returned values.
- new_state_types, self._output_types = wrapped_func.output_types
- for new_state_type, state_type in zip(
- nest.flatten(new_state_types), nest.flatten(self._state_types)):
- if new_state_type != state_type:
- raise TypeError(
- "The element types for the new state must match the initial "
- "state. Expected %s; got %s." %
- (self._state_types, new_state_types))
-
- # Extract shape information from the returned values.
- new_state_shapes, self._output_shapes = wrapped_func.output_shapes
-
- flat_state_shapes = nest.flatten(self._state_shapes)
- flat_new_state_shapes = nest.flatten(new_state_shapes)
- weakened_state_shapes = [
- original.most_specific_compatible_shape(new)
- for original, new in zip(flat_state_shapes, flat_new_state_shapes)
- ]
-
- need_to_rerun = False
- for original_shape, weakened_shape in zip(flat_state_shapes,
- weakened_state_shapes):
- if original_shape.ndims is not None and (
- weakened_shape.ndims is None or
- original_shape.as_list() != weakened_shape.as_list()):
- need_to_rerun = True
- break
-
- if need_to_rerun:
- self._state_shapes = nest.pack_sequence_as(self._state_shapes,
- weakened_state_shapes)
-
- self._scan_func = wrapped_func.function
- self._scan_func.add_to_graph(ops.get_default_graph())
-
- def _as_variant_tensor(self):
- input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access
- return gen_dataset_ops.scan_dataset(
- input_t,
- nest.flatten(sparse.serialize_sparse_tensors(self._initial_state)),
- self._scan_func.captured_inputs,
- f=self._scan_func,
- **dataset_ops.flat_structure(self))
-
- @property
- def output_classes(self):
- return self._output_classes
-
- @property
- def output_shapes(self):
- return self._output_shapes
-
- @property
- def output_types(self):
- return self._output_types
+from tensorflow.python.data.experimental.ops import scan_ops
+from tensorflow.python.util import deprecation
+@deprecation.deprecated(None, "Use `tf.data.experimental.scan(...)`.")
def scan(initial_state, scan_func):
"""A transformation that scans a function across an input dataset.
@@ -168,7 +42,4 @@ def scan(initial_state, scan_func):
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
- def _apply_fn(dataset):
- return _ScanDataset(dataset, initial_state, scan_func)
-
- return _apply_fn
+ return scan_ops.scan(initial_state, scan_func)
diff --git a/tensorflow/contrib/data/python/ops/shuffle_ops.py b/tensorflow/contrib/data/python/ops/shuffle_ops.py
index 985d1d87d0..329b34fdfe 100644
--- a/tensorflow/contrib/data/python/ops/shuffle_ops.py
+++ b/tensorflow/contrib/data/python/ops/shuffle_ops.py
@@ -17,54 +17,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.util import random_seed
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import gen_dataset_ops
-
-
-class _ShuffleAndRepeatDataset(dataset_ops.UnaryDataset):
- """A `Dataset` that fuses `shuffle` and `repeat`."""
-
- def __init__(self, input_dataset, buffer_size, count=None, seed=None):
- super(_ShuffleAndRepeatDataset, self).__init__(input_dataset)
- self._input_dataset = input_dataset
- self._buffer_size = ops.convert_to_tensor(
- buffer_size, dtype=dtypes.int64, name="buffer_size")
- if count is None:
- self._count = constant_op.constant(-1, dtype=dtypes.int64, name="count")
- else:
- self._count = ops.convert_to_tensor(
- count, dtype=dtypes.int64, name="count")
- self._seed, self._seed2 = random_seed.get_seed(seed)
-
- def _as_variant_tensor(self):
- # pylint: disable=protected-access
- input_resource = self._input_dataset._as_variant_tensor()
- return gen_dataset_ops.shuffle_and_repeat_dataset(
- input_resource,
- buffer_size=self._buffer_size,
- count=self._count,
- seed=self._seed,
- seed2=self._seed2,
- **dataset_ops.flat_structure(self))
- # pylint: enable=protected-access
-
- @property
- def output_classes(self):
- return self._input_dataset.output_classes
-
- @property
- def output_shapes(self):
- return self._input_dataset.output_shapes
-
- @property
- def output_types(self):
- return self._input_dataset.output_types
+from tensorflow.python.data.experimental.ops import shuffle_ops
+from tensorflow.python.util import deprecation
+@deprecation.deprecated(None,
+ "Use `tf.data.experimental.shuffle_and_repeat(...)`.")
def shuffle_and_repeat(buffer_size, count=None, seed=None):
"""Shuffles and repeats a Dataset returning a new permutation for each epoch.
@@ -93,8 +51,4 @@ def shuffle_and_repeat(buffer_size, count=None, seed=None):
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
-
- def _apply_fn(dataset): # pylint: disable=missing-docstring
- return _ShuffleAndRepeatDataset(dataset, buffer_size, count, seed)
-
- return _apply_fn
+ return shuffle_ops.shuffle_and_repeat(buffer_size, count, seed)
diff --git a/tensorflow/contrib/data/python/ops/stats_ops.py b/tensorflow/contrib/data/python/ops/stats_ops.py
deleted file mode 100644
index bc47c5989d..0000000000
--- a/tensorflow/contrib/data/python/ops/stats_ops.py
+++ /dev/null
@@ -1,201 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Experimental API for gathering statistics from `tf.data` pipelines."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import gen_dataset_ops
-
-
-class StatsAggregator(object):
- """A stateful resource that aggregates statistics from one or more iterators.
-
- To record statistics, use one of the custom transformation functions defined
- in this module when defining your `tf.data.Dataset`. All statistics will be
- aggregated by the `StatsAggregator` that is associated with a particular
- iterator (see below). For example, to record the latency of producing each
- element by iterating over a dataset:
-
- ```python
- dataset = ...
- dataset = dataset.apply(stats_ops.latency_stats("total_bytes"))
- ```
-
- To associate a `StatsAggregator` with a `tf.data.Dataset` object, use
- the following pattern:
-
- ```python
- stats_aggregator = stats_ops.StatsAggregator()
- dataset = ...
-
- # Apply `set_stats_aggregator` to associate `dataset` with `stats_aggregator`.
- dataset = dataset.apply(
- tf.contrib.data.set_stats_aggregator(stats_aggregator))
- iterator = dataset.make_one_shot_iterator()
- ```
-
- To get a protocol buffer summary of the currently aggregated statistics,
- use the `StatsAggregator.get_summary()` tensor. The easiest way to do this
- is to add the returned tensor to the `tf.GraphKeys.SUMMARIES` collection,
- so that the summaries will be included with any existing summaries.
-
- ```python
- stats_aggregator = stats_ops.StatsAggregator()
- # ...
- stats_summary = stats_aggregator.get_summary()
- tf.add_to_collection(tf.GraphKeys.SUMMARIES, stats_summary)
- ```
-
- Note: This interface is experimental and expected to change. In particular,
- we expect to add other implementations of `StatsAggregator` that provide
- different ways of exporting statistics, and add more types of statistics.
- """
-
- def __init__(self):
- """Creates a `StatsAggregator`."""
- self._resource = gen_dataset_ops.stats_aggregator_handle()
-
- # TODO(b/116314787): Update this/add support for V2 summary API.
- def get_summary(self):
- """Returns a string `tf.Tensor` that summarizes the aggregated statistics.
-
- The returned tensor will contain a serialized `tf.summary.Summary` protocol
- buffer, which can be used with the standard TensorBoard logging facilities.
-
- Returns:
- A scalar string `tf.Tensor` that summarizes the aggregated statistics.
- """
- return gen_dataset_ops.stats_aggregator_summary(self._resource)
-
-
-class _SetStatsAggregatorDataset(dataset_ops.UnaryDataset):
- """A `Dataset` that acts as an identity, and sets given stats_aggregator."""
-
- def __init__(self, input_dataset, stats_aggregator):
- super(_SetStatsAggregatorDataset, self).__init__(input_dataset)
- self._input_dataset = input_dataset
- self._stats_aggregator = stats_aggregator
-
- def _as_variant_tensor(self):
- return gen_dataset_ops.set_stats_aggregator_dataset(
- self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- self._stats_aggregator._resource, # pylint: disable=protected-access
- **dataset_ops.flat_structure(self))
-
- @property
- def output_shapes(self):
- return self._input_dataset.output_shapes
-
- @property
- def output_types(self):
- return self._input_dataset.output_types
-
- @property
- def output_classes(self):
- return self._input_dataset.output_classes
-
-
-def set_stats_aggregator(stats_aggregator):
- """Set the given `stats_aggregator` for aggregating the input dataset stats.
-
- Args:
- stats_aggregator: A `tf.contrib.data.StatsAggregator` object.
-
- Returns:
- A `Dataset` transformation function, which can be passed to
- `tf.data.Dataset.apply`.
- """
-
- def _apply_fn(dataset):
- return _SetStatsAggregatorDataset(dataset, stats_aggregator)
-
- return _apply_fn
-
-
-# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable
-# or make private / remove.
-def bytes_produced_stats(tag):
- """Records the number of bytes produced by each element of the input dataset.
-
- To consume the statistics, associate a `StatsAggregator` with the output
- dataset.
-
- Args:
- tag: String. All statistics recorded by the returned transformation will
- be associated with the given `tag`.
-
- Returns:
- A `Dataset` transformation function, which can be passed to
- `tf.data.Dataset.apply`.
- """
-
- def _apply_fn(dataset):
- return _StatsDataset(dataset, gen_dataset_ops.bytes_produced_stats_dataset,
- tag)
-
- return _apply_fn
-
-
-def latency_stats(tag):
- """Records the latency of producing each element of the input dataset.
-
- To consume the statistics, associate a `StatsAggregator` with the output
- dataset.
-
- Args:
- tag: String. All statistics recorded by the returned transformation will
- be associated with the given `tag`.
-
- Returns:
- A `Dataset` transformation function, which can be passed to
- `tf.data.Dataset.apply`.
- """
-
- def _apply_fn(dataset):
- return _StatsDataset(dataset, gen_dataset_ops.latency_stats_dataset, tag)
-
- return _apply_fn
-
-
-class _StatsDataset(dataset_ops.UnaryDataset):
- """A `Dataset` that acts as an identity, and also records statistics."""
-
- def __init__(self, input_dataset, op_function, tag):
- super(_StatsDataset, self).__init__(input_dataset)
- self._input_dataset = input_dataset
- self._op_function = op_function
- self._tag = ops.convert_to_tensor(tag, dtype=dtypes.string)
-
- def _as_variant_tensor(self):
- return self._op_function(
- self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- self._tag,
- **dataset_ops.flat_structure(self))
-
- @property
- def output_shapes(self):
- return self._input_dataset.output_shapes
-
- @property
- def output_types(self):
- return self._input_dataset.output_types
-
- @property
- def output_classes(self):
- return self._input_dataset.output_classes
diff --git a/tensorflow/contrib/data/python/ops/threadpool.py b/tensorflow/contrib/data/python/ops/threadpool.py
index f73c3fd9cb..20cceb4647 100644
--- a/tensorflow/contrib/data/python/ops/threadpool.py
+++ b/tensorflow/contrib/data/python/ops/threadpool.py
@@ -17,88 +17,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import threading
-
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.eager import context
-from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
-from tensorflow.python.ops import resource_variable_ops
-
-_uid_counter = 0
-_uid_lock = threading.Lock()
-
-
-def _generate_shared_name(prefix):
- with _uid_lock:
- global _uid_counter
- uid = _uid_counter
- _uid_counter += 1
- return "{}{}".format(prefix, uid)
-
-
-# TODO(b/73383364): Properly export in the `tf.contrib.data` API when stable
-# or make private / remove.
-class PrivateThreadPool(object):
- """A stateful resource that represents a private thread pool."""
-
- def __init__(self, num_threads, display_name=None,
- max_intra_op_parallelism=1):
- """Creates a `PrivateThreadPool` with the given number of threads."""
- if context.executing_eagerly():
- shared_name = _generate_shared_name("privatethreadpool")
- self._resource = ged_ops.experimental_thread_pool_handle(
- num_threads=num_threads,
- max_intra_op_parallelism=max_intra_op_parallelism,
- display_name=display_name,
- shared_name=shared_name)
- self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
- handle=self._resource, handle_device=context.context().device_name)
- else:
- self._resource = ged_ops.experimental_thread_pool_handle(
- num_threads=num_threads,
- max_intra_op_parallelism=max_intra_op_parallelism,
- display_name=display_name)
-
-
-class _ThreadPoolDataset(dataset_ops.UnaryDataset):
- """A `Dataset` that acts as an identity, and sets a custom threadpool."""
-
- def __init__(self, input_dataset, thread_pool):
- super(_ThreadPoolDataset, self).__init__(input_dataset)
- self._input_dataset = input_dataset
- self._thread_pool = thread_pool
-
- def _as_variant_tensor(self):
- return ged_ops.experimental_thread_pool_dataset(
- self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- self._thread_pool._resource, # pylint: disable=protected-access
- **dataset_ops.flat_structure(self))
-
- @property
- def output_shapes(self):
- return self._input_dataset.output_shapes
-
- @property
- def output_types(self):
- return self._input_dataset.output_types
-
- @property
- def output_classes(self):
- return self._input_dataset.output_classes
-
-
-# TODO(b/73383364): Properly export in the `tf.contrib.data` API when stable
-# or make private / remove.
-def override_threadpool(dataset, thread_pool):
- """Returns a new dataset that uses the given thread pool for its operations.
-
- Args:
- dataset: A `tf.data.Dataset` object.
- thread_pool: A `PrivateThreadPool` object.
-
- Returns:
- A dataset containing the same values as `dataset`, but which uses
- `thread_pool` to compute any of its parallel operations (such as
- `tf.data.Dataset.map`).
- """
- return _ThreadPoolDataset(dataset, thread_pool)
+# pylint: disable=unused-import
+from tensorflow.python.data.experimental.ops.threadpool import override_threadpool
+from tensorflow.python.data.experimental.ops.threadpool import PrivateThreadPool
diff --git a/tensorflow/contrib/data/python/ops/unique.py b/tensorflow/contrib/data/python/ops/unique.py
index ed363a7090..909d06c677 100644
--- a/tensorflow/contrib/data/python/ops/unique.py
+++ b/tensorflow/contrib/data/python/ops/unique.py
@@ -17,11 +17,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import dtypes
-from tensorflow.python.ops import gen_experimental_dataset_ops
+from tensorflow.python.data.experimental.ops import unique as experimental_unique
+from tensorflow.python.util import deprecation
+@deprecation.deprecated(None, "Use `tf.data.experimental.unique()`.")
def unique():
"""Creates a `Dataset` from another `Dataset`, discarding duplicates.
@@ -39,39 +39,4 @@ def unique():
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
-
- def _apply_fn(dataset):
- return _UniqueDataset(dataset)
-
- return _apply_fn
-
-
-class _UniqueDataset(dataset_ops.UnaryDataset):
- """A `Dataset` contains the unique elements from its input."""
-
- def __init__(self, input_dataset):
- """See `unique()` for details."""
- super(_UniqueDataset, self).__init__(input_dataset)
- self._input_dataset = input_dataset
- if input_dataset.output_types not in (dtypes.int32, dtypes.int64,
- dtypes.string):
- raise TypeError(
- "`tf.contrib.data.unique()` only supports inputs with a single "
- "`tf.int32`, `tf.int64`, or `tf.string` component.")
-
- def _as_variant_tensor(self):
- return gen_experimental_dataset_ops.experimental_unique_dataset(
- self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- **dataset_ops.flat_structure(self))
-
- @property
- def output_classes(self):
- return self._input_dataset.output_classes
-
- @property
- def output_shapes(self):
- return self._input_dataset.output_shapes
-
- @property
- def output_types(self):
- return self._input_dataset.output_types
+ return experimental_unique.unique()
diff --git a/tensorflow/contrib/data/python/ops/writers.py b/tensorflow/contrib/data/python/ops/writers.py
index c455fdcba6..42fb69bf07 100644
--- a/tensorflow/contrib/data/python/ops/writers.py
+++ b/tensorflow/contrib/data/python/ops/writers.py
@@ -17,42 +17,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.util import convert
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.data.experimental.ops import writers
+from tensorflow.python.util import deprecation
-class TFRecordWriter(object):
+class TFRecordWriter(writers.TFRecordWriter):
"""Writes data to a TFRecord file."""
+ @deprecation.deprecated(
+ None, "Use `tf.data.experimental.TFRecordWriter(...)`.")
def __init__(self, filename, compression_type=None):
- self._filename = ops.convert_to_tensor(
- filename, dtypes.string, name="filename")
- self._compression_type = convert.optional_param_to_tensor(
- "compression_type",
- compression_type,
- argument_default="",
- argument_dtype=dtypes.string)
-
- def write(self, dataset):
- """Returns a `tf.Operation` to write a dataset to a file.
-
- Args:
- dataset: a `tf.data.Dataset` whose elements are to be written to a file
-
- Returns:
- A `tf.Operation` that, when run, writes contents of `dataset` to a file.
- """
- if not isinstance(dataset, dataset_ops.Dataset):
- raise TypeError("`dataset` must be a `tf.data.Dataset` object.")
- if (dataset.output_types != dtypes.string or
- dataset.output_shapes != tensor_shape.scalar()):
- raise TypeError(
- "`dataset` must produce scalar `DT_STRING` tensors whereas it "
- "produces shape {0} and types {1}".format(dataset.output_shapes,
- dataset.output_types))
- return gen_dataset_ops.dataset_to_tf_record(
- dataset._as_variant_tensor(), self._filename, self._compression_type) # pylint: disable=protected-access
+ super(TFRecordWriter, self).__init__(filename, compression_type)