diff options
author | Derek Murray <mrry@google.com> | 2018-10-01 16:45:11 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-01 16:50:05 -0700 |
commit | b72265dc002e712fc3d0f33434f13c7a36a484b2 (patch) | |
tree | f92d1f23c329654772f95d93f5cf4458741b72df /tensorflow/python/data | |
parent | bb1f9e1a57c8bc18325b3c86298be96e6647a0a3 (diff) |
[tf.data] Deprecate `tf.contrib.data` and introduce `tf.data.experimental` to replace it.
This change prepares `tf.data` for TensorFlow 2.0, where `tf.contrib` will no longer exist. It retains the pre-existing endpoints in `tf.contrib.data` with deprecation warnings.
Note there are some exceptions to the move:
* Deprecated symbols in `tf.contrib.data` have not been moved to `tf.data.experimental`, because replacements already exist.
* `tf.contrib.data.LMDBDataset` has not been moved, because we plan to move it to a SIG-maintained repository.
* `tf.contrib.data.assert_element_shape()` has not yet been moved, because it depends on functionality in `tf.contrib`, and it will move in a later change.
* `tf.contrib.data.AUTOTUNE` has not yet been moved, because we have not yet determined how to `tf_export()` a Python integer.
* The stats-related API endpoints have not yet appeared in a released version of TensorFlow, so these are moved to `tf.data.experimental` without retaining an endpoint in `tf.contrib.data`.
In addition, this change includes some build rule and ApiDef refactoring:
* Some of the "//third_party/tensorflow/python:training" dependencies had to be split in order to avoid a circular dependency.
* The `tf.contrib.stateless` ops now have a private core library for the generated wrappers (and accordingly are hidden in their ApiDef) so that `tf.data.experimental.sample_from_datasets()` can depend on them.
PiperOrigin-RevId: 215304249
Diffstat (limited to 'tensorflow/python/data')
106 files changed, 21452 insertions, 6 deletions
diff --git a/tensorflow/python/data/BUILD b/tensorflow/python/data/BUILD index 138141f4fc..e32eeecbb8 100644 --- a/tensorflow/python/data/BUILD +++ b/tensorflow/python/data/BUILD @@ -10,6 +10,7 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/python:util", + "//tensorflow/python/data/experimental", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:iterator_ops", "//tensorflow/python/data/ops:multi_device_iterator_ops", diff --git a/tensorflow/python/data/__init__.py b/tensorflow/python/data/__init__.py index f8b561205e..7536ba668a 100644 --- a/tensorflow/python/data/__init__.py +++ b/tensorflow/python/data/__init__.py @@ -22,6 +22,7 @@ from __future__ import division from __future__ import print_function # pylint: disable=unused-import +from tensorflow.python.data import experimental from tensorflow.python.data.ops.dataset_ops import Dataset from tensorflow.python.data.ops.iterator_ops import Iterator from tensorflow.python.data.ops.readers import FixedLengthRecordDataset diff --git a/tensorflow/python/data/experimental/BUILD b/tensorflow/python/data/experimental/BUILD new file mode 100644 index 0000000000..84e761d376 --- /dev/null +++ b/tensorflow/python/data/experimental/BUILD @@ -0,0 +1,16 @@ +package(default_visibility = ["//tensorflow:internal"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +py_library( + name = "experimental", + srcs = ["__init__.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:util", + "//tensorflow/python/data/experimental/ops:dataset_ops", + "//tensorflow/python/data/experimental/ops:iterator_ops", + ], +) diff --git a/tensorflow/python/data/experimental/__init__.py b/tensorflow/python/data/experimental/__init__.py new file mode 100644 index 0000000000..2ac159d38a --- /dev/null +++ b/tensorflow/python/data/experimental/__init__.py @@ -0,0 +1,109 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Experimental API for building input pipelines. + +This module contains experimental `Dataset` sources and transformations that can +be used in conjunction with the `tf.data.Dataset` API. Note that the +`tf.data.experimental` API is not subject to the same backwards compatibility +guarantees as `tf.data`, but we will provide deprecation advice in advance of +removing existing functionality. + +See [Importing Data](https://tensorflow.org/guide/datasets) for an overview. + +@@Counter +@@CheckpointInputPipelineHook +@@CsvDataset +@@Optional +@@RandomDataset +@@Reducer +@@SqlDataset +@@TFRecordWriter + +@@bucket_by_sequence_length +@@choose_from_datasets +@@copy_to_device +@@dense_to_sparse_batch +@@enumerate_dataset +@@get_next_as_optional +@@get_single_element +@@group_by_reducer +@@group_by_window +@@ignore_errors +@@latency_stats +@@make_batched_features_dataset +@@make_csv_dataset +@@make_saveable_from_iterator +@@map_and_batch +@@parallel_interleave +@@parse_example_dataset +@@prefetch_to_device +@@rejection_resample +@@sample_from_datasets +@@scan +@@set_stats_aggregator +@@shuffle_and_repeat +@@StatsAggregator +@@unbatch +@@unique +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import + +from tensorflow.python.data.experimental.ops.batching import dense_to_sparse_batch +from tensorflow.python.data.experimental.ops.batching import map_and_batch +from tensorflow.python.data.experimental.ops.batching import unbatch +from tensorflow.python.data.experimental.ops.counter import Counter +from tensorflow.python.data.experimental.ops.enumerate_ops import enumerate_dataset +from tensorflow.python.data.experimental.ops.error_ops import ignore_errors +from tensorflow.python.data.experimental.ops.get_single_element import get_single_element +from tensorflow.python.data.experimental.ops.grouping import bucket_by_sequence_length +from tensorflow.python.data.experimental.ops.grouping import group_by_reducer +from tensorflow.python.data.experimental.ops.grouping import group_by_window +from tensorflow.python.data.experimental.ops.grouping import Reducer +from tensorflow.python.data.experimental.ops.interleave_ops import choose_from_datasets +from tensorflow.python.data.experimental.ops.interleave_ops import parallel_interleave +from tensorflow.python.data.experimental.ops.interleave_ops import sample_from_datasets +from tensorflow.python.data.experimental.ops.iterator_ops import CheckpointInputPipelineHook +from tensorflow.python.data.experimental.ops.iterator_ops import make_saveable_from_iterator + +# Optimization constant that can be used to enable auto-tuning. +from tensorflow.python.data.experimental.ops.optimization import AUTOTUNE + +from tensorflow.python.data.experimental.ops.parsing_ops import parse_example_dataset +from tensorflow.python.data.experimental.ops.prefetching_ops import copy_to_device +from tensorflow.python.data.experimental.ops.prefetching_ops import prefetch_to_device +from tensorflow.python.data.experimental.ops.random_ops import RandomDataset +from tensorflow.python.data.experimental.ops.readers import CsvDataset +from tensorflow.python.data.experimental.ops.readers import make_batched_features_dataset +from tensorflow.python.data.experimental.ops.readers import make_csv_dataset +from tensorflow.python.data.experimental.ops.readers import SqlDataset +from tensorflow.python.data.experimental.ops.resampling import rejection_resample +from tensorflow.python.data.experimental.ops.scan_ops import scan +from tensorflow.python.data.experimental.ops.shuffle_ops import shuffle_and_repeat +from tensorflow.python.data.experimental.ops.stats_ops import latency_stats +from tensorflow.python.data.experimental.ops.stats_ops import set_stats_aggregator +from tensorflow.python.data.experimental.ops.stats_ops import StatsAggregator +from tensorflow.python.data.experimental.ops.unique import unique +from tensorflow.python.data.experimental.ops.writers import TFRecordWriter +from tensorflow.python.data.ops.iterator_ops import get_next_as_optional +from tensorflow.python.data.ops.optional_ops import Optional +# pylint: enable=unused-import + +from tensorflow.python.util.all_util import remove_undocumented +remove_undocumented(__name__) diff --git a/tensorflow/python/data/experimental/kernel_tests/BUILD b/tensorflow/python/data/experimental/kernel_tests/BUILD new file mode 100644 index 0000000000..a46c30ed2e --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/BUILD @@ -0,0 +1,569 @@ +package(default_visibility = ["//tensorflow:internal"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "cuda_py_test") +load("//tensorflow:tensorflow.bzl", "py_test") + +py_test( + name = "batch_dataset_op_test", + size = "medium", + srcs = ["batch_dataset_op_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_oss", # (b/79552534) + "no_pip", + ], + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:script_ops", + "//tensorflow/python:session", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:string_ops", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:util", + "//tensorflow/python/data/experimental/ops:batching", + "//tensorflow/python/data/kernel_tests:test_base", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) + +py_test( + name = "bucketing_test", + size = "medium", + srcs = ["bucketing_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:string_ops", + "//tensorflow/python:tensor_shape", + "//tensorflow/python/data/experimental/ops:grouping", + "//tensorflow/python/data/kernel_tests:test_base", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "csv_dataset_op_test", + size = "medium", + srcs = ["csv_dataset_op_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:platform", + "//tensorflow/python:platform_test", + "//tensorflow/python:session", + "//tensorflow/python/data/experimental/ops:error_ops", + "//tensorflow/python/data/experimental/ops:readers", + "//tensorflow/python/data/kernel_tests:test_base", + "//tensorflow/python/data/ops:readers", + "//tensorflow/python/eager:context", + "//third_party/py/numpy", + ], +) + +py_test( + name = "dataset_constructor_op_test", + size = "medium", + srcs = ["dataset_constructor_op_test.py"], + srcs_version = "PY2AND3", + tags = [ + "manual", + "nomac", # b/62040583 + ], + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python/data/experimental/ops:batching", + "//tensorflow/python/data/kernel_tests:test_base", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + ], +) + +py_test( + name = "directed_interleave_dataset_test", + size = "medium", + srcs = ["directed_interleave_dataset_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python:random_seed", + "//tensorflow/python/data/experimental/ops:interleave_ops", + "//tensorflow/python/data/kernel_tests:test_base", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "get_single_element_test", + size = "small", + srcs = ["get_single_element_test.py"], + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/experimental/ops:get_single_element", + "//tensorflow/python/data/experimental/ops:grouping", + "//tensorflow/python/data/kernel_tests:test_base", + "//tensorflow/python/data/ops:dataset_ops", + "@absl_py//absl/testing:parameterized", + ], +) + +py_test( + name = "indexed_dataset_ops_test", + srcs = ["indexed_dataset_ops_test.py"], + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:experimental_dataset_ops_gen", + "//tensorflow/python/data/experimental/ops:indexed_dataset_ops", + "//tensorflow/python/data/kernel_tests:test_base", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "interleave_dataset_op_test", + size = "medium", + srcs = ["interleave_dataset_op_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_oss", + "no_pip", + "notap", + ], + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:math_ops", + "//tensorflow/python:script_ops", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/experimental/ops:interleave_ops", + "//tensorflow/python/data/kernel_tests:test_base", + "//tensorflow/python/data/ops:dataset_ops", + "@six_archive//:six", + ], +) + +py_test( + name = "iterator_ops_test", + size = "small", + srcs = ["iterator_ops_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python:variables", + "//tensorflow/python/data/experimental/ops:iterator_ops", + "//tensorflow/python/data/kernel_tests:test_base", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/estimator:estimator_py", + ], +) + +py_test( + name = "map_dataset_op_test", + size = "medium", + srcs = ["map_dataset_op_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + "noasan", # times out + "optonly", + ], + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:io_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:util", + "//tensorflow/python/data/experimental/ops:batching", + "//tensorflow/python/data/experimental/ops:error_ops", + "//tensorflow/python/data/experimental/ops:optimization", + "//tensorflow/python/data/kernel_tests:test_base", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "filter_dataset_op_test", + size = "medium", + srcs = ["filter_dataset_op_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:io_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:util", + "//tensorflow/python/data/experimental/ops:optimization", + "//tensorflow/python/data/kernel_tests:test_base", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "map_defun_op_test", + size = "small", + srcs = ["map_defun_op_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:check_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:data_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:function", + "//tensorflow/python:functional_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:session", + "//tensorflow/python/data/experimental/ops:map_defun", + "//tensorflow/python/data/kernel_tests:test_base", + ], +) + +py_test( + name = "parsing_ops_test", + size = "small", + srcs = ["parsing_ops_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:platform", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/experimental/ops:parsing_ops", + "//tensorflow/python/data/kernel_tests:test_base", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + "//third_party/py/numpy", + ], +) + +cuda_py_test( + name = "prefetching_ops_test", + size = "small", + srcs = ["prefetching_ops_test.py"], + additional_deps = [ + "//tensorflow/python/data/experimental/ops:prefetching_ops", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python/data/kernel_tests:test_base", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:function", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python/compat:compat", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/ops:iterator_ops", + ], + tags = ["no_windows_gpu"], +) + +py_test( + name = "range_dataset_op_test", + size = "small", + srcs = ["range_dataset_op_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:tensor_shape", + "//tensorflow/python/data/experimental/ops:counter", + "//tensorflow/python/data/experimental/ops:enumerate_ops", + "//tensorflow/python/data/kernel_tests:test_base", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_library( + name = "reader_dataset_ops_test_base", + testonly = 1, + srcs = [ + "reader_dataset_ops_test_base.py", + ], + srcs_version = "PY2AND3", + visibility = [ + "//tensorflow/python/data/experimental/kernel_tests:__pkg__", + "//tensorflow/python/data/experimental/kernel_tests/serialization:__pkg__", + ], + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:lib", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:util", + "//tensorflow/python/data/experimental/ops:readers", + "//tensorflow/python/data/kernel_tests:test_base", + "//tensorflow/python/data/ops:iterator_ops", + "//tensorflow/python/data/ops:readers", + ], +) + +py_test( + name = "reader_dataset_ops_test", + size = "medium", + srcs = ["reader_dataset_ops_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":reader_dataset_ops_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:string_ops", + "//tensorflow/python/data/experimental/ops:readers", + "//tensorflow/python/data/ops:readers", + "//tensorflow/python/data/util:nest", + "//third_party/py/numpy", + ], +) + +py_test( + name = "resample_test", + size = "medium", + srcs = ["resample_test.py"], + shard_count = 2, + srcs_version = "PY2AND3", + tags = [ + "noasan", + "optonly", + ], + deps = [ + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:string_ops", + "//tensorflow/python:util", + "//tensorflow/python/data/experimental/ops:resampling", + "//tensorflow/python/data/kernel_tests:test_base", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + "@six_archive//:six", + ], +) + +py_test( + name = "scan_dataset_op_test", + size = "small", + srcs = ["scan_dataset_op_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/experimental/ops:scan_ops", + "//tensorflow/python/data/kernel_tests:test_base", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/eager:context", + "//third_party/py/numpy", + ], +) + +py_test( + name = "shuffle_dataset_op_test", + size = "medium", + srcs = ["shuffle_dataset_op_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + "optonly", + ], + deps = [ + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python/data/experimental/ops:shuffle_ops", + "//tensorflow/python/data/kernel_tests:test_base", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_library( + name = "sql_dataset_op_test_base", + srcs = ["sql_dataset_op_test_base.py"], + srcs_version = "PY2AND3", + visibility = [ + "//tensorflow/python/data/experimental/kernel_tests:__pkg__", + "//tensorflow/python/data/experimental/kernel_tests/serialization:__pkg__", + ], + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python/data/experimental/ops:readers", + "//tensorflow/python/data/kernel_tests:test_base", + "@org_sqlite//:python", + ], +) + +py_test( + name = "sql_dataset_op_test", + size = "small", + srcs = ["sql_dataset_op_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":sql_dataset_op_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + ], +) + +py_test( + name = "stats_dataset_ops_test", + size = "medium", + srcs = ["stats_dataset_ops_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":reader_dataset_ops_test_base", + ":stats_dataset_test_base", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python/data/experimental/ops:stats_ops", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_library( + name = "stats_dataset_test_base", + srcs = ["stats_dataset_test_base.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/kernel_tests:test_base", + ], +) + +py_test( + name = "threadpool_dataset_ops_test", + size = "small", + srcs = ["threadpool_dataset_ops_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:script_ops", + "//tensorflow/python/data/experimental/ops:threadpool", + "//tensorflow/python/data/experimental/ops:unique", + "//tensorflow/python/data/kernel_tests:test_base", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) + +py_test( + name = "unique_dataset_op_test", + size = "small", + srcs = ["unique_dataset_op_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:util", + "//tensorflow/python/data/experimental/ops:unique", + "//tensorflow/python/data/kernel_tests:test_base", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "writer_ops_test", + size = "small", + srcs = ["writer_ops_test.py"], + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:lib", + "//tensorflow/python:util", + "//tensorflow/python/data/experimental/ops:writers", + "//tensorflow/python/data/kernel_tests:test_base", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/ops:readers", + ], +) diff --git a/tensorflow/python/data/experimental/kernel_tests/batch_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/batch_dataset_op_test.py new file mode 100644 index 0000000000..8703b2810e --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/batch_dataset_op_test.py @@ -0,0 +1,672 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import time + +from absl.testing import parameterized +import numpy as np + +from tensorflow.python.client import session +from tensorflow.python.data.experimental.ops import batching +from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import script_ops +from tensorflow.python.ops import string_ops +from tensorflow.python.platform import test +from tensorflow.python.util import compat + + +class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): + + def testDenseToSparseBatchDataset(self): + components = np.random.randint(12, size=(100,)).astype(np.int32) + iterator = ( + dataset_ops.Dataset.from_tensor_slices(components) + .map(lambda x: array_ops.fill([x], x)).apply( + batching.dense_to_sparse_batch(4, [12])) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.cached_session() as sess: + sess.run(init_op) + + for start in range(0, len(components), 4): + results = sess.run(get_next) + self.assertAllEqual([[i, j] + for i, c in enumerate(components[start:start + 4]) + for j in range(c)], results.indices) + self.assertAllEqual( + [c for c in components[start:start + 4] for _ in range(c)], + results.values) + self.assertAllEqual([min(4, + len(components) - start), 12], + results.dense_shape) + + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testDenseToSparseBatchDatasetWithUnknownShape(self): + components = np.random.randint(5, size=(40,)).astype(np.int32) + iterator = ( + dataset_ops.Dataset.from_tensor_slices(components) + .map(lambda x: array_ops.fill([x, x], x)).apply( + batching.dense_to_sparse_batch( + 4, [5, None])).make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.cached_session() as sess: + sess.run(init_op) + + for start in range(0, len(components), 4): + results = sess.run(get_next) + self.assertAllEqual([[i, j, z] + for i, c in enumerate(components[start:start + 4]) + for j in range(c) + for z in range(c)], results.indices) + self.assertAllEqual([ + c + for c in components[start:start + 4] for _ in range(c) + for _ in range(c) + ], results.values) + self.assertAllEqual([ + min(4, + len(components) - start), 5, + np.max(components[start:start + 4]) + ], results.dense_shape) + + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testDenseToSparseBatchDatasetWithInvalidShape(self): + input_tensor = array_ops.constant([[1]]) + with self.assertRaisesRegexp(ValueError, "Dimension -2 must be >= 0"): + dataset_ops.Dataset.from_tensors(input_tensor).apply( + batching.dense_to_sparse_batch(4, [-2])).make_initializable_iterator() + + def testDenseToSparseBatchDatasetShapeErrors(self): + input_tensor = array_ops.placeholder(dtypes.int32) + iterator = ( + dataset_ops.Dataset.from_tensors(input_tensor).apply( + batching.dense_to_sparse_batch(4, [12])) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.cached_session() as sess: + # Initialize with an input tensor of incompatible rank. + sess.run(init_op, feed_dict={input_tensor: [[1]]}) + with self.assertRaisesRegexp(errors.InvalidArgumentError, + "incompatible with the row shape"): + sess.run(get_next) + + # Initialize with an input tensor that is larger than `row_shape`. + sess.run(init_op, feed_dict={input_tensor: range(13)}) + with self.assertRaisesRegexp(errors.DataLossError, + "larger than the row shape"): + sess.run(get_next) + + def testUnbatchScalarDataset(self): + data = tuple([math_ops.range(10) for _ in range(3)]) + data = dataset_ops.Dataset.from_tensor_slices(data) + expected_types = (dtypes.int32,) * 3 + data = data.batch(2) + self.assertEqual(expected_types, data.output_types) + data = data.apply(batching.unbatch()) + self.assertEqual(expected_types, data.output_types) + + iterator = data.make_one_shot_iterator() + op = iterator.get_next() + + with self.cached_session() as sess: + for i in range(10): + self.assertEqual((i,) * 3, sess.run(op)) + + with self.assertRaises(errors.OutOfRangeError): + sess.run(op) + + def testUnbatchDatasetWithStrings(self): + data = tuple([math_ops.range(10) for _ in range(3)]) + data = dataset_ops.Dataset.from_tensor_slices(data) + data = data.map(lambda x, y, z: (x, string_ops.as_string(y), z)) + expected_types = (dtypes.int32, dtypes.string, dtypes.int32) + data = data.batch(2) + self.assertEqual(expected_types, data.output_types) + data = data.apply(batching.unbatch()) + self.assertEqual(expected_types, data.output_types) + + iterator = data.make_one_shot_iterator() + op = iterator.get_next() + + with self.cached_session() as sess: + for i in range(10): + self.assertEqual((i, compat.as_bytes(str(i)), i), sess.run(op)) + + with self.assertRaises(errors.OutOfRangeError): + sess.run(op) + + def testUnbatchDatasetWithSparseTensor(self): + st = sparse_tensor.SparseTensorValue( + indices=[[i, i] for i in range(10)], + values=list(range(10)), + dense_shape=[10, 10]) + data = dataset_ops.Dataset.from_tensors(st) + data = data.apply(batching.unbatch()) + data = data.batch(5) + data = data.apply(batching.unbatch()) + iterator = data.make_one_shot_iterator() + next_element = iterator.get_next() + + with self.cached_session() as sess: + for i in range(10): + st_row = sess.run(next_element) + self.assertEqual([i], st_row.indices) + self.assertEqual([i], st_row.values) + self.assertEqual([10], st_row.dense_shape) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testUnbatchDatasetWithDenseAndSparseTensor(self): + st = sparse_tensor.SparseTensorValue( + indices=[[i, i] for i in range(10)], + values=list(range(10)), + dense_shape=[10, 10]) + data = dataset_ops.Dataset.from_tensors((list(range(10)), st)) + data = data.apply(batching.unbatch()) + data = data.batch(5) + data = data.apply(batching.unbatch()) + iterator = data.make_one_shot_iterator() + next_element = iterator.get_next() + + with self.cached_session() as sess: + for i in range(10): + dense_elem, st_row = sess.run(next_element) + self.assertEqual(i, dense_elem) + self.assertEqual([i], st_row.indices) + self.assertEqual([i], st_row.values) + self.assertEqual([10], st_row.dense_shape) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testUnbatchSingleElementTupleDataset(self): + data = tuple([(math_ops.range(10),) for _ in range(3)]) + data = dataset_ops.Dataset.from_tensor_slices(data) + expected_types = ((dtypes.int32,),) * 3 + data = data.batch(2) + self.assertEqual(expected_types, data.output_types) + data = data.apply(batching.unbatch()) + self.assertEqual(expected_types, data.output_types) + + iterator = data.make_one_shot_iterator() + op = iterator.get_next() + + with self.cached_session() as sess: + for i in range(10): + self.assertEqual(((i,),) * 3, sess.run(op)) + + with self.assertRaises(errors.OutOfRangeError): + sess.run(op) + + def testUnbatchMultiElementTupleDataset(self): + data = tuple([(math_ops.range(10 * i, 10 * i + 10), + array_ops.fill([10], "hi")) for i in range(3)]) + data = dataset_ops.Dataset.from_tensor_slices(data) + expected_types = ((dtypes.int32, dtypes.string),) * 3 + data = data.batch(2) + self.assertAllEqual(expected_types, data.output_types) + data = data.apply(batching.unbatch()) + self.assertAllEqual(expected_types, data.output_types) + + iterator = data.make_one_shot_iterator() + op = iterator.get_next() + + with self.cached_session() as sess: + for i in range(10): + self.assertEqual(((i, b"hi"), (10 + i, b"hi"), (20 + i, b"hi")), + sess.run(op)) + + with self.assertRaises(errors.OutOfRangeError): + sess.run(op) + + def testUnbatchEmpty(self): + data = dataset_ops.Dataset.from_tensors( + (constant_op.constant([]), constant_op.constant([], shape=[0, 4]), + constant_op.constant([], shape=[0, 4, 0]))) + data = data.apply(batching.unbatch()) + iterator = data.make_one_shot_iterator() + next_element = iterator.get_next() + + with self.cached_session() as sess: + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testUnbatchStaticShapeMismatch(self): + data = dataset_ops.Dataset.from_tensors((np.arange(7), np.arange(8), + np.arange(9))) + with self.assertRaises(ValueError): + data.apply(batching.unbatch()) + + def testUnbatchDynamicShapeMismatch(self): + ph1 = array_ops.placeholder(dtypes.int32, shape=[None]) + ph2 = array_ops.placeholder(dtypes.int32, shape=None) + data = dataset_ops.Dataset.from_tensors((ph1, ph2)) + data = data.apply(batching.unbatch()) + iterator = data.make_initializable_iterator() + next_element = iterator.get_next() + + with self.cached_session() as sess: + # Mismatch in the 0th dimension. + sess.run( + iterator.initializer, + feed_dict={ + ph1: np.arange(7).astype(np.int32), + ph2: np.arange(8).astype(np.int32) + }) + with self.assertRaises(errors.InvalidArgumentError): + sess.run(next_element) + + # No 0th dimension (i.e. scalar value) for one component. + sess.run( + iterator.initializer, + feed_dict={ + ph1: np.arange(7).astype(np.int32), + ph2: 7 + }) + with self.assertRaises(errors.InvalidArgumentError): + sess.run(next_element) + + @parameterized.named_parameters( + ("Default", None, None), + ("SequentialCalls", 1, None), + ("ParallelCalls", 2, None), + ("ParallelBatches", None, 10), + ) + def testMapAndBatch(self, num_parallel_calls, num_parallel_batches): + """Test a dataset that maps a TF function across its input elements.""" + # The pipeline is TensorSliceDataset -> + # RepeatDataset(count) -> MapAndBatchDataset(square_3, batch_size). + components = (np.arange(7), + np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], + np.array(37.0) * np.arange(7)) + + count = array_ops.placeholder(dtypes.int64, shape=[]) + batch_size = array_ops.placeholder(dtypes.int64, shape=[]) + + def _map_fn(x, y, z): + return math_ops.square(x), math_ops.square(y), math_ops.square(z) + + iterator = ( + dataset_ops.Dataset.from_tensor_slices(components).repeat(count).apply( + batching.map_and_batch( + map_func=_map_fn, + batch_size=batch_size, + num_parallel_calls=num_parallel_calls, + num_parallel_batches=num_parallel_batches)) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + self.assertEqual([[None] + list(c.shape[1:]) for c in components], + [t.shape.as_list() for t in get_next]) + + with self.cached_session() as sess: + # Batch of a finite input, where the batch_size divides the + # total number of elements. + sess.run(init_op, feed_dict={count: 28, batch_size: 14}) + num_batches = (28 * 7) // 14 + for i in range(num_batches): + result = sess.run(get_next) + for component, result_component in zip(components, result): + for j in range(14): + self.assertAllEqual(component[(i * 14 + j) % 7]**2, + result_component[j]) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Batch of a finite input, where the batch_size does not + # divide the total number of elements. + sess.run(init_op, feed_dict={count: 14, batch_size: 8}) + + # We expect (num_batches - 1) full-sized batches. + num_batches = int(math.ceil((14 * 7) / 8)) + for i in range(num_batches - 1): + result = sess.run(get_next) + for component, result_component in zip(components, result): + for j in range(8): + self.assertAllEqual(component[(i * 8 + j) % 7]**2, + result_component[j]) + result = sess.run(get_next) + for component, result_component in zip(components, result): + for j in range((14 * 7) % 8): + self.assertAllEqual(component[((num_batches - 1) * 8 + j) % 7]**2, + result_component[j]) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Batch of an empty input should fail straight away. + sess.run(init_op, feed_dict={count: 0, batch_size: 8}) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Empty batch should be an initialization time error. + with self.assertRaises(errors.InvalidArgumentError): + sess.run(init_op, feed_dict={count: 14, batch_size: 0}) + + @parameterized.named_parameters( + ("Even", False), + ("Uneven", True), + ) + def testMapAndBatchPartialBatch(self, drop_remainder): + iterator = ( + dataset_ops.Dataset.range(10).apply( + batching.map_and_batch( + lambda x: array_ops.reshape(x * x, [1]), + batch_size=4, + drop_remainder=drop_remainder)).make_one_shot_iterator()) + if drop_remainder: + self.assertEqual([4, 1], iterator.output_shapes.as_list()) + else: + self.assertEqual([None, 1], iterator.output_shapes.as_list()) + next_element = iterator.get_next() + with self.cached_session() as sess: + self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element)) + self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element)) + if not drop_remainder: + self.assertAllEqual([[64], [81]], sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testMapAndBatchYieldsPartialBatch(self): + iterator = (dataset_ops.Dataset.range(10) + .apply(batching.map_and_batch( + lambda x: array_ops.reshape(x * x, [1]), 4)) + .make_one_shot_iterator()) + self.assertEqual([None, 1], iterator.output_shapes.as_list()) + next_element = iterator.get_next() + with self.cached_session() as sess: + self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element)) + self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element)) + self.assertAllEqual([[64], [81]], sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testMapAndBatchParallelGetNext(self): + iterator = (dataset_ops.Dataset.range(50000) + .apply(batching.map_and_batch(lambda x: x, batch_size=100)) + .make_one_shot_iterator()) + elements = [] + for _ in range(100): + elements.append(iterator.get_next()) + with self.cached_session() as sess: + for i in range(5): + got = sess.run(elements) + got.sort(key=lambda x: x[0]) + expected = [] + for j in range(100): + expected.append(range(i*10000+j*100, i*10000+(j+1)*100)) + self.assertAllEqual(got, expected) + with self.assertRaises(errors.OutOfRangeError): + sess.run(elements) + + def testMapAndBatchParallelGetNextDropRemainder(self): + iterator = ( + dataset_ops.Dataset.range(49999).apply( + batching.map_and_batch( + lambda x: x, batch_size=100, drop_remainder=True)) + .make_one_shot_iterator()) + elements = [] + for _ in range(100): + elements.append(iterator.get_next()) + with self.cached_session() as sess: + for i in range(4): + got = sess.run(elements) + got.sort(key=lambda x: x[0]) + expected = [] + for j in range(100): + expected.append(range(i*10000+j*100, i*10000+(j+1)*100)) + self.assertAllEqual(got, expected) + with self.assertRaises(errors.OutOfRangeError): + sess.run(elements) + + def testMapAndBatchSparse(self): + + def _sparse(i): + return sparse_tensor.SparseTensorValue( + indices=[[0]], values=(i * [1]), dense_shape=[1]) + + iterator = dataset_ops.Dataset.range(10).apply( + batching.map_and_batch(_sparse, 5)).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.cached_session() as sess: + sess.run(init_op) + for i in range(2): + actual = sess.run(get_next) + expected = sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]], + values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4], + dense_shape=[5, 1]) + self.assertTrue(sparse_tensor.is_sparse(actual)) + self.assertSparseValuesEqual(actual, expected) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testMapAndBatchFails(self): + """Test a dataset that maps a TF function across its input elements.""" + dataset = dataset_ops.Dataset.from_tensors( + array_ops.check_numerics( + constant_op.constant(1.0) / constant_op.constant(0.0), "oops")) + batch_size = array_ops.placeholder(dtypes.int64, shape=[]) + iterator = ( + dataset.apply(batching.map_and_batch(lambda x: x, batch_size)) + .make_initializable_iterator()) + init_op = iterator.initializer + with self.cached_session() as sess: + with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"): + sess.run(init_op, feed_dict={batch_size: 14}) + + def testMapAndBatchShapeMismatch(self): + """Test a dataset that maps a TF function across its input elements.""" + + def generator(): + yield [1] + yield [2] + yield [3] + yield [[4, 5, 6]] + + dataset = dataset_ops.Dataset.from_generator( + generator, output_types=dtypes.int32) + batch_size = 4 + iterator = ( + dataset.apply(batching.map_and_batch(lambda x: x, batch_size)) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + with self.cached_session() as sess: + sess.run(init_op) + with self.assertRaisesRegexp(errors.InvalidArgumentError, + "number of elements does not match"): + sess.run(get_next) + + def testMapAndBatchImplicitDispose(self): + # Tests whether a map and batch dataset will be cleaned up correctly when + # the pipeline does not run it until exhaustion. + # The pipeline is TensorSliceDataset -> RepeatDataset(1000) -> + # MapAndBatchDataset(f=square_3, batch_size=100). + components = (np.arange(1000), + np.array([[1, 2, 3]]) * np.arange(1000)[:, np.newaxis], + np.array(37.0) * np.arange(1000)) + + def _map_fn(x, y, z): + return math_ops.square(x), math_ops.square(y), math_ops.square(z) + + dataset = dataset_ops.Dataset.from_tensor_slices(components).repeat( + 1000).apply(batching.map_and_batch(_map_fn, batch_size=100)) + dataset = dataset.prefetch(5) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.cached_session() as sess: + for _ in range(3): + sess.run(get_next) + + @parameterized.named_parameters( + ("1", 0), + ("2", 5), + ("3", 10), + ("4", 90), + ("5", 95), + ("6", 99), + ) + def testMapAndBatchOutOfRangeError(self, threshold): + + def raising_py_fn(i): + if i >= threshold: + raise StopIteration() + else: + return i + + iterator = ( + dataset_ops.Dataset.range(100).apply( + batching.map_and_batch( + lambda x: script_ops.py_func(raising_py_fn, [x], dtypes.int64), + batch_size=10)).make_one_shot_iterator()) + get_next = iterator.get_next() + + with self.cached_session() as sess: + for i in range(threshold // 10): + self.assertAllEqual([i * 10 + j for j in range(10)], sess.run(get_next)) + if threshold % 10 != 0: + self.assertAllEqual( + [threshold // 10 * 10 + j for j in range(threshold % 10)], + sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + @parameterized.named_parameters( + ("1", False, dtypes.bool), + ("2", -42, dtypes.int8), + ("3", -42, dtypes.int16), + ("4", -42, dtypes.int32), + ("5", -42, dtypes.int64), + ("6", 42, dtypes.uint8), + ("7", 42, dtypes.uint16), + ("8", 42.0, dtypes.float16), + ("9", 42.0, dtypes.float32), + ("10", 42.0, dtypes.float64), + ("11", b"hello", dtypes.string), + ) + def testMapAndBatchTypes(self, element, dtype): + def gen(): + yield element + + dataset = dataset_ops.Dataset.from_generator(gen, dtype).repeat(100).apply( + batching.map_and_batch(lambda x: x, batch_size=10)) + + get_next = dataset.make_one_shot_iterator().get_next() + + with self.cached_session() as sess: + for _ in range(10): + self.assertAllEqual([element for _ in range(10)], sess.run(get_next)) + + +class UnbatchDatasetBenchmark(test.Benchmark): + + def benchmarkNativeUnbatch(self): + batch_sizes = [1, 2, 5, 10, 20, 50] + elems_per_trial = 10000 + with ops.Graph().as_default(): + dataset = dataset_ops.Dataset.from_tensors("element").repeat(None) + batch_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[]) + dataset = dataset.batch(batch_size_placeholder) + dataset = dataset.apply(batching.unbatch()) + dataset = dataset.skip(elems_per_trial) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with session.Session() as sess: + for batch_size in batch_sizes: + deltas = [] + for _ in range(5): + sess.run( + iterator.initializer, + feed_dict={batch_size_placeholder: batch_size}) + start = time.time() + sess.run(next_element.op) + end = time.time() + deltas.append((end - start) / elems_per_trial) + + median_wall_time = np.median(deltas) + print("Unbatch (native) batch size: %d Median wall time per element:" + " %f microseconds" % (batch_size, median_wall_time * 1e6)) + self.report_benchmark( + iters=10000, + wall_time=median_wall_time, + name="benchmark_unbatch_dataset_native_batch_size_%d" % + batch_size) + + # Include a benchmark of the previous `unbatch()` implementation that uses + # a composition of more primitive ops. Eventually we'd hope to generate code + # that is as good in both cases. + def benchmarkOldUnbatchImplementation(self): + batch_sizes = [1, 2, 5, 10, 20, 50] + elems_per_trial = 10000 + with ops.Graph().as_default(): + dataset = dataset_ops.Dataset.from_tensors("element").repeat(None) + batch_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[]) + dataset = dataset.batch(batch_size_placeholder) + dataset = dataset.flat_map(dataset_ops.Dataset.from_tensor_slices) + dataset = dataset.skip(elems_per_trial) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with session.Session() as sess: + for batch_size in batch_sizes: + deltas = [] + for _ in range(5): + sess.run( + iterator.initializer, + feed_dict={batch_size_placeholder: batch_size}) + start = time.time() + sess.run(next_element.op) + end = time.time() + deltas.append((end - start) / elems_per_trial) + + median_wall_time = np.median(deltas) + print("Unbatch (unfused) batch size: %d Median wall time per element:" + " %f microseconds" % (batch_size, median_wall_time * 1e6)) + self.report_benchmark( + iters=10000, + wall_time=median_wall_time, + name="benchmark_unbatch_dataset_unfused_batch_size_%d" % + batch_size) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/bucketing_test.py b/tensorflow/python/data/experimental/kernel_tests/bucketing_test.py new file mode 100644 index 0000000000..153a03989b --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/bucketing_test.py @@ -0,0 +1,824 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import random + +import numpy as np + +from tensorflow.python.data.experimental.ops import grouping +from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import string_ops +from tensorflow.python.platform import test + + +class GroupByReducerTest(test_base.DatasetTestBase): + + def checkResults(self, dataset, shapes, values): + self.assertEqual(shapes, dataset.output_shapes) + get_next = dataset.make_one_shot_iterator().get_next() + with self.cached_session() as sess: + for expected in values: + got = sess.run(get_next) + self.assertEqual(got, expected) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testSum(self): + reducer = grouping.Reducer( + init_func=lambda _: np.int64(0), + reduce_func=lambda x, y: x + y, + finalize_func=lambda x: x) + for i in range(1, 11): + dataset = dataset_ops.Dataset.range(2 * i).apply( + grouping.group_by_reducer(lambda x: x % 2, reducer)) + self.checkResults( + dataset, shapes=tensor_shape.scalar(), values=[(i - 1) * i, i * i]) + + def testAverage(self): + + def reduce_fn(x, y): + return (x[0] * x[1] + math_ops.cast(y, dtypes.float32)) / ( + x[1] + 1), x[1] + 1 + + reducer = grouping.Reducer( + init_func=lambda _: (0.0, 0.0), + reduce_func=reduce_fn, + finalize_func=lambda x, _: x) + for i in range(1, 11): + dataset = dataset_ops.Dataset.range(2 * i).apply( + grouping.group_by_reducer( + lambda x: math_ops.cast(x, dtypes.int64) % 2, reducer)) + self.checkResults( + dataset, shapes=tensor_shape.scalar(), values=[i - 1, i]) + + def testConcat(self): + components = np.array(list("abcdefghijklmnopqrst")).view(np.chararray) + reducer = grouping.Reducer( + init_func=lambda x: "", + reduce_func=lambda x, y: x + y[0], + finalize_func=lambda x: x) + for i in range(1, 11): + dataset = dataset_ops.Dataset.zip( + (dataset_ops.Dataset.from_tensor_slices(components), + dataset_ops.Dataset.range(2 * i))).apply( + grouping.group_by_reducer(lambda x, y: y % 2, reducer)) + self.checkResults( + dataset, + shapes=tensor_shape.scalar(), + values=[b"acegikmoqs" [:i], b"bdfhjlnprt" [:i]]) + + def testSparseSum(self): + def _sparse(i): + return sparse_tensor.SparseTensorValue( + indices=np.array([[0, 0]]), + values=(i * np.array([1], dtype=np.int64)), + dense_shape=np.array([1, 1])) + + reducer = grouping.Reducer( + init_func=lambda _: _sparse(np.int64(0)), + reduce_func=lambda x, y: _sparse(x.values[0] + y.values[0]), + finalize_func=lambda x: x.values[0]) + for i in range(1, 11): + dataset = dataset_ops.Dataset.range(2 * i).map(_sparse).apply( + grouping.group_by_reducer(lambda x: x.values[0] % 2, reducer)) + self.checkResults( + dataset, shapes=tensor_shape.scalar(), values=[(i - 1) * i, i * i]) + + def testChangingStateShape(self): + + def reduce_fn(x, _): + # Statically known rank, but dynamic length. + larger_dim = array_ops.concat([x[0], x[0]], 0) + # Statically unknown rank. + larger_rank = array_ops.expand_dims(x[1], 0) + return larger_dim, larger_rank + + reducer = grouping.Reducer( + init_func=lambda x: ([0], 1), + reduce_func=reduce_fn, + finalize_func=lambda x, y: (x, y)) + + for i in range(1, 11): + dataset = dataset_ops.Dataset.from_tensors(np.int64(0)).repeat(i).apply( + grouping.group_by_reducer(lambda x: x, reducer)) + self.assertEqual([None], dataset.output_shapes[0].as_list()) + self.assertIs(None, dataset.output_shapes[1].ndims) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + with self.cached_session() as sess: + x, y = sess.run(get_next) + self.assertAllEqual([0] * (2**i), x) + self.assertAllEqual(np.array(1, ndmin=i), y) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testTypeMismatch(self): + reducer = grouping.Reducer( + init_func=lambda x: constant_op.constant(1, dtype=dtypes.int32), + reduce_func=lambda x, y: constant_op.constant(1, dtype=dtypes.int64), + finalize_func=lambda x: x) + + dataset = dataset_ops.Dataset.range(10) + with self.assertRaisesRegexp( + TypeError, + "The element types for the new state must match the initial state."): + dataset.apply( + grouping.group_by_reducer(lambda _: np.int64(0), reducer)) + + # TODO(b/78665031): Remove once non-scalar keys are supported. + def testInvalidKeyShape(self): + reducer = grouping.Reducer( + init_func=lambda x: np.int64(0), + reduce_func=lambda x, y: x + y, + finalize_func=lambda x: x) + + dataset = dataset_ops.Dataset.range(10) + with self.assertRaisesRegexp( + ValueError, "`key_func` must return a single tf.int64 tensor."): + dataset.apply( + grouping.group_by_reducer(lambda _: np.int64((0, 0)), reducer)) + + # TODO(b/78665031): Remove once non-int64 keys are supported. + def testInvalidKeyType(self): + reducer = grouping.Reducer( + init_func=lambda x: np.int64(0), + reduce_func=lambda x, y: x + y, + finalize_func=lambda x: x) + + dataset = dataset_ops.Dataset.range(10) + with self.assertRaisesRegexp( + ValueError, "`key_func` must return a single tf.int64 tensor."): + dataset.apply( + grouping.group_by_reducer(lambda _: "wrong", reducer)) + + def testTuple(self): + def init_fn(_): + return np.array([], dtype=np.int64), np.int64(0) + + def reduce_fn(state, value): + s1, s2 = state + v1, v2 = value + return array_ops.concat([s1, [v1]], 0), s2 + v2 + + def finalize_fn(s1, s2): + return s1, s2 + + reducer = grouping.Reducer(init_fn, reduce_fn, finalize_fn) + dataset = dataset_ops.Dataset.zip( + (dataset_ops.Dataset.range(10), dataset_ops.Dataset.range(10))).apply( + grouping.group_by_reducer(lambda x, y: np.int64(0), reducer)) + get_next = dataset.make_one_shot_iterator().get_next() + with self.cached_session() as sess: + x, y = sess.run(get_next) + self.assertAllEqual(x, np.asarray([x for x in range(10)])) + self.assertEqual(y, 45) + + +class GroupByWindowTest(test_base.DatasetTestBase): + + def testSimple(self): + components = np.random.randint(100, size=(200,)).astype(np.int64) + iterator = ( + dataset_ops.Dataset.from_tensor_slices(components).map(lambda x: x * x) + .apply( + grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(4), + 4)).make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.cached_session() as sess: + sess.run(init_op) + counts = [] + with self.assertRaises(errors.OutOfRangeError): + while True: + result = sess.run(get_next) + self.assertTrue( + all(x % 2 == 0 + for x in result) or all(x % 2 == 1) + for x in result) + counts.append(result.shape[0]) + + self.assertEqual(len(components), sum(counts)) + num_full_batches = len([c for c in counts if c == 4]) + self.assertGreaterEqual(num_full_batches, 24) + self.assertTrue(all(c == 4 for c in counts[:num_full_batches])) + + def testImmediateOutput(self): + components = np.array( + [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0], dtype=np.int64) + iterator = ( + dataset_ops.Dataset.from_tensor_slices(components).repeat(-1).apply( + grouping.group_by_window(lambda x: x % 3, lambda _, xs: xs.batch(4), + 4)).make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.cached_session() as sess: + sess.run(init_op) + # The input is infinite, so this test demonstrates that: + # 1. We produce output without having to consume the entire input, + # 2. Different buckets can produce output at different rates, and + # 3. For deterministic input, the output is deterministic. + for _ in range(3): + self.assertAllEqual([0, 0, 0, 0], sess.run(get_next)) + self.assertAllEqual([1, 1, 1, 1], sess.run(get_next)) + self.assertAllEqual([2, 2, 2, 2], sess.run(get_next)) + self.assertAllEqual([0, 0, 0, 0], sess.run(get_next)) + + def testSmallGroups(self): + components = np.array([0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0], dtype=np.int64) + iterator = ( + dataset_ops.Dataset.from_tensor_slices(components).apply( + grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(4), + 4)).make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.cached_session() as sess: + sess.run(init_op) + self.assertAllEqual([0, 0, 0, 0], sess.run(get_next)) + self.assertAllEqual([1, 1, 1, 1], sess.run(get_next)) + # The small outputs at the end are deterministically produced in key + # order. + self.assertAllEqual([0, 0, 0], sess.run(get_next)) + self.assertAllEqual([1], sess.run(get_next)) + + def testEmpty(self): + iterator = ( + dataset_ops.Dataset.range(4).apply( + grouping.group_by_window(lambda _: 0, lambda _, xs: xs, 0)) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.cached_session() as sess: + sess.run(init_op) + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + "Window size must be greater than zero, but got 0."): + print(sess.run(get_next)) + + def testReduceFuncError(self): + components = np.random.randint(100, size=(200,)).astype(np.int64) + + def reduce_func(_, xs): + # Introduce an incorrect padded shape that cannot (currently) be + # detected at graph construction time. + return xs.padded_batch( + 4, + padded_shapes=(tensor_shape.TensorShape([]), + constant_op.constant([5], dtype=dtypes.int64) * -1)) + + iterator = ( + dataset_ops.Dataset.from_tensor_slices(components) + .map(lambda x: (x, ops.convert_to_tensor([x * x]))).apply( + grouping.group_by_window(lambda x, _: x % 2, reduce_func, + 32)).make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.cached_session() as sess: + sess.run(init_op) + with self.assertRaises(errors.InvalidArgumentError): + sess.run(get_next) + + def testConsumeWindowDatasetMoreThanOnce(self): + components = np.random.randint(50, size=(200,)).astype(np.int64) + + def reduce_func(key, window): + # Apply two different kinds of padding to the input: tight + # padding, and quantized (to a multiple of 10) padding. + return dataset_ops.Dataset.zip(( + window.padded_batch( + 4, padded_shapes=tensor_shape.TensorShape([None])), + window.padded_batch( + 4, padded_shapes=ops.convert_to_tensor([(key + 1) * 10])), + )) + + iterator = ( + dataset_ops.Dataset.from_tensor_slices(components) + .map(lambda x: array_ops.fill([math_ops.cast(x, dtypes.int32)], x)) + .apply(grouping.group_by_window( + lambda x: math_ops.cast(array_ops.shape(x)[0] // 10, dtypes.int64), + reduce_func, 4)) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.cached_session() as sess: + sess.run(init_op) + counts = [] + with self.assertRaises(errors.OutOfRangeError): + while True: + tight_result, multiple_of_10_result = sess.run(get_next) + self.assertEqual(0, multiple_of_10_result.shape[1] % 10) + self.assertAllEqual(tight_result, + multiple_of_10_result[:, :tight_result.shape[1]]) + counts.append(tight_result.shape[0]) + self.assertEqual(len(components), sum(counts)) + + +# NOTE(mrry): These tests are based on the tests in bucket_ops_test.py. +# Currently, they use a constant batch size, though should be made to use a +# different batch size per key. +class BucketTest(test_base.DatasetTestBase): + + def _dynamicPad(self, bucket, window, window_size): + # TODO(mrry): To match `tf.contrib.training.bucket()`, implement a + # generic form of padded_batch that pads every component + # dynamically and does not rely on static shape information about + # the arguments. + return dataset_ops.Dataset.zip( + (dataset_ops.Dataset.from_tensors(bucket), + window.padded_batch( + 32, (tensor_shape.TensorShape([]), tensor_shape.TensorShape( + [None]), tensor_shape.TensorShape([3]))))) + + def testSingleBucket(self): + + def _map_fn(v): + return (v, array_ops.fill([v], v), + array_ops.fill([3], string_ops.as_string(v))) + + input_dataset = ( + dataset_ops.Dataset.from_tensor_slices(math_ops.range(32)).map(_map_fn)) + + bucketed_dataset = input_dataset.apply( + grouping.group_by_window( + lambda x, y, z: 0, + lambda k, bucket: self._dynamicPad(k, bucket, 32), 32)) + + iterator = bucketed_dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.cached_session() as sess: + sess.run(init_op) + + which_bucket, bucketed_values = sess.run(get_next) + + self.assertEqual(0, which_bucket) + + expected_scalar_int = np.arange(32, dtype=np.int64) + expected_unk_int64 = np.zeros((32, 31)).astype(np.int64) + for i in range(32): + expected_unk_int64[i, :i] = i + expected_vec3_str = np.vstack(3 * [np.arange(32).astype(bytes)]).T + + self.assertAllEqual(expected_scalar_int, bucketed_values[0]) + self.assertAllEqual(expected_unk_int64, bucketed_values[1]) + self.assertAllEqual(expected_vec3_str, bucketed_values[2]) + + def testEvenOddBuckets(self): + + def _map_fn(v): + return (v, array_ops.fill([v], v), + array_ops.fill([3], string_ops.as_string(v))) + + input_dataset = ( + dataset_ops.Dataset.from_tensor_slices(math_ops.range(64)).map(_map_fn)) + + bucketed_dataset = input_dataset.apply( + grouping.group_by_window( + lambda x, y, z: math_ops.cast(x % 2, dtypes.int64), + lambda k, bucket: self._dynamicPad(k, bucket, 32), 32)) + + iterator = bucketed_dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.cached_session() as sess: + sess.run(init_op) + + # Get two minibatches (one containing even values, one containing odds) + which_bucket_even, bucketed_values_even = sess.run(get_next) + which_bucket_odd, bucketed_values_odd = sess.run(get_next) + + # Count number of bucket_tensors. + self.assertEqual(3, len(bucketed_values_even)) + self.assertEqual(3, len(bucketed_values_odd)) + + # Ensure bucket 0 was used for all minibatch entries. + self.assertAllEqual(0, which_bucket_even) + self.assertAllEqual(1, which_bucket_odd) + + # Test the first bucket outputted, the events starting at 0 + expected_scalar_int = np.arange(0, 32 * 2, 2, dtype=np.int64) + expected_unk_int64 = np.zeros((32, 31 * 2)).astype(np.int64) + for i in range(0, 32): + expected_unk_int64[i, :2 * i] = 2 * i + expected_vec3_str = np.vstack( + 3 * [np.arange(0, 32 * 2, 2).astype(bytes)]).T + + self.assertAllEqual(expected_scalar_int, bucketed_values_even[0]) + self.assertAllEqual(expected_unk_int64, bucketed_values_even[1]) + self.assertAllEqual(expected_vec3_str, bucketed_values_even[2]) + + # Test the second bucket outputted, the odds starting at 1 + expected_scalar_int = np.arange(1, 32 * 2 + 1, 2, dtype=np.int64) + expected_unk_int64 = np.zeros((32, 31 * 2 + 1)).astype(np.int64) + for i in range(0, 32): + expected_unk_int64[i, :2 * i + 1] = 2 * i + 1 + expected_vec3_str = np.vstack( + 3 * [np.arange(1, 32 * 2 + 1, 2).astype(bytes)]).T + + self.assertAllEqual(expected_scalar_int, bucketed_values_odd[0]) + self.assertAllEqual(expected_unk_int64, bucketed_values_odd[1]) + self.assertAllEqual(expected_vec3_str, bucketed_values_odd[2]) + + def testEvenOddBucketsFilterOutAllOdd(self): + + def _map_fn(v): + return { + "x": v, + "y": array_ops.fill([v], v), + "z": array_ops.fill([3], string_ops.as_string(v)) + } + + def _dynamic_pad_fn(bucket, window, _): + return dataset_ops.Dataset.zip( + (dataset_ops.Dataset.from_tensors(bucket), + window.padded_batch( + 32, { + "x": tensor_shape.TensorShape([]), + "y": tensor_shape.TensorShape([None]), + "z": tensor_shape.TensorShape([3]) + }))) + + input_dataset = ( + dataset_ops.Dataset.from_tensor_slices(math_ops.range(128)).map(_map_fn) + .filter(lambda d: math_ops.equal(d["x"] % 2, 0))) + + bucketed_dataset = input_dataset.apply( + grouping.group_by_window( + lambda d: math_ops.cast(d["x"] % 2, dtypes.int64), + lambda k, bucket: _dynamic_pad_fn(k, bucket, 32), 32)) + + iterator = bucketed_dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.cached_session() as sess: + sess.run(init_op) + + # Get two minibatches ([0, 2, ...] and [64, 66, ...]) + which_bucket0, bucketed_values_even0 = sess.run(get_next) + which_bucket1, bucketed_values_even1 = sess.run(get_next) + + # Ensure that bucket 1 was completely filtered out + self.assertAllEqual(0, which_bucket0) + self.assertAllEqual(0, which_bucket1) + self.assertAllEqual( + np.arange(0, 64, 2, dtype=np.int64), bucketed_values_even0["x"]) + self.assertAllEqual( + np.arange(64, 128, 2, dtype=np.int64), bucketed_values_even1["x"]) + + def testDynamicWindowSize(self): + components = np.arange(100).astype(np.int64) + + # Key fn: even/odd + # Reduce fn: batches of 5 + # Window size fn: even=5, odd=10 + + def window_size_func(key): + window_sizes = constant_op.constant([5, 10], dtype=dtypes.int64) + return window_sizes[key] + + dataset = dataset_ops.Dataset.from_tensor_slices(components).apply( + grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(20), + None, window_size_func)) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.cached_session() as sess: + sess.run(init_op) + with self.assertRaises(errors.OutOfRangeError): + batches = 0 + while True: + result = sess.run(get_next) + is_even = all(x % 2 == 0 for x in result) + is_odd = all(x % 2 == 1 for x in result) + self.assertTrue(is_even or is_odd) + expected_batch_size = 5 if is_even else 10 + self.assertEqual(expected_batch_size, result.shape[0]) + batches += 1 + + self.assertEqual(batches, 15) + + +def _element_length_fn(x, y=None): + del y + return array_ops.shape(x)[0] + + +def _to_sparse_tensor(record): + return sparse_tensor.SparseTensor(**record) + + +def _format_record(array, sparse): + if sparse: + return { + "values": array, + "indices": [[i] for i in range(len(array))], + "dense_shape": (len(array),) + } + return array + + +def _get_record_type(sparse): + if sparse: + return { + "values": dtypes.int64, + "indices": dtypes.int64, + "dense_shape": dtypes.int64 + } + return dtypes.int32 + + +def _get_record_shape(sparse): + if sparse: + return { + "values": tensor_shape.TensorShape([None,]), + "indices": tensor_shape.TensorShape([None, 1]), + "dense_shape": tensor_shape.TensorShape([1,]) + } + return tensor_shape.TensorShape([None]) + + +class BucketBySequenceLength(test_base.DatasetTestBase): + + def testBucket(self): + + boundaries = [10, 20, 30] + batch_sizes = [10, 8, 4, 2] + lengths = [8, 13, 25, 35] + + def build_dataset(sparse): + def _generator(): + # Produce 1 batch for each bucket + elements = [] + for batch_size, length in zip(batch_sizes, lengths): + record_len = length - 1 + for _ in range(batch_size): + elements.append([1] * record_len) + record_len = length + random.shuffle(elements) + for el in elements: + yield (_format_record(el, sparse),) + dataset = dataset_ops.Dataset.from_generator( + _generator, + (_get_record_type(sparse),), + (_get_record_shape(sparse),)) + if sparse: + dataset = dataset.map(lambda x: (_to_sparse_tensor(x),)) + return dataset + + def _test_bucket_by_padding(no_padding): + dataset = build_dataset(sparse=no_padding) + dataset = dataset.apply( + grouping.bucket_by_sequence_length( + _element_length_fn, + boundaries, + batch_sizes, + no_padding=no_padding)) + batch, = dataset.make_one_shot_iterator().get_next() + + with self.cached_session() as sess: + batches = [] + for _ in range(4): + batches.append(sess.run(batch)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(batch) + batch_sizes_val = [] + lengths_val = [] + for batch in batches: + shape = batch.dense_shape if no_padding else batch.shape + batch_size = shape[0] + length = shape[1] + batch_sizes_val.append(batch_size) + lengths_val.append(length) + sum_check = batch.values.sum() if no_padding else batch.sum() + self.assertEqual(sum_check, batch_size * length - 1) + self.assertEqual(sum(batch_sizes_val), sum(batch_sizes)) + self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val)) + self.assertEqual(sorted(lengths), sorted(lengths_val)) + + for no_padding in (True, False): + _test_bucket_by_padding(no_padding) + + def testPadToBoundary(self): + + boundaries = [10, 20, 30] + batch_sizes = [10, 8, 4, 2] + lengths = [8, 13, 25] + + def element_gen(): + # Produce 1 batch for each bucket + elements = [] + for batch_size, length in zip(batch_sizes[:-1], lengths): + for _ in range(batch_size): + elements.append([1] * length) + random.shuffle(elements) + for el in elements: + yield (el,) + for _ in range(batch_sizes[-1]): + el = [1] * (boundaries[-1] + 5) + yield (el,) + + element_len = lambda el: array_ops.shape(el)[0] + dataset = dataset_ops.Dataset.from_generator( + element_gen, (dtypes.int64,), ([None],)).apply( + grouping.bucket_by_sequence_length( + element_len, boundaries, batch_sizes, + pad_to_bucket_boundary=True)) + batch, = dataset.make_one_shot_iterator().get_next() + + with self.cached_session() as sess: + batches = [] + for _ in range(3): + batches.append(sess.run(batch)) + with self.assertRaisesOpError("bucket_boundaries"): + sess.run(batch) + batch_sizes_val = [] + lengths_val = [] + for batch in batches: + batch_size = batch.shape[0] + length = batch.shape[1] + batch_sizes_val.append(batch_size) + lengths_val.append(length) + batch_sizes = batch_sizes[:-1] + self.assertEqual(sum(batch_sizes_val), sum(batch_sizes)) + self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val)) + self.assertEqual([boundary - 1 for boundary in sorted(boundaries)], + sorted(lengths_val)) + + def testPadToBoundaryNoExtraneousPadding(self): + + boundaries = [3, 7, 11] + batch_sizes = [2, 2, 2, 2] + lengths = range(1, 11) + + def element_gen(): + for length in lengths: + yield ([1] * length,) + + element_len = lambda element: array_ops.shape(element)[0] + dataset = dataset_ops.Dataset.from_generator( + element_gen, (dtypes.int64,), ([None],)).apply( + grouping.bucket_by_sequence_length( + element_len, boundaries, batch_sizes, + pad_to_bucket_boundary=True)) + batch, = dataset.make_one_shot_iterator().get_next() + + with self.cached_session() as sess: + batches = [] + for _ in range(5): + batches.append(sess.run(batch)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(batch) + + self.assertAllEqual(batches[0], [[1, 0], + [1, 1]]) + self.assertAllEqual(batches[1], [[1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 0, 0]]) + self.assertAllEqual(batches[2], [[1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 1]]) + self.assertAllEqual(batches[3], [[1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) + self.assertAllEqual(batches[4], [[1, 1, 1, 1, 1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]) + + def testTupleElements(self): + + def build_dataset(sparse): + def _generator(): + text = [[1, 2, 3], [3, 4, 5, 6, 7], [1, 2], [8, 9, 0, 2, 3]] + label = [1, 2, 1, 2] + for x, y in zip(text, label): + yield (_format_record(x, sparse), y) + dataset = dataset_ops.Dataset.from_generator( + generator=_generator, + output_types=(_get_record_type(sparse), dtypes.int32), + output_shapes=(_get_record_shape(sparse), + tensor_shape.TensorShape([]))) + if sparse: + dataset = dataset.map(lambda x, y: (_to_sparse_tensor(x), y)) + return dataset + + def _test_tuple_elements_by_padding(no_padding): + dataset = build_dataset(sparse=no_padding) + dataset = dataset.apply(grouping.bucket_by_sequence_length( + element_length_func=_element_length_fn, + bucket_batch_sizes=[2, 2, 2], + bucket_boundaries=[0, 8], + no_padding=no_padding)) + shapes = dataset.output_shapes + self.assertEqual([None, None], shapes[0].as_list()) + self.assertEqual([None], shapes[1].as_list()) + + for no_padding in (True, False): + _test_tuple_elements_by_padding(no_padding) + + def testBucketSparse(self): + """Tests bucketing of sparse tensors (case where `no_padding` == True). + + Test runs on following dataset: + [ + [0], + [0, 1], + [0, 1, 2] + ... + [0, ..., max_len - 1] + ] + Sequences are bucketed by length and batched with + `batch_size` < `bucket_size`. + """ + + min_len = 0 + max_len = 100 + batch_size = 7 + bucket_size = 10 + + def _build_dataset(): + input_data = [range(i+1) for i in range(min_len, max_len)] + def generator_fn(): + for record in input_data: + yield _format_record(record, sparse=True) + dataset = dataset_ops.Dataset.from_generator( + generator=generator_fn, + output_types=_get_record_type(sparse=True)) + dataset = dataset.map(_to_sparse_tensor) + return dataset + + def _compute_expected_batches(): + """Computes expected batch outputs and stores in a set.""" + all_expected_sparse_tensors = set() + for bucket_start_len in range(min_len, max_len, bucket_size): + for batch_offset in range(0, bucket_size, batch_size): + batch_start_len = bucket_start_len + batch_offset + batch_end_len = min(batch_start_len + batch_size, + bucket_start_len + bucket_size) + expected_indices = [] + expected_values = [] + for length in range(batch_start_len, batch_end_len): + for val in range(length + 1): + expected_indices.append((length - batch_start_len, val)) + expected_values.append(val) + expected_sprs_tensor = (tuple(expected_indices), + tuple(expected_values)) + all_expected_sparse_tensors.add(expected_sprs_tensor) + return all_expected_sparse_tensors + + def _compute_batches(dataset): + """Computes actual batch outputs of dataset and stores in a set.""" + batch = dataset.make_one_shot_iterator().get_next() + all_sparse_tensors = set() + with self.cached_session() as sess: + with self.assertRaises(errors.OutOfRangeError): + while True: + output = sess.run(batch) + sprs_tensor = (tuple([tuple(idx) for idx in output.indices]), + tuple(output.values)) + all_sparse_tensors.add(sprs_tensor) + return all_sparse_tensors + + dataset = _build_dataset() + boundaries = range(min_len + bucket_size + 1, max_len, bucket_size) + dataset = dataset.apply(grouping.bucket_by_sequence_length( + _element_length_fn, + boundaries, + [batch_size] * (len(boundaries) + 1), + no_padding=True)) + batches = _compute_batches(dataset) + expected_batches = _compute_expected_batches() + self.assertEqual(batches, expected_batches) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/csv_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/csv_dataset_op_test.py new file mode 100644 index 0000000000..4ee1779710 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/csv_dataset_op_test.py @@ -0,0 +1,632 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for CsvDatasetOp.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gzip +import os +import string +import tempfile +import time +import zlib + +import numpy as np + +from tensorflow.python.client import session +from tensorflow.python.data.experimental.ops import error_ops +from tensorflow.python.data.experimental.ops import readers +from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.data.ops import readers as core_readers +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import test_util +from tensorflow.python.ops import parsing_ops +from tensorflow.python.platform import gfile +from tensorflow.python.platform import googletest +from tensorflow.python.platform import test + + +@test_util.run_all_in_graph_and_eager_modes +class CsvDatasetOpTest(test_base.DatasetTestBase): + + def _setup_files(self, inputs, linebreak='\n', compression_type=None): + filenames = [] + for i, ip in enumerate(inputs): + fn = os.path.join(self.get_temp_dir(), 'temp_%d.csv' % i) + contents = linebreak.join(ip).encode('utf-8') + if compression_type is None: + with open(fn, 'wb') as f: + f.write(contents) + elif compression_type == 'GZIP': + with gzip.GzipFile(fn, 'wb') as f: + f.write(contents) + elif compression_type == 'ZLIB': + contents = zlib.compress(contents) + with open(fn, 'wb') as f: + f.write(contents) + else: + raise ValueError('Unsupported compression_type', compression_type) + filenames.append(fn) + return filenames + + def _make_test_datasets(self, inputs, **kwargs): + # Test by comparing its output to what we could get with map->decode_csv + filenames = self._setup_files(inputs) + dataset_expected = core_readers.TextLineDataset(filenames) + dataset_expected = dataset_expected.map( + lambda l: parsing_ops.decode_csv(l, **kwargs)) + dataset_actual = readers.CsvDataset(filenames, **kwargs) + return (dataset_actual, dataset_expected) + + def _test_by_comparison(self, inputs, **kwargs): + """Checks that CsvDataset is equiv to TextLineDataset->map(decode_csv).""" + dataset_actual, dataset_expected = self._make_test_datasets( + inputs, **kwargs) + self.assertDatasetsEqual(dataset_actual, dataset_expected) + + def _verify_output_or_err(self, + dataset, + expected_output=None, + expected_err_re=None): + if expected_err_re is None: + # Verify that output is expected, without errors + nxt = self.getNext(dataset) + expected_output = [[ + v.encode('utf-8') if isinstance(v, str) else v for v in op + ] for op in expected_output] + for value in expected_output: + op = self.evaluate(nxt()) + self.assertAllEqual(op, value) + with self.assertRaises(errors.OutOfRangeError): + self.evaluate(nxt()) + else: + # Verify that OpError is produced as expected + with self.assertRaisesOpError(expected_err_re): + nxt = self.getNext(dataset) + while True: + try: + self.evaluate(nxt()) + except errors.OutOfRangeError: + break + + def _test_dataset( + self, + inputs, + expected_output=None, + expected_err_re=None, + linebreak='\n', + compression_type=None, # Used for both setup and parsing + **kwargs): + """Checks that elements produced by CsvDataset match expected output.""" + # Convert str type because py3 tf strings are bytestrings + filenames = self._setup_files(inputs, linebreak, compression_type) + kwargs['compression_type'] = compression_type + dataset = readers.CsvDataset(filenames, **kwargs) + self._verify_output_or_err(dataset, expected_output, expected_err_re) + + def testCsvDataset_requiredFields(self): + record_defaults = [[]] * 4 + inputs = [['1,2,3,4']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + def testCsvDataset_int(self): + record_defaults = [[0]] * 4 + inputs = [['1,2,3,4', '5,6,7,8']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + def testCsvDataset_float(self): + record_defaults = [[0.0]] * 4 + inputs = [['1.0,2.1,3.2,4.3', '5.4,6.5,7.6,8.7']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + def testCsvDataset_string(self): + record_defaults = [['']] * 4 + inputs = [['1.0,2.1,hello,4.3', '5.4,6.5,goodbye,8.7']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + def testCsvDataset_withEmptyFields(self): + record_defaults = [[0]] * 4 + inputs = [[',,,', '1,1,1,', ',2,2,2']] + self._test_dataset( + inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]], + record_defaults=record_defaults) + + def testCsvDataset_errWithUnquotedQuotes(self): + record_defaults = [['']] * 3 + inputs = [['1,2"3,4']] + self._test_dataset( + inputs, + expected_err_re='Unquoted fields cannot have quotes inside', + record_defaults=record_defaults) + + def testCsvDataset_errWithUnescapedQuotes(self): + record_defaults = [['']] * 3 + inputs = [['"a"b","c","d"']] + self._test_dataset( + inputs, + expected_err_re= + 'Quote inside a string has to be escaped by another quote', + record_defaults=record_defaults) + + def testCsvDataset_ignoreErrWithUnescapedQuotes(self): + record_defaults = [['']] * 3 + inputs = [['1,"2"3",4', '1,"2"3",4",5,5', 'a,b,"c"d"', 'e,f,g']] + filenames = self._setup_files(inputs) + dataset = readers.CsvDataset(filenames, record_defaults=record_defaults) + dataset = dataset.apply(error_ops.ignore_errors()) + self._verify_output_or_err(dataset, [['e', 'f', 'g']]) + + def testCsvDataset_ignoreErrWithUnquotedQuotes(self): + record_defaults = [['']] * 3 + inputs = [['1,2"3,4', 'a,b,c"d', '9,8"7,6,5', 'e,f,g']] + filenames = self._setup_files(inputs) + dataset = readers.CsvDataset(filenames, record_defaults=record_defaults) + dataset = dataset.apply(error_ops.ignore_errors()) + self._verify_output_or_err(dataset, [['e', 'f', 'g']]) + + def testCsvDataset_withNoQuoteDelimAndUnquotedQuotes(self): + record_defaults = [['']] * 3 + inputs = [['1,2"3,4']] + self._test_by_comparison( + inputs, record_defaults=record_defaults, use_quote_delim=False) + + def testCsvDataset_mixedTypes(self): + record_defaults = [ + constant_op.constant([], dtype=dtypes.int32), + constant_op.constant([], dtype=dtypes.float32), + constant_op.constant([], dtype=dtypes.string), + constant_op.constant([], dtype=dtypes.float64) + ] + inputs = [['1,2.1,3.2,4.3', '5,6.5,7.6,8.7']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + def testCsvDataset_withUseQuoteDelimFalse(self): + record_defaults = [['']] * 4 + inputs = [['1,2,"3,4"', '"5,6",7,8']] + self._test_by_comparison( + inputs, record_defaults=record_defaults, use_quote_delim=False) + + def testCsvDataset_withFieldDelim(self): + record_defaults = [[0]] * 4 + inputs = [['1:2:3:4', '5:6:7:8']] + self._test_by_comparison( + inputs, record_defaults=record_defaults, field_delim=':') + + def testCsvDataset_withNaValue(self): + record_defaults = [[0]] * 4 + inputs = [['1,NA,3,4', 'NA,6,7,8']] + self._test_by_comparison( + inputs, record_defaults=record_defaults, na_value='NA') + + def testCsvDataset_withSelectCols(self): + record_defaults = [['']] * 2 + inputs = [['1,2,3,4', '"5","6","7","8"']] + self._test_by_comparison( + inputs, record_defaults=record_defaults, select_cols=[1, 2]) + + def testCsvDataset_withSelectColsTooHigh(self): + record_defaults = [[0]] * 2 + inputs = [['1,2,3,4', '5,6,7,8']] + self._test_dataset( + inputs, + expected_err_re='Expect 2 fields but have 1 in record', + record_defaults=record_defaults, + select_cols=[3, 4]) + + def testCsvDataset_withOneCol(self): + record_defaults = [['NA']] + inputs = [['0', '', '2']] + self._test_dataset( + inputs, [['0'], ['NA'], ['2']], record_defaults=record_defaults) + + def testCsvDataset_withMultipleFiles(self): + record_defaults = [[0]] * 4 + inputs = [['1,2,3,4', '5,6,7,8'], ['5,6,7,8']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + def testCsvDataset_withLeadingAndTrailingSpaces(self): + record_defaults = [[0.0]] * 4 + inputs = [['0, 1, 2, 3']] + expected = [[0.0, 1.0, 2.0, 3.0]] + self._test_dataset(inputs, expected, record_defaults=record_defaults) + + def testCsvDataset_errorWithMissingDefault(self): + record_defaults = [[]] * 2 + inputs = [['0,']] + self._test_dataset( + inputs, + expected_err_re='Field 1 is required but missing in record!', + record_defaults=record_defaults) + + def testCsvDataset_errorWithFewerDefaultsThanFields(self): + record_defaults = [[0.0]] * 2 + inputs = [['0,1,2,3']] + self._test_dataset( + inputs, + expected_err_re='Expect 2 fields but have more in record', + record_defaults=record_defaults) + + def testCsvDataset_errorWithMoreDefaultsThanFields(self): + record_defaults = [[0.0]] * 5 + inputs = [['0,1,2,3']] + self._test_dataset( + inputs, + expected_err_re='Expect 5 fields but have 4 in record', + record_defaults=record_defaults) + + def testCsvDataset_withHeader(self): + record_defaults = [[0]] * 2 + inputs = [['col1,col2', '1,2']] + expected = [[1, 2]] + self._test_dataset( + inputs, + expected, + record_defaults=record_defaults, + header=True, + ) + + def testCsvDataset_withHeaderAndNoRecords(self): + record_defaults = [[0]] * 2 + inputs = [['col1,col2']] + expected = [] + self._test_dataset( + inputs, + expected, + record_defaults=record_defaults, + header=True, + ) + + def testCsvDataset_errorWithHeaderEmptyFile(self): + record_defaults = [[0]] * 2 + inputs = [[]] + expected_err_re = "Can't read header of file" + self._test_dataset( + inputs, + expected_err_re=expected_err_re, + record_defaults=record_defaults, + header=True, + ) + + def testCsvDataset_withEmptyFile(self): + record_defaults = [['']] * 2 + inputs = [['']] # Empty file + self._test_dataset( + inputs, expected_output=[], record_defaults=record_defaults) + + def testCsvDataset_errorWithEmptyRecord(self): + record_defaults = [['']] * 2 + inputs = [['', '1,2']] # First record is empty + self._test_dataset( + inputs, + expected_err_re='Expect 2 fields but have 1 in record', + record_defaults=record_defaults) + + def testCsvDataset_withChainedOps(self): + # Testing that one dataset can create multiple iterators fine. + # `repeat` creates multiple iterators from the same C++ Dataset. + record_defaults = [[0]] * 4 + inputs = [['1,,3,4', '5,6,,8']] + ds_actual, ds_expected = self._make_test_datasets( + inputs, record_defaults=record_defaults) + self.assertDatasetsEqual( + ds_actual.repeat(5).prefetch(1), + ds_expected.repeat(5).prefetch(1)) + + def testCsvDataset_withTypeDefaults(self): + # Testing using dtypes as record_defaults for required fields + record_defaults = [dtypes.float32, [0.0]] + inputs = [['1.0,2.0', '3.0,4.0']] + self._test_dataset( + inputs, + [[1.0, 2.0], [3.0, 4.0]], + record_defaults=record_defaults, + ) + + def testMakeCsvDataset_fieldOrder(self): + data = [[ + '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19', + '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19' + ]] + file_path = self._setup_files(data) + + ds = readers.make_csv_dataset( + file_path, batch_size=1, shuffle=False, num_epochs=1) + nxt = self.getNext(ds) + + result = list(self.evaluate(nxt()).values()) + + self.assertEqual(result, sorted(result)) + +## The following tests exercise parsing logic for quoted fields + + def testCsvDataset_withQuoted(self): + record_defaults = [['']] * 4 + inputs = [['"a","b","c :)","d"', '"e","f","g :(","h"']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + def testCsvDataset_withOneColAndQuotes(self): + record_defaults = [['']] + inputs = [['"0"', '"1"', '"2"']] + self._test_dataset( + inputs, [['0'], ['1'], ['2']], record_defaults=record_defaults) + + def testCsvDataset_withNewLine(self): + # In this case, we expect it to behave differently from + # TextLineDataset->map(decode_csv) since that flow has bugs + record_defaults = [['']] * 4 + inputs = [['a,b,"""c""\n0","d\ne"', 'f,g,h,i']] + expected = [['a', 'b', '"c"\n0', 'd\ne'], ['f', 'g', 'h', 'i']] + self._test_dataset(inputs, expected, record_defaults=record_defaults) + + def testCsvDataset_withNewLineInUnselectedCol(self): + record_defaults = [['']] + inputs = [['1,"2\n3",4', '5,6,7']] + self._test_dataset( + inputs, + expected_output=[['1'], ['5']], + record_defaults=record_defaults, + select_cols=[0]) + + def testCsvDataset_withMultipleNewLines(self): + # In this case, we expect it to behave differently from + # TextLineDataset->map(decode_csv) since that flow has bugs + record_defaults = [['']] * 4 + inputs = [['a,"b\n\nx","""c""\n \n0","d\ne"', 'f,g,h,i']] + expected = [['a', 'b\n\nx', '"c"\n \n0', 'd\ne'], ['f', 'g', 'h', 'i']] + self._test_dataset(inputs, expected, record_defaults=record_defaults) + + def testCsvDataset_errorWithTerminateMidRecord(self): + record_defaults = [['']] * 4 + inputs = [['a,b,c,"a']] + self._test_dataset( + inputs, + expected_err_re= + 'Reached end of file without closing quoted field in record', + record_defaults=record_defaults) + + def testCsvDataset_withEscapedQuotes(self): + record_defaults = [['']] * 4 + inputs = [['1.0,2.1,"she said: ""hello""",4.3', '5.4,6.5,goodbye,8.7']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + +## Testing that parsing works with all buffer sizes, quoted/unquoted fields, +## and different types of line breaks + + def testCsvDataset_withInvalidBufferSize(self): + record_defaults = [['']] * 4 + inputs = [['a,b,c,d']] + self._test_dataset( + inputs, + expected_err_re='buffer_size should be positive', + record_defaults=record_defaults, + buffer_size=0) + + def _test_dataset_on_buffer_sizes(self, + inputs, + expected, + linebreak, + record_defaults, + compression_type=None, + num_sizes_to_test=20): + # Testing reading with a range of buffer sizes that should all work. + for i in list(range(1, 1 + num_sizes_to_test)) + [None]: + self._test_dataset( + inputs, + expected, + linebreak=linebreak, + compression_type=compression_type, + record_defaults=record_defaults, + buffer_size=i) + + def testCsvDataset_withLF(self): + record_defaults = [['NA']] * 3 + inputs = [['abc,def,ghi', '0,1,2', ',,']] + expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']] + self._test_dataset_on_buffer_sizes( + inputs, expected, linebreak='\n', record_defaults=record_defaults) + + def testCsvDataset_withCR(self): + # Test that when the line separator is '\r', parsing works with all buffer + # sizes + record_defaults = [['NA']] * 3 + inputs = [['abc,def,ghi', '0,1,2', ',,']] + expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']] + self._test_dataset_on_buffer_sizes( + inputs, expected, linebreak='\r', record_defaults=record_defaults) + + def testCsvDataset_withCRLF(self): + # Test that when the line separator is '\r\n', parsing works with all buffer + # sizes + record_defaults = [['NA']] * 3 + inputs = [['abc,def,ghi', '0,1,2', ',,']] + expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']] + self._test_dataset_on_buffer_sizes( + inputs, expected, linebreak='\r\n', record_defaults=record_defaults) + + def testCsvDataset_withBufferSizeAndQuoted(self): + record_defaults = [['NA']] * 3 + inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']] + expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'], + ['NA', 'NA', 'NA']] + self._test_dataset_on_buffer_sizes( + inputs, expected, linebreak='\n', record_defaults=record_defaults) + + def testCsvDataset_withCRAndQuoted(self): + # Test that when the line separator is '\r', parsing works with all buffer + # sizes + record_defaults = [['NA']] * 3 + inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']] + expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'], + ['NA', 'NA', 'NA']] + self._test_dataset_on_buffer_sizes( + inputs, expected, linebreak='\r', record_defaults=record_defaults) + + def testCsvDataset_withCRLFAndQuoted(self): + # Test that when the line separator is '\r\n', parsing works with all buffer + # sizes + record_defaults = [['NA']] * 3 + inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']] + expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'], + ['NA', 'NA', 'NA']] + self._test_dataset_on_buffer_sizes( + inputs, expected, linebreak='\r\n', record_defaults=record_defaults) + + def testCsvDataset_withGzipCompressionType(self): + record_defaults = [['NA']] * 3 + inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']] + expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'], + ['NA', 'NA', 'NA']] + self._test_dataset_on_buffer_sizes( + inputs, + expected, + linebreak='\r\n', + compression_type='GZIP', + record_defaults=record_defaults) + + def testCsvDataset_withZlibCompressionType(self): + record_defaults = [['NA']] * 3 + inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']] + expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'], + ['NA', 'NA', 'NA']] + self._test_dataset_on_buffer_sizes( + inputs, + expected, + linebreak='\r\n', + compression_type='ZLIB', + record_defaults=record_defaults) + + def testCsvDataset_withScalarDefaults(self): + record_defaults = [constant_op.constant(0, dtype=dtypes.int64)] * 4 + inputs = [[',,,', '1,1,1,', ',2,2,2']] + self._test_dataset( + inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]], + record_defaults=record_defaults) + + def testCsvDataset_with2DDefaults(self): + record_defaults = [constant_op.constant([[0]], dtype=dtypes.int64)] * 4 + inputs = [[',,,', '1,1,1,', ',2,2,2']] + + if context.executing_eagerly(): + err_spec = errors.InvalidArgumentError, ( + 'Each record default should be at ' + 'most rank 1.') + else: + err_spec = ValueError, 'Shape must be at most rank 1 but is rank 2' + + with self.assertRaisesWithPredicateMatch(*err_spec): + self._test_dataset( + inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]], + record_defaults=record_defaults) + + +class CsvDatasetBenchmark(test.Benchmark): + """Benchmarks for the various ways of creating a dataset from CSV files. + """ + FLOAT_VAL = '1.23456E12' + STR_VAL = string.ascii_letters * 10 + + def _setUp(self, str_val): + # Since this isn't test.TestCase, have to manually create a test dir + gfile.MakeDirs(googletest.GetTempDir()) + self._temp_dir = tempfile.mkdtemp(dir=googletest.GetTempDir()) + + self._num_cols = [4, 64, 256] + self._num_per_iter = 5000 + self._filenames = [] + for n in self._num_cols: + fn = os.path.join(self._temp_dir, 'file%d.csv' % n) + with open(fn, 'wb') as f: + # Just write 100 rows and use `repeat`... Assumes the cost + # of creating an iterator is not significant + row = ','.join([str_val for _ in range(n)]) + f.write('\n'.join([row for _ in range(100)])) + self._filenames.append(fn) + + def _tearDown(self): + gfile.DeleteRecursively(self._temp_dir) + + def _runBenchmark(self, dataset, num_cols, prefix): + dataset = dataset.skip(self._num_per_iter - 1) + deltas = [] + for _ in range(10): + next_element = dataset.make_one_shot_iterator().get_next() + with session.Session() as sess: + start = time.time() + # NOTE: This depends on the underlying implementation of skip, to have + # the net effect of calling `GetNext` num_per_iter times on the + # input dataset. We do it this way (instead of a python for loop, or + # batching N inputs in one iter) so that the overhead from session.run + # or batch doesn't dominate. If we eventually optimize skip, this has + # to change. + sess.run(next_element) + end = time.time() + deltas.append(end - start) + # Median wall time per CSV record read and decoded + median_wall_time = np.median(deltas) / self._num_per_iter + print('%s num_cols: %d Median wall time: %f' % (prefix, num_cols, + median_wall_time)) + self.report_benchmark( + iters=self._num_per_iter, + wall_time=median_wall_time, + name='%s_with_cols_%d' % (prefix, num_cols)) + + def benchmarkMapWithFloats(self): + self._setUp(self.FLOAT_VAL) + for i in range(len(self._filenames)): + num_cols = self._num_cols[i] + kwargs = {'record_defaults': [[0.0]] * num_cols} + dataset = core_readers.TextLineDataset(self._filenames[i]).repeat() + dataset = dataset.map(lambda l: parsing_ops.decode_csv(l, **kwargs)) # pylint: disable=cell-var-from-loop + self._runBenchmark(dataset, num_cols, 'csv_float_map_decode_csv') + self._tearDown() + + def benchmarkMapWithStrings(self): + self._setUp(self.STR_VAL) + for i in range(len(self._filenames)): + num_cols = self._num_cols[i] + kwargs = {'record_defaults': [['']] * num_cols} + dataset = core_readers.TextLineDataset(self._filenames[i]).repeat() + dataset = dataset.map(lambda l: parsing_ops.decode_csv(l, **kwargs)) # pylint: disable=cell-var-from-loop + self._runBenchmark(dataset, num_cols, 'csv_strings_map_decode_csv') + self._tearDown() + + def benchmarkCsvDatasetWithFloats(self): + self._setUp(self.FLOAT_VAL) + for i in range(len(self._filenames)): + num_cols = self._num_cols[i] + kwargs = {'record_defaults': [[0.0]] * num_cols} + dataset = core_readers.TextLineDataset(self._filenames[i]).repeat() + dataset = readers.CsvDataset(self._filenames[i], **kwargs).repeat() # pylint: disable=cell-var-from-loop + self._runBenchmark(dataset, num_cols, 'csv_float_fused_dataset') + self._tearDown() + + def benchmarkCsvDatasetWithStrings(self): + self._setUp(self.STR_VAL) + for i in range(len(self._filenames)): + num_cols = self._num_cols[i] + kwargs = {'record_defaults': [['']] * num_cols} + dataset = core_readers.TextLineDataset(self._filenames[i]).repeat() + dataset = readers.CsvDataset(self._filenames[i], **kwargs).repeat() # pylint: disable=cell-var-from-loop + self._runBenchmark(dataset, num_cols, 'csv_strings_fused_dataset') + self._tearDown() + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/dataset_constructor_op_test.py b/tensorflow/python/data/experimental/kernel_tests/dataset_constructor_op_test.py new file mode 100644 index 0000000000..3fc7157bc5 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/dataset_constructor_op_test.py @@ -0,0 +1,71 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.experimental.ops import batching +from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import nest +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class DatasetConstructorTest(test_base.DatasetTestBase): + + def testRestructureDataset(self): + components = (array_ops.placeholder(dtypes.int32), + (array_ops.placeholder(dtypes.int32, shape=[None]), + array_ops.placeholder(dtypes.int32, shape=[20, 30]))) + dataset = dataset_ops.Dataset.from_tensors(components) + + i32 = dtypes.int32 + + test_cases = [((i32, i32, i32), None), + (((i32, i32), i32), None), + ((i32, i32, i32), (None, None, None)), + ((i32, i32, i32), ([17], [17], [20, 30]))] + + for new_types, new_shape_lists in test_cases: + # pylint: disable=protected-access + new = batching._RestructuredDataset(dataset, new_types, new_shape_lists) + # pylint: enable=protected-access + self.assertEqual(new_types, new.output_types) + if new_shape_lists is not None: + for expected_shape_list, shape in zip( + nest.flatten(new_shape_lists), nest.flatten(new.output_shapes)): + if expected_shape_list is None: + self.assertIs(None, shape.ndims) + else: + self.assertEqual(expected_shape_list, shape.as_list()) + + fail_cases = [((i32, dtypes.int64, i32), None), + ((i32, i32, i32, i32), None), + ((i32, i32, i32), ((None, None), None)), + ((i32, i32, i32), (None, None, None, None)), + ((i32, i32, i32), (None, [None], [21, 30]))] + + for new_types, new_shape_lists in fail_cases: + with self.assertRaises(ValueError): + # pylint: disable=protected-access + new = batching._RestructuredDataset(dataset, new_types, new_shape_lists) + # pylint: enable=protected-access + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/dataset_serialization_test_base.py b/tensorflow/python/data/experimental/kernel_tests/dataset_serialization_test_base.py new file mode 100644 index 0000000000..7f435b8239 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/dataset_serialization_test_base.py @@ -0,0 +1,692 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Base class for testing serializable datasets.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import numpy as np + +from tensorflow.python.data.experimental.ops import iterator_ops as contrib_iterator_ops +from tensorflow.python.data.ops import iterator_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import lookup_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test +from tensorflow.python.training import checkpoint_management +from tensorflow.python.training import saver as saver_lib +from tensorflow.python.util import nest + + +def remove_variants(get_next_op): + # TODO(b/72408568): Remove this once session.run can get + # variant tensors. + """Remove variants from a nest structure, so sess.run will execute.""" + + def _remove_variant(x): + if isinstance(x, ops.Tensor) and x.dtype == dtypes.variant: + return () + else: + return x + + return nest.map_structure(_remove_variant, get_next_op) + + +class DatasetSerializationTestBase(test.TestCase): + """Base class for testing serializable datasets.""" + + def tearDown(self): + self._delete_ckpt() + + # TODO(b/72657739): Remove sparse_tensor argument, which is to test the + # (deprecated) saveable `SparseTensorSliceDataset`, once the API + # `from_sparse_tensor_slices()`and related tests are deleted. + def run_core_tests(self, ds_fn1, ds_fn2, num_outputs, sparse_tensors=False): + """Runs the core tests. + + Args: + ds_fn1: 0-argument function that returns a Dataset. + ds_fn2: 0-argument function that returns a Dataset different from + ds_fn1. If None, verify_restore_in_modified_graph test is not run. + num_outputs: Total number of outputs expected from this Dataset. + sparse_tensors: Whether dataset is built from SparseTensor(s). + + Raises: + AssertionError if any test fails. + """ + self.verify_unused_iterator( + ds_fn1, num_outputs, sparse_tensors=sparse_tensors) + self.verify_fully_used_iterator( + ds_fn1, num_outputs, sparse_tensors=sparse_tensors) + self.verify_exhausted_iterator( + ds_fn1, num_outputs, sparse_tensors=sparse_tensors) + self.verify_init_before_restore( + ds_fn1, num_outputs, sparse_tensors=sparse_tensors) + self.verify_multiple_breaks( + ds_fn1, num_outputs, sparse_tensors=sparse_tensors) + self.verify_reset_restored_iterator( + ds_fn1, num_outputs, sparse_tensors=sparse_tensors) + self.verify_restore_in_empty_graph( + ds_fn1, num_outputs, sparse_tensors=sparse_tensors) + if ds_fn2: + self.verify_restore_in_modified_graph( + ds_fn1, ds_fn2, num_outputs, sparse_tensors=sparse_tensors) + + def verify_unused_iterator(self, + ds_fn, + num_outputs, + sparse_tensors=False, + verify_exhausted=True): + """Verifies that saving and restoring an unused iterator works. + + Args: + ds_fn: See `run_core_tests`. + num_outputs: See `run_core_tests`. + sparse_tensors: See `run_core_tests`. + verify_exhausted: See `gen_outputs`. + + Raises: + AssertionError if any test fails. + """ + self.verify_run_with_breaks( + ds_fn, [0], + num_outputs, + sparse_tensors=sparse_tensors, + verify_exhausted=verify_exhausted) + + def verify_fully_used_iterator(self, ds_fn, num_outputs, + sparse_tensors=False): + """Verifies that saving and restoring a fully used iterator works. + + Note that this only checks saving and restoring an iterator from which + `num_outputs` items have been produced but does not check for an + exhausted iterator, i.e., one from which an OutOfRange error has been + returned. + + Args: + ds_fn: See `run_core_tests`. + num_outputs: See `run_core_tests`. + sparse_tensors: See `run_core_tests`. + + Raises: + AssertionError if test fails. + """ + self.verify_run_with_breaks( + ds_fn, [num_outputs], num_outputs, sparse_tensors=sparse_tensors) + + def verify_exhausted_iterator(self, ds_fn, num_outputs, sparse_tensors=False): + """Verifies that saving and restoring an exhausted iterator works. + + An exhausted iterator is one which has returned an OutOfRange error. + + Args: + ds_fn: See `run_core_tests`. + num_outputs: See `run_core_tests`. + sparse_tensors: See `run_core_tests`. + + Raises: + AssertionError if any test fails. + """ + self.gen_outputs( + ds_fn, [], + num_outputs, + verify_exhausted=True, + sparse_tensors=sparse_tensors) + actual = self.gen_outputs( + ds_fn, [], + 0, + ckpt_saved=True, + verify_exhausted=True, + sparse_tensors=sparse_tensors) + self.assertEqual(len(actual), 0) + + def verify_init_before_restore(self, + ds_fn, + num_outputs, + sparse_tensors=False, + verify_exhausted=True): + """Verifies that restoring into an already initialized iterator works. + + Args: + ds_fn: See `run_core_tests`. + num_outputs: See `run_core_tests`. + sparse_tensors: See `run_core_tests`. + verify_exhausted: See `gen_outputs`. + + Raises: + AssertionError if any test fails. + """ + self.verify_run_with_breaks( + ds_fn, + self.gen_break_points(num_outputs), + num_outputs, + init_before_restore=True, + sparse_tensors=sparse_tensors, + verify_exhausted=verify_exhausted) + + def verify_multiple_breaks(self, + ds_fn, + num_outputs, + num_breaks=10, + sparse_tensors=False, + verify_exhausted=True): + """Attempts to save/restore at multiple break points. + + Args: + ds_fn: See `run_core_tests`. + num_outputs: See `run_core_tests`. + num_breaks: The number of break points. These are uniformly spread in + [0, num_outputs] both inclusive. + sparse_tensors: See `run_core_tests`. + verify_exhausted: See `gen_outputs`. + + Raises: + AssertionError if any test fails. + """ + self.verify_run_with_breaks( + ds_fn, + self.gen_break_points(num_outputs, num_breaks), + num_outputs, + sparse_tensors=sparse_tensors, + verify_exhausted=verify_exhausted) + + def verify_reset_restored_iterator(self, + ds_fn, + num_outputs, + break_point=None, + sparse_tensors=False, + verify_exhausted=True): + """Attempts to re-initialize a restored iterator. + + This is useful when restoring a training checkpoint during validation. + + Args: + ds_fn: See `run_core_tests`. + num_outputs: See `run_core_tests`. + break_point: Break point. Optional. Defaults to num_outputs/2. + sparse_tensors: See `run_core_tests`. + verify_exhausted: See `gen_outputs`. + + Raises: + AssertionError if any test fails. + """ + break_point = num_outputs // 2 if not break_point else break_point + + # Collect ground truth containing all outputs. + expected = self.gen_outputs( + ds_fn, [], + num_outputs, + sparse_tensors=sparse_tensors, + verify_exhausted=verify_exhausted) + + # Skip some items and save checkpoint. + self.gen_outputs( + ds_fn, [], + break_point, + sparse_tensors=sparse_tensors, + verify_exhausted=False) + + actual = [] + # Restore from checkpoint and then run init_op. + with ops.Graph().as_default() as g: + saver = self._import_meta_graph() + init_op, get_next_op = self._get_iterator_ops_from_collection( + ds_fn, sparse_tensors=sparse_tensors) + get_next_op = remove_variants(get_next_op) + with self.session(graph=g) as sess: + self._restore(saver, sess) + self._initialize(init_op, sess) + for _ in range(num_outputs): + actual.append(sess.run(get_next_op)) + if verify_exhausted: + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + self.match(expected, actual) + + def verify_restore_in_modified_graph(self, + ds_fn1, + ds_fn2, + num_outputs, + break_point=None, + sparse_tensors=False, + verify_exhausted=True): + """Attempts to restore an iterator in a modified graph. + + Builds an input pipeline using ds_fn1, runs it for `break_point` steps + and saves a checkpoint. Then builds a new graph using ds_fn2, restores + the checkpoint from ds_fn1 and verifies that the restore is successful. + + Args: + ds_fn1: See `run_core_tests`. + ds_fn2: See `run_core_tests`. + num_outputs: See `run_core_tests`. + break_point: Break point. Optional. Defaults to num_outputs/2. + sparse_tensors: See `run_core_tests`. + verify_exhausted: See `gen_outputs`. + + Raises: + AssertionError if any test fails. + """ + break_point = num_outputs // 2 if not break_point else break_point + + # Skip `break_point` items and store the remaining produced from ds_fn1 + # in `expected`. + self.gen_outputs( + ds_fn1, [], + break_point, + sparse_tensors=sparse_tensors, + verify_exhausted=False) + expected = self.gen_outputs( + ds_fn1, [], + num_outputs - break_point, + ckpt_saved=True, + sparse_tensors=sparse_tensors, + verify_exhausted=verify_exhausted) + + # Generate `break_point` items from ds_fn1 and save checkpoint. + self.gen_outputs( + ds_fn1, [], + break_point, + sparse_tensors=sparse_tensors, + verify_exhausted=False) + + actual = [] + # Build graph for ds_fn2 but load checkpoint for ds_fn1. + with ops.Graph().as_default() as g: + _, get_next_op, saver = self._build_graph( + ds_fn2, sparse_tensors=sparse_tensors) + get_next_op = remove_variants(get_next_op) + with self.session(graph=g) as sess: + self._restore(saver, sess) + for _ in range(num_outputs - break_point): + actual.append(sess.run(get_next_op)) + if verify_exhausted: + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + self.match(expected, actual) + + def verify_restore_in_empty_graph(self, + ds_fn, + num_outputs, + break_point=None, + sparse_tensors=False, + verify_exhausted=True): + """Attempts to restore an iterator in an empty graph. + + Builds an input pipeline using ds_fn, runs it for `break_point` steps + and saves a checkpoint. Then builds a new empty graph, restores + the checkpoint from ds_fn and verifies that the restore is successful. + + Args: + ds_fn: See `run_core_tests`. + num_outputs: See `run_core_tests`. + break_point: Break point. Optional. Defaults to num_outputs/2. + sparse_tensors: See `run_core_tests`. + verify_exhausted: See `gen_outputs`. + + Raises: + AssertionError if any test fails. + """ + break_point = num_outputs // 2 if not break_point else break_point + + # Skip `break_point` items and store the remaining produced from ds_fn + # in `expected`. + self.gen_outputs( + ds_fn, [], + break_point, + sparse_tensors=sparse_tensors, + verify_exhausted=False) + expected = self.gen_outputs( + ds_fn, [], + num_outputs - break_point, + ckpt_saved=True, + sparse_tensors=sparse_tensors, + verify_exhausted=verify_exhausted) + + # Generate `break_point` items from ds_fn and save checkpoint. + self.gen_outputs( + ds_fn, [], + break_point, + sparse_tensors=sparse_tensors, + verify_exhausted=False) + + actual = [] + # Build an empty graph but load checkpoint for ds_fn. + with ops.Graph().as_default() as g: + get_next_op, saver = self._build_empty_graph( + ds_fn, sparse_tensors=sparse_tensors) + get_next_op = remove_variants(get_next_op) + with self.session(graph=g) as sess: + self._restore(saver, sess) + for _ in range(num_outputs - break_point): + actual.append(sess.run(get_next_op)) + if verify_exhausted: + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + self.match(expected, actual) + + def verify_error_on_save(self, + ds_fn, + num_outputs, + error, + break_point=None, + sparse_tensors=False): + """Attempts to save a non-saveable iterator. + + Args: + ds_fn: See `run_core_tests`. + num_outputs: See `run_core_tests`. + error: Declared error when trying to save iterator. + break_point: Break point. Optional. Defaults to num_outputs/2. + sparse_tensors: See `run_core_tests`. + + Raises: + AssertionError if any test fails. + """ + + break_point = num_outputs // 2 if not break_point else break_point + with ops.Graph().as_default() as g: + init_op, get_next_op, saver = self._build_graph( + ds_fn, sparse_tensors=sparse_tensors) + get_next_op = remove_variants(get_next_op) + with self.session(graph=g) as sess: + self._initialize(init_op, sess) + for _ in range(break_point): + sess.run(get_next_op) + with self.assertRaises(error): + self._save(sess, saver) + + def verify_run_with_breaks(self, + ds_fn, + break_points, + num_outputs, + init_before_restore=False, + sparse_tensors=False, + verify_exhausted=True): + """Verifies that ds_fn() produces the same outputs with and without breaks. + + 1. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it + *without* stopping at break points. + 2. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it + with stopping at break points. + + Deep matches outputs from 1 and 2. + + Args: + ds_fn: See `gen_outputs`. + break_points: See `gen_outputs`. + num_outputs: See `gen_outputs`. + init_before_restore: See `gen_outputs`. + sparse_tensors: See `run_core_tests`. + verify_exhausted: See `gen_outputs`. + + Raises: + AssertionError if any test fails. + """ + expected = self.gen_outputs( + ds_fn, [], + num_outputs, + init_before_restore=init_before_restore, + sparse_tensors=sparse_tensors, + verify_exhausted=verify_exhausted) + + actual = self.gen_outputs( + ds_fn, + break_points, + num_outputs, + init_before_restore=init_before_restore, + sparse_tensors=sparse_tensors, + verify_exhausted=verify_exhausted) + + self.match(expected, actual) + + def gen_outputs(self, + ds_fn, + break_points, + num_outputs, + ckpt_saved=False, + init_before_restore=False, + sparse_tensors=False, + verify_exhausted=True, + save_checkpoint_at_end=True): + """Generates elements from input dataset while stopping at break points. + + Produces `num_outputs` outputs and saves the state of the iterator in the + Saver checkpoint. + + Args: + ds_fn: 0-argument function that returns the dataset. + break_points: A list of integers. For each `break_point` in + `break_points`, we produce outputs till `break_point` number of items + have been produced and then checkpoint the state. The current graph + and session are destroyed and a new graph and session are used to + produce outputs till next checkpoint or till `num_outputs` elements + have been produced. `break_point` must be <= `num_outputs`. + num_outputs: The total number of outputs to produce from the iterator. + ckpt_saved: Whether a checkpoint already exists. If False, we build the + graph from ds_fn. + init_before_restore: Whether init should be called before saver.restore. + This is just so that we can verify that restoring an already initialized + iterator works. + sparse_tensors: Whether dataset is built from SparseTensor(s). + verify_exhausted: Whether to verify that the iterator has been exhausted + after producing `num_outputs` elements. + save_checkpoint_at_end: Whether to save a checkpoint after producing all + outputs. If False, checkpoints are saved each break point but not at the + end. Note that checkpoints overwrite each other so there is always only + a single checkpoint available. Defaults to True. + + Returns: + A list of `num_outputs` items. + """ + outputs = [] + + def get_ops(): + if ckpt_saved: + saver = self._import_meta_graph() + init_op, get_next_op = self._get_iterator_ops_from_collection( + ds_fn, sparse_tensors=sparse_tensors) + else: + init_op, get_next_op, saver = self._build_graph( + ds_fn, sparse_tensors=sparse_tensors) + return init_op, get_next_op, saver + + for i in range(len(break_points) + 1): + with ops.Graph().as_default() as g: + init_op, get_next_op, saver = get_ops() + get_next_op = remove_variants(get_next_op) + with self.session(graph=g) as sess: + if ckpt_saved: + if init_before_restore: + self._initialize(init_op, sess) + self._restore(saver, sess) + else: + self._initialize(init_op, sess) + start = break_points[i - 1] if i > 0 else 0 + end = break_points[i] if i < len(break_points) else num_outputs + num_iters = end - start + for _ in range(num_iters): + outputs.append(sess.run(get_next_op)) + if i == len(break_points) and verify_exhausted: + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + if save_checkpoint_at_end or i < len(break_points): + self._save(sess, saver) + ckpt_saved = True + + return outputs + + def match(self, expected, actual): + """Matches nested structures. + + Recursively matches shape and values of `expected` and `actual`. + Handles scalars, numpy arrays and other python sequence containers + e.g. list, dict. + + Args: + expected: Nested structure 1. + actual: Nested structure 2. + + Raises: + AssertionError if matching fails. + """ + if isinstance(expected, np.ndarray): + expected = expected.tolist() + if isinstance(actual, np.ndarray): + actual = actual.tolist() + self.assertEqual(type(expected), type(actual)) + + if nest.is_sequence(expected): + self.assertEqual(len(expected), len(actual)) + if isinstance(expected, dict): + for key1, key2 in zip(sorted(expected), sorted(actual)): + self.assertEqual(key1, key2) + self.match(expected[key1], actual[key2]) + else: + for item1, item2 in zip(expected, actual): + self.match(item1, item2) + else: + self.assertEqual(expected, actual) + + def does_not_match(self, expected, actual): + with self.assertRaises(AssertionError): + self.match(expected, actual) + + def gen_break_points(self, num_outputs, num_samples=10): + """Generates `num_samples` breaks points in [0, num_outputs].""" + return np.linspace(0, num_outputs, num_samples, dtype=int) + + def _build_graph(self, ds_fn, sparse_tensors=False): + iterator = ds_fn().make_initializable_iterator() + + saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + init_op = iterator.initializer + if sparse_tensors: + get_next = sparse_tensor.SparseTensor(*iterator.get_next()) + else: + get_next = iterator.get_next() + self._add_iterator_ops_to_collection(init_op, get_next, ds_fn, + sparse_tensors) + saver = saver_lib.Saver(allow_empty=True) + return init_op, get_next, saver + + def _build_empty_graph(self, ds_fn, sparse_tensors=False): + iterator = iterator_ops.Iterator.from_structure( + self._get_output_types(ds_fn), + output_shapes=self._get_output_shapes(ds_fn), + output_classes=self._get_output_classes(ds_fn)) + saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + if sparse_tensors: + get_next = sparse_tensor.SparseTensor(*iterator.get_next()) + else: + get_next = iterator.get_next() + saver = saver_lib.Saver(allow_empty=True) + return get_next, saver + + def _add_iterator_ops_to_collection(self, + init_op, + get_next, + ds_fn, + sparse_tensors=False): + ops.add_to_collection("iterator_ops", init_op) + # `get_next` may be a tuple e.g. in TensorSliceDataset. Since Collections + # do not support tuples we flatten the tensors and restore the shape in + # `_get_iterator_ops_from_collection`. + if sparse_tensors: # specific for deprecated `from_sparse_tensor_slices`. + ops.add_to_collection("iterator_ops", get_next.indices) + ops.add_to_collection("iterator_ops", get_next.values) + ops.add_to_collection("iterator_ops", get_next.dense_shape) + return + + get_next_list = nest.flatten(get_next) + for i, output_class in enumerate( + nest.flatten(self._get_output_classes(ds_fn))): + if output_class is sparse_tensor.SparseTensor: + ops.add_to_collection("iterator_ops", get_next_list[i].indices) + ops.add_to_collection("iterator_ops", get_next_list[i].values) + ops.add_to_collection("iterator_ops", get_next_list[i].dense_shape) + else: + ops.add_to_collection("iterator_ops", get_next_list[i]) + + def _get_iterator_ops_from_collection(self, ds_fn, sparse_tensors=False): + all_ops = ops.get_collection("iterator_ops") + if sparse_tensors: # specific for deprecated `from_sparse_tensor_slices`. + init_op, indices, values, dense_shape = all_ops + return init_op, sparse_tensor.SparseTensor(indices, values, dense_shape) + get_next_list = [] + i = 1 + for output_class in nest.flatten(self._get_output_classes(ds_fn)): + if output_class is sparse_tensor.SparseTensor: + indices, values, dense_shape = all_ops[i:i + 3] + i += 3 + get_next_list.append( + sparse_tensor.SparseTensor(indices, values, dense_shape)) + else: + get_next_list.append(all_ops[i]) + i += 1 + return all_ops[0], nest.pack_sequence_as( + self._get_output_types(ds_fn), get_next_list) + + def _get_output_types(self, ds_fn): + with ops.Graph().as_default(): + return ds_fn().output_types + + def _get_output_shapes(self, ds_fn): + with ops.Graph().as_default(): + return ds_fn().output_shapes + + def _get_output_classes(self, ds_fn): + with ops.Graph().as_default(): + return ds_fn().output_classes + + def _ckpt_path(self): + return os.path.join(self.get_temp_dir(), "iterator") + + def _latest_ckpt(self): + return checkpoint_management.latest_checkpoint(self.get_temp_dir()) + + def _save(self, sess, saver): + saver.save(sess, self._ckpt_path()) + + def _restore(self, saver, sess): + sess.run(lookup_ops.tables_initializer()) + saver.restore(sess, self._latest_ckpt()) + + def _initialize(self, init_op, sess): + sess.run(variables.global_variables_initializer()) + sess.run(lookup_ops.tables_initializer()) + sess.run(init_op) + + def _import_meta_graph(self): + meta_file_path = self._ckpt_path() + ".meta" + return saver_lib.import_meta_graph(meta_file_path) + + def _delete_ckpt(self): + # Remove all checkpoint files. + prefix = self._ckpt_path() + pattern = prefix + "*" + files = gfile.Glob(pattern) + map(gfile.Remove, files) diff --git a/tensorflow/python/data/experimental/kernel_tests/directed_interleave_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/directed_interleave_dataset_test.py new file mode 100644 index 0000000000..796a692c56 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/directed_interleave_dataset_test.py @@ -0,0 +1,148 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.data.experimental.ops import interleave_ops +from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import errors +from tensorflow.python.framework import random_seed +from tensorflow.python.platform import test + + +class DirectedInterleaveDatasetTest(test_base.DatasetTestBase): + + def testBasic(self): + selector_dataset = dataset_ops.Dataset.range(10).repeat(100) + input_datasets = [ + dataset_ops.Dataset.from_tensors(i).repeat(100) for i in range(10) + ] + dataset = interleave_ops._DirectedInterleaveDataset(selector_dataset, + input_datasets) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.cached_session() as sess: + sess.run(iterator.initializer) + for _ in range(100): + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def _normalize(self, vec): + return vec / vec.sum() + + def _chi2(self, expected, actual): + actual = np.asarray(actual) + expected = np.asarray(expected) + diff = actual - expected + chi2 = np.sum(diff * diff / expected, axis=0) + return chi2 + + def _testSampleFromDatasetsHelper(self, weights, num_datasets, num_samples): + # Create a dataset that samples each integer in `[0, num_datasets)` + # with probability given by `weights[i]`. + dataset = interleave_ops.sample_from_datasets([ + dataset_ops.Dataset.from_tensors(i).repeat(None) + for i in range(num_datasets) + ], weights) + dataset = dataset.take(num_samples) + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + with self.cached_session() as sess: + freqs = np.zeros([num_datasets]) + for _ in range(num_samples): + freqs[sess.run(next_element)] += 1 + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + return freqs + + def testSampleFromDatasets(self): + random_seed.set_random_seed(1619) + num_samples = 5000 + rand_probs = self._normalize(np.random.random_sample((15,))) + + # Use chi-squared test to assert that the observed distribution matches the + # expected distribution. Based on the implementation in + # "third_party/tensorflow/python/kernel_tests/multinomial_op_test.py". + for probs in [[.85, .05, .1], rand_probs, [1.]]: + probs = np.asarray(probs) + classes = len(probs) + freqs = self._testSampleFromDatasetsHelper(probs, classes, num_samples) + self.assertLess(self._chi2(probs, freqs / num_samples), 1e-2) + + # Also check that `weights` as a dataset samples correctly. + probs_ds = dataset_ops.Dataset.from_tensors(probs).repeat() + freqs = self._testSampleFromDatasetsHelper(probs_ds, classes, num_samples) + self.assertLess(self._chi2(probs, freqs / num_samples), 1e-2) + + def testSelectFromDatasets(self): + words = [b"foo", b"bar", b"baz"] + datasets = [dataset_ops.Dataset.from_tensors(w).repeat() for w in words] + choice_array = np.random.randint(3, size=(15,), dtype=np.int64) + choice_dataset = dataset_ops.Dataset.from_tensor_slices(choice_array) + dataset = interleave_ops.choose_from_datasets(datasets, choice_dataset) + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + with self.cached_session() as sess: + for i in choice_array: + self.assertEqual(words[i], sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testErrors(self): + with self.assertRaisesRegexp(ValueError, + r"vector of length `len\(datasets\)`"): + interleave_ops.sample_from_datasets( + [dataset_ops.Dataset.range(10), + dataset_ops.Dataset.range(20)], + weights=[0.25, 0.25, 0.25, 0.25]) + + with self.assertRaisesRegexp(TypeError, "`tf.float32` or `tf.float64`"): + interleave_ops.sample_from_datasets( + [dataset_ops.Dataset.range(10), + dataset_ops.Dataset.range(20)], + weights=[1, 1]) + + with self.assertRaisesRegexp(TypeError, "must have the same type"): + interleave_ops.sample_from_datasets([ + dataset_ops.Dataset.from_tensors(0), + dataset_ops.Dataset.from_tensors(0.0) + ]) + + with self.assertRaisesRegexp(TypeError, "tf.int64"): + interleave_ops.choose_from_datasets([ + dataset_ops.Dataset.from_tensors(0), + dataset_ops.Dataset.from_tensors(1) + ], choice_dataset=dataset_ops.Dataset.from_tensors(1.0)) + + with self.assertRaisesRegexp(TypeError, "scalar"): + interleave_ops.choose_from_datasets([ + dataset_ops.Dataset.from_tensors(0), + dataset_ops.Dataset.from_tensors(1) + ], choice_dataset=dataset_ops.Dataset.from_tensors([1.0])) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/filter_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/filter_dataset_op_test.py new file mode 100644 index 0000000000..c6ee88c676 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/filter_dataset_op_test.py @@ -0,0 +1,76 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Benchmarks FilterDataset input pipeline op.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time + +import numpy as np + +from tensorflow.python.client import session +from tensorflow.python.data.experimental.ops import optimization +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class FilterBenchmark(test.Benchmark): + + # This benchmark compares the performance of pipeline with multiple chained + # filter with and without filter fusion. + def benchmarkFilters(self): + chain_lengths = [0, 1, 2, 5, 10, 20, 50] + for chain_length in chain_lengths: + self._benchmarkFilters(chain_length, False) + self._benchmarkFilters(chain_length, True) + + def _benchmarkFilters(self, chain_length, optimize_dataset): + with ops.Graph().as_default(): + dataset = dataset_ops.Dataset.from_tensors(5).repeat(None) + for _ in range(chain_length): + dataset = dataset.filter(lambda x: math_ops.greater_equal(x - 5, 0)) + if optimize_dataset: + dataset = dataset.apply(optimization.optimize(["filter_fusion"])) + + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + with session.Session() as sess: + for _ in range(10): + sess.run(next_element.op) + deltas = [] + for _ in range(100): + start = time.time() + for _ in range(100): + sess.run(next_element.op) + end = time.time() + deltas.append(end - start) + + median_wall_time = np.median(deltas) / 100 + opt_mark = "opt" if optimize_dataset else "no-opt" + print("Filter dataset {} chain length: {} Median wall time: {}".format( + opt_mark, chain_length, median_wall_time)) + self.report_benchmark( + iters=1000, + wall_time=median_wall_time, + name="benchmark_filter_dataset_chain_latency_{}_{}".format( + opt_mark, chain_length)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/get_single_element_test.py b/tensorflow/python/data/experimental/kernel_tests/get_single_element_test.py new file mode 100644 index 0000000000..8c07afbac5 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/get_single_element_test.py @@ -0,0 +1,72 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +from tensorflow.python.data.experimental.ops import get_single_element +from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class GetSingleElementTest(test_base.DatasetTestBase, parameterized.TestCase): + + @parameterized.named_parameters( + ("Zero", 0, 1), + ("Five", 5, 1), + ("Ten", 10, 1), + ("Empty", 100, 1, errors.InvalidArgumentError, "Dataset was empty."), + ("MoreThanOne", 0, 2, errors.InvalidArgumentError, + "Dataset had more than one element."), + ) + def testGetSingleElement(self, skip, take, error=None, error_msg=None): + skip_t = array_ops.placeholder(dtypes.int64, shape=[]) + take_t = array_ops.placeholder(dtypes.int64, shape=[]) + + def make_sparse(x): + x_1d = array_ops.reshape(x, [1]) + x_2d = array_ops.reshape(x, [1, 1]) + return sparse_tensor.SparseTensor(x_2d, x_1d, x_1d) + + dataset = dataset_ops.Dataset.range(100).skip(skip_t).map( + lambda x: (x * x, make_sparse(x))).take(take_t) + element = get_single_element.get_single_element(dataset) + + with self.cached_session() as sess: + if error is None: + dense_val, sparse_val = sess.run( + element, feed_dict={ + skip_t: skip, + take_t: take + }) + self.assertEqual(skip * skip, dense_val) + self.assertAllEqual([[skip]], sparse_val.indices) + self.assertAllEqual([skip], sparse_val.values) + self.assertAllEqual([skip], sparse_val.dense_shape) + else: + with self.assertRaisesRegexp(error, error_msg): + sess.run(element, feed_dict={skip_t: skip, take_t: take}) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/indexed_dataset_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/indexed_dataset_ops_test.py new file mode 100644 index 0000000000..c93a8353ce --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/indexed_dataset_ops_test.py @@ -0,0 +1,79 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for experimental indexed dataset ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import unittest + +from tensorflow.python.data.experimental.ops import indexed_dataset_ops +from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops +from tensorflow.python.platform import test + + +class IndexedDatasetOpsTest(test_base.DatasetTestBase): + + def testLowLevelIndexedDatasetOps(self): + identity = ged_ops.experimental_identity_indexed_dataset( + ops.convert_to_tensor(16, dtype=dtypes.uint64)) + handle = ged_ops.experimental_materialized_index_dataset_handle( + container="", + shared_name="", + output_types=[dtypes.uint64], + output_shapes=[[]]) + materialize = ged_ops.experimental_indexed_dataset_materialize( + identity, handle) + index = array_ops.placeholder(dtypes.uint64) + get_op = ged_ops.experimental_indexed_dataset_get( + handle, index, output_types=[dtypes.uint64], output_shapes=[[]]) + + with self.cached_session() as sess: + sess.run(materialize) + self.assertEqual([3], sess.run(get_op, feed_dict={index: 3})) + + def testIdentityIndexedDataset(self): + ds = indexed_dataset_ops.IdentityIndexedDataset(16) + materialized = ds.materialize() + with self.cached_session() as sess: + sess.run(materialized.initializer) + placeholder = array_ops.placeholder(dtypes.uint64, shape=[]) + for i in range(16): + output = sess.run( + materialized.get(placeholder), feed_dict={placeholder: i}) + self.assertEqual([i], output) + with self.assertRaises(errors.InvalidArgumentError): + sess.run(materialized.get(placeholder), feed_dict={placeholder: 16}) + + @unittest.skip("Requisite functionality currently unimplemented.") + def testIdentityIndexedDatasetIterator(self): + ds = indexed_dataset_ops.IdentityIndexedDataset(16) + itr = ds.make_initializable_iterator() + n = itr.get_next() + with self.cached_session() as sess: + sess.run(itr.initializer) + for i in range(16): + output = sess.run(n) + self.assertEqual(i, output) + with self.assertRaises(errors.OutOfRangeError): + sess.run(n) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/interleave_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/interleave_dataset_op_test.py new file mode 100644 index 0000000000..560902caad --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/interleave_dataset_op_test.py @@ -0,0 +1,811 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools +import math +import threading +import time + +from six.moves import zip_longest + +from tensorflow.python.data.experimental.ops import interleave_ops +from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import script_ops +from tensorflow.python.ops import sparse_ops +from tensorflow.python.platform import test + + +class ParallelInterleaveDatasetTest(test_base.DatasetTestBase): + + def setUp(self): + + self.input_values = array_ops.placeholder(dtypes.int64, shape=[None]) + self.cycle_length = array_ops.placeholder(dtypes.int64, shape=[]) + self.block_length = array_ops.placeholder(dtypes.int64, shape=[]) + self.sloppy = array_ops.placeholder(dtypes.bool, shape=[]) + self.buffer_output_elements = array_ops.placeholder(dtypes.int64, shape=[]) + self.prefetch_input_elements = array_ops.placeholder(dtypes.int64, shape=[]) + + self.error = None + self.repeat_count = 2 + + # Set up threading events used to sequence when items are produced that + # are subsequently interleaved. These events allow us to deterministically + # simulate slowdowns and force sloppiness. + self.read_coordination_events = {} + self.write_coordination_events = {} + # input values [4, 5, 6] are the common case for the tests; set defaults + for i in range(4, 7): + self.read_coordination_events[i] = threading.Semaphore(0) + self.write_coordination_events[i] = threading.Event() + + def map_py_fn(x): + self.write_coordination_events[x].wait() + self.write_coordination_events[x].clear() + self.read_coordination_events[x].release() + if self.error: + err = self.error + self.error = None + raise err # pylint: disable=raising-bad-type + return x * x + + def map_fn(x): + return script_ops.py_func(map_py_fn, [x], x.dtype) + + def interleave_fn(x): + dataset = dataset_ops.Dataset.from_tensors(x) + dataset = dataset.repeat(x) + return dataset.map(map_fn) + + self.dataset = ( + dataset_ops.Dataset.from_tensor_slices(self.input_values) + .repeat(self.repeat_count).apply( + interleave_ops.parallel_interleave(interleave_fn, self.cycle_length, + self.block_length, self.sloppy, + self.buffer_output_elements, + self.prefetch_input_elements))) + self.iterator = self.dataset.make_initializable_iterator() + self.init_op = self.iterator.initializer + self.next_element = self.iterator.get_next() + + def _interleave(self, lists, cycle_length, block_length): + """Python implementation of interleave used for testing.""" + num_open = 0 + + # `all_iterators` acts as a queue of iterators over each element of `lists`. + all_iterators = [iter(l) for l in lists] + + # `open_iterators` are the iterators whose elements are currently being + # interleaved. + open_iterators = [] + for i in range(cycle_length): + if all_iterators: + open_iterators.append(all_iterators.pop(0)) + num_open += 1 + else: + open_iterators.append(None) + + while num_open or all_iterators: + for i in range(cycle_length): + if open_iterators[i] is None: + if all_iterators: + open_iterators[i] = all_iterators.pop(0) + num_open += 1 + else: + continue + for _ in range(block_length): + try: + yield next(open_iterators[i]) + except StopIteration: + open_iterators[i] = None + num_open -= 1 + break + + def testPythonImplementation(self): + input_lists = [[4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6], + [4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]] + + # Cycle length 1 acts like `Dataset.flat_map()`. + expected_elements = itertools.chain(*input_lists) + for expected, produced in zip(expected_elements, + self._interleave(input_lists, 1, 1)): + self.assertEqual(expected, produced) + + # Cycle length > 1. + expected_elements = [ + 4, 5, 4, 5, 4, 5, 4, 5, 5, 6, 6, 4, 6, 4, 6, 4, 6, 4, 6, 5, 6, 5, 6, 5, + 6, 5, 6, 5, 6, 6 + ] + for index, (expected, produced) in enumerate( + zip_longest(expected_elements, self._interleave(input_lists, 2, 1))): + self.assertEqual(expected, produced, "Values differ at %s. %s != %s" % + (index, expected, produced)) + + def testPythonImplementationBlockLength(self): + input_lists = [[4] * 4, [5] * 5, [6] * 6] * 2 + expected_elements = [ + 4, 4, 5, 5, 4, 4, 5, 5, 5, 6, 6, 4, 4, 6, 6, 4, 4, 6, 6, 5, 5, 6, 6, 5, + 5, 6, 6, 5, 6, 6 + ] + for index, (expected, produced) in enumerate( + zip_longest(expected_elements, self._interleave(input_lists, 2, 2))): + self.assertEqual(expected, produced, "Values differ at %s. %s != %s" % + (index, expected, produced)) + + def testPythonImplementationEmptyLists(self): + input_lists = [[4, 4, 4, 4], [], [6, 6, 6, 6, 6, 6], [4, 4, 4, 4], [], + [6, 6, 6, 6, 6, 6]] + + expected_elements = [ + 4, 4, 6, 4, 6, 4, 6, 6, 4, 6, 4, 6, 4, 4, 6, 6, 6, 6, 6, 6 + ] + for index, (expected, produced) in enumerate( + zip_longest(expected_elements, self._interleave(input_lists, 2, 1))): + self.assertEqual(expected, produced, "Values differ at %s. %s != %s" % + (index, expected, produced)) + + def _clear_coordination_events(self): + for i in range(4, 7): + self.read_coordination_events[i] = threading.Semaphore(0) + self.write_coordination_events[i].clear() + + def _allow_all_map_threads(self): + for i in range(4, 7): + self.write_coordination_events[i].set() + + def _testSingleThreaded(self, sloppy=False, prefetch_input_elements=0): + # cycle_length=1,block_length=1 acts like `Dataset.interleave()` and + # `Dataset.flat_map()` and is single-threaded. No synchronization required. + with self.cached_session() as sess: + self._clear_coordination_events() + sess.run( + self.init_op, + feed_dict={ + self.input_values: [4, 5, 6], + self.cycle_length: 1, + self.block_length: 1, + self.sloppy: sloppy, + self.buffer_output_elements: 1, + self.prefetch_input_elements: prefetch_input_elements, + }) + + for expected_element in self._interleave( + [[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 1, 1): + self.write_coordination_events[expected_element].set() + self.assertEqual(expected_element * expected_element, + sess.run(self.next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + + def testSingleThreaded(self): + self._testSingleThreaded() + + def testSingleThreadedSloppy(self): + self._testSingleThreaded(sloppy=True) + + def testSingleThreadedPrefetch1Itr(self): + self._testSingleThreaded(prefetch_input_elements=1) + + def testSingleThreadedPrefetch1ItrSloppy(self): + self._testSingleThreaded(prefetch_input_elements=1, sloppy=True) + + def testSingleThreadedRagged(self): + # Tests a sequence with wildly different elements per iterator. + with self.cached_session() as sess: + self._clear_coordination_events() + sess.run( + self.init_op, + feed_dict={ + self.input_values: [3, 7, 4], + self.cycle_length: 2, + self.block_length: 1, + self.sloppy: False, + self.buffer_output_elements: 1, + self.prefetch_input_elements: 1, + }) + + # Add coordination values for 3 and 7 + self.read_coordination_events[3] = threading.Semaphore(0) + self.write_coordination_events[3] = threading.Event() + self.read_coordination_events[7] = threading.Semaphore(0) + self.write_coordination_events[7] = threading.Event() + + for expected_element in self._interleave( + [[3] * 3, [7] * 7, [4] * 4] * self.repeat_count, 2, 1): + self.write_coordination_events[expected_element].set() + output = sess.run(self.next_element) + self.assertEqual(expected_element * expected_element, output) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + + def _testTwoThreadsNoContention(self, sloppy=False): + # num_threads > 1. + # Explicit coordination should result in `Dataset.interleave()` behavior + with self.cached_session() as sess: + self._clear_coordination_events() + done_first_event = False + sess.run( + self.init_op, + feed_dict={ + self.input_values: [4, 5, 6], + self.cycle_length: 2, + self.block_length: 1, + self.sloppy: sloppy, + self.buffer_output_elements: 1, + self.prefetch_input_elements: 1, + }) + for i, expected_element in enumerate( + self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2, + 1)): + self.write_coordination_events[expected_element].set() + if done_first_event: # First event starts the worker threads. + self.read_coordination_events[expected_element].acquire() + actual_element = sess.run(self.next_element) + if not done_first_event: + self.read_coordination_events[expected_element].acquire() + done_first_event = True + self.assertEqual(expected_element * expected_element, actual_element, + "At index %s: %s expected, got: %s" % + (i, expected_element, actual_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + + def testTwoThreadsNoContention(self): + self._testTwoThreadsNoContention() + + def testTwoThreadsNoContentionSloppy(self): + self._testTwoThreadsNoContention(sloppy=True) + + def _testTwoThreadsNoContentionWithRaces(self, sloppy=False): + """Tests where all the workers race in producing elements. + + Note: this is in contrast with the previous test which carefully sequences + the execution of the map functions. + + Args: + sloppy: Whether to be sloppy or not. + """ + with self.cached_session() as sess: + self._clear_coordination_events() + done_first_event = False + sess.run( + self.init_op, + feed_dict={ + self.input_values: [4, 5, 6], + self.cycle_length: 2, + self.block_length: 1, + self.sloppy: sloppy, + self.buffer_output_elements: 1, + self.prefetch_input_elements: 1, + }) + for i, expected_element in enumerate( + self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2, + 1)): + if done_first_event: # First event starts the worker threads. + self._allow_all_map_threads() + self.read_coordination_events[expected_element].acquire() + else: + self.write_coordination_events[expected_element].set() + time.sleep(0.5) # Sleep to consistently "avoid" the race condition. + actual_element = sess.run(self.next_element) + if not done_first_event: + done_first_event = True + self.assertTrue( + self.read_coordination_events[expected_element].acquire(False)) + self.assertEqual(expected_element * expected_element, actual_element, + "At index %s: %s expected, got: %s" % + (i, expected_element, actual_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + + def testTwoThreadsNoContentionWithRaces(self): + self._testTwoThreadsNoContentionWithRaces() + + def testTwoThreadsNoContentionWithRacesSloppy(self): + self._testTwoThreadsNoContentionWithRaces(sloppy=True) + + def _testTwoThreadsNoContentionBlockLength(self, sloppy=False): + # num_threads > 1. + # Explicit coordination should result in `Dataset.interleave()` behavior + with self.cached_session() as sess: + self._clear_coordination_events() + done_first_event = False + sess.run( + self.init_op, + feed_dict={ + self.input_values: [4, 5, 6], + self.cycle_length: 2, + self.block_length: 2, + self.sloppy: sloppy, + self.buffer_output_elements: 1, + self.prefetch_input_elements: 1, + }) + for i, expected_element in enumerate( + self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2, + 2)): + self.write_coordination_events[expected_element].set() + if done_first_event: # First event starts the worker threads. + self.read_coordination_events[expected_element].acquire() + actual_element = sess.run(self.next_element) + if not done_first_event: + done_first_event = True + self.read_coordination_events[expected_element].acquire() + self.assertEqual(expected_element * expected_element, actual_element, + "At index %s: %s expected, got: %s" % + (i, expected_element, actual_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + + def testTwoThreadsNoContentionBlockLength(self): + self._testTwoThreadsNoContentionBlockLength() + + def testTwoThreadsNoContentionBlockLengthSloppy(self): + self._testTwoThreadsNoContentionBlockLength(sloppy=True) + + def _testTwoThreadsNoContentionWithRacesAndBlocking(self, sloppy=False): + """Tests where all the workers race in producing elements. + + Note: this is in contrast with the previous test which carefully sequences + the execution of the map functions. + + + Args: + sloppy: Whether to be sloppy or not. + """ + with self.cached_session() as sess: + self._clear_coordination_events() + done_first_event = False + sess.run( + self.init_op, + feed_dict={ + self.input_values: [4, 5, 6], + self.cycle_length: 2, + self.block_length: 2, + self.sloppy: sloppy, + self.buffer_output_elements: 1, + self.prefetch_input_elements: 1, + }) + for i, expected_element in enumerate( + self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2, + 2)): + if done_first_event: # First event starts the worker threads. + self._allow_all_map_threads() + self.read_coordination_events[expected_element].acquire() + else: + self.write_coordination_events[expected_element].set() + time.sleep(0.5) # Sleep to consistently "avoid" the race condition. + actual_element = sess.run(self.next_element) + if not done_first_event: + done_first_event = True + self.assertTrue( + self.read_coordination_events[expected_element].acquire(False)) + self.assertEqual(expected_element * expected_element, actual_element, + "At index %s: %s expected, got: %s" % + (i, expected_element, actual_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + + def testTwoThreadsNoContentionWithRacesAndBlocking(self): + self._testTwoThreadsNoContentionWithRacesAndBlocking() + + def testTwoThreadsNoContentionWithRacesAndBlockingSloppy(self): + self._testTwoThreadsNoContentionWithRacesAndBlocking(sloppy=True) + + def _testEmptyInput(self, sloppy=False): + with self.cached_session() as sess: + # Empty input. + self._clear_coordination_events() + sess.run( + self.init_op, + feed_dict={ + self.input_values: [], + self.cycle_length: 2, + self.block_length: 3, + self.sloppy: sloppy, + self.buffer_output_elements: 1, + self.prefetch_input_elements: 0, + }) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + + def testEmptyInput(self): + self._testEmptyInput() + + def testEmptyInputSloppy(self): + self._testEmptyInput(sloppy=True) + + def _testNonEmptyInputIntoEmptyOutputs(self, sloppy=False): + # Non-empty input leading to empty output. + with self.cached_session() as sess: + self._clear_coordination_events() + sess.run( + self.init_op, + feed_dict={ + self.input_values: [0, 0, 0], + self.cycle_length: 2, + self.block_length: 3, + self.sloppy: sloppy, + self.buffer_output_elements: 1, + self.prefetch_input_elements: 0, + }) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + + def testNonEmptyInputIntoEmptyOutputs(self): + self._testNonEmptyInputIntoEmptyOutputs() + + def testNonEmptyInputIntoEmptyOutputsSloppy(self): + self._testNonEmptyInputIntoEmptyOutputs(sloppy=True) + + def _testPartiallyEmptyOutputs(self, sloppy=False, prefetch_input_elements=1): + race_indices = {2, 8, 14} # Sequence points when sloppy mode has race conds + # Mixture of non-empty and empty interleaved datasets. + with self.cached_session() as sess: + self._clear_coordination_events() + done_first_event = False + sess.run( + self.init_op, + feed_dict={ + self.input_values: [4, 0, 6], + self.cycle_length: 2, + self.block_length: 1, + self.sloppy: sloppy, + self.buffer_output_elements: 1, + self.prefetch_input_elements: prefetch_input_elements, + }) + for i, expected_element in enumerate( + self._interleave([[4] * 4, [], [6] * 6] * self.repeat_count, 2, 1)): + self.write_coordination_events[expected_element].set() + # First event starts the worker threads. Additionally, when running the + # sloppy case with prefetch_input_elements=0, we get stuck if we wait + # for the read coordination event for certain event orderings in the + # presence of finishing iterators. + if done_first_event and not (sloppy and (i in race_indices)): + self.read_coordination_events[expected_element].acquire() + actual_element = sess.run(self.next_element) + if not done_first_event or (sloppy and (i in race_indices)): + done_first_event = True + self.read_coordination_events[expected_element].acquire() + self.assertEqual(expected_element * expected_element, actual_element, + "At index %s: %s expected, got: %s" % + (i, expected_element, actual_element)) + + def testPartiallyEmptyOutputs(self): + self._testPartiallyEmptyOutputs() + + def testPartiallyEmptyOutputsSloppy(self): + self._testPartiallyEmptyOutputs(sloppy=True, prefetch_input_elements=0) + + def testDelayedOutputSloppy(self): + # Explicitly control the sequence of events to ensure we correctly avoid + # head-of-line blocking. + with self.cached_session() as sess: + self._clear_coordination_events() + sess.run( + self.init_op, + feed_dict={ + self.input_values: [4, 5, 6], + self.cycle_length: 2, + self.block_length: 1, + self.sloppy: True, + self.buffer_output_elements: 1, + self.prefetch_input_elements: 0, + }) + + mis_ordering = [ + 4, 4, 5, 4, 5, 5, 4, 5, 6, 6, 6, 5, 4, 4, 6, 6, 4, 4, 6, 5, 6, 6, 6, + 6, 5, 5, 5, 5, 6, 6 + ] + for element in mis_ordering: + self.write_coordination_events[element].set() + self.assertEqual(element * element, sess.run(self.next_element)) + self.assertTrue(self.read_coordination_events[element].acquire(False)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + + def testBlockLengthWithContentionSloppy(self): + with self.cached_session() as sess: + self._clear_coordination_events() + done_first_event = False + sess.run( + self.init_op, + feed_dict={ + self.input_values: [4, 5, 6], + self.cycle_length: 2, + self.block_length: 1, + self.sloppy: True, + self.buffer_output_elements: 1, + self.prefetch_input_elements: 1, + }) + # Test against a generating sequence that differs from the uncontended + # case, in order to prove sloppy correctness. + for i, expected_element in enumerate( + self._interleave( + [[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, + cycle_length=2, + block_length=3)): + self.write_coordination_events[expected_element].set() + if done_first_event: # First event starts the worker threads. + self.read_coordination_events[expected_element].acquire() + actual_element = sess.run(self.next_element) + if not done_first_event: + self.read_coordination_events[expected_element].acquire() + done_first_event = True + self.assertEqual(expected_element * expected_element, actual_element, + "At index %s: %s expected, got: %s" % + (i, expected_element, actual_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + + def _testEarlyExit(self, sloppy=False): + # Exiting without consuming all input should not block + with self.cached_session() as sess: + self._clear_coordination_events() + sess.run( + self.init_op, + feed_dict={ + self.input_values: [4, 5, 6], + self.cycle_length: 3, + self.block_length: 2, + self.sloppy: sloppy, + self.buffer_output_elements: 1, + self.prefetch_input_elements: 0, + }) + for i in range(4, 7): + self.write_coordination_events[i].set() + elem = sess.run(self.next_element) # Start all workers + # Allow the one successful worker to progress beyond the py_func again. + elem = int(math.sqrt(elem)) + self.write_coordination_events[elem].set() + self.read_coordination_events[elem].acquire() + # Allow the prefetch to succeed + for i in range(4, 7): + self.read_coordination_events[i].acquire() + self.write_coordination_events[i].set() + + def testEarlyExit(self): + self._testEarlyExit() + + def testEarlyExitSloppy(self): + self._testEarlyExit(sloppy=True) + + def _testTooManyReaders(self, sloppy=False): + + def interleave_fn(x): + dataset = dataset_ops.Dataset.from_tensors(x) + dataset = dataset.repeat(math_ops.cast(x, dtype=dtypes.int64)) + return dataset + + dataset = dataset_ops.Dataset.from_tensor_slices([4, 5, 6]) + dataset = dataset.repeat(self.repeat_count) + dataset = dataset.apply( + interleave_ops.parallel_interleave( + interleave_fn, cycle_length=16, block_length=2, sloppy=sloppy)) + iterator = dataset.make_one_shot_iterator() + + with self.cached_session() as sess: + output_values = [] + for _ in range(30): + output_values.append(sess.run(iterator.get_next())) + + expected_values = self._interleave( + [[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 1, 2) + self.assertItemsEqual(output_values, expected_values) + + def testTooManyReaders(self): + self._testTooManyReaders() + + def testTooManyReadersSloppy(self): + self._testTooManyReaders(sloppy=True) + + def testSparse(self): + def _map_fn(i): + return sparse_tensor.SparseTensor( + indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) + + def _interleave_fn(x): + return dataset_ops.Dataset.from_tensor_slices( + sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) + + dataset = dataset_ops.Dataset.range(10).map(_map_fn) + iterator = dataset.apply( + interleave_ops.parallel_interleave( + _interleave_fn, cycle_length=1)).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.cached_session() as sess: + sess.run(init_op) + for i in range(10): + for j in range(2): + expected = [i, 0] if j % 2 == 0 else [0, -i] + self.assertAllEqual(expected, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testErrorsInOutputFn(self): + with self.cached_session() as sess: + self._clear_coordination_events() + sess.run( + self.init_op, + feed_dict={ + self.input_values: [4, 5, 6], + self.cycle_length: 2, + self.block_length: 1, + self.sloppy: False, + self.buffer_output_elements: 1, + self.prefetch_input_elements: 0, + }) + + except_on_element_indices = set([3]) + + for i, expected_element in enumerate( + self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2, + 1)): + if i in except_on_element_indices: + self.error = ValueError() + self.write_coordination_events[expected_element].set() + with self.assertRaises(errors.InvalidArgumentError): + sess.run(self.next_element) + else: + self.write_coordination_events[expected_element].set() + actual_element = sess.run(self.next_element) + self.assertEqual(expected_element * expected_element, actual_element, + "At index %s: %s expected, got: %s" % + (i, expected_element, actual_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + + def testErrorsInInputFn(self): + + def map_py_fn(x): + if x == 5: + raise ValueError() + return x + + def map_fn(x): + return script_ops.py_func(map_py_fn, [x], x.dtype) + + def interleave_fn(x): + dataset = dataset_ops.Dataset.from_tensors(x) + dataset = dataset.repeat(x) + return dataset + + self.dataset = ( + dataset_ops.Dataset.from_tensor_slices(self.input_values).map(map_fn) + .repeat(self.repeat_count).apply( + interleave_ops.parallel_interleave(interleave_fn, self.cycle_length, + self.block_length, self.sloppy, + self.buffer_output_elements, + self.prefetch_input_elements))) + + self.iterator = self.dataset.make_initializable_iterator() + self.init_op = self.iterator.initializer + self.next_element = self.iterator.get_next() + + with self.cached_session() as sess: + sess.run( + self.init_op, + feed_dict={ + self.input_values: [4, 5, 6], + self.cycle_length: 2, + self.block_length: 1, + self.sloppy: False, + self.buffer_output_elements: 1, + self.prefetch_input_elements: 0, + }) + for i, expected_element in enumerate( + self._interleave([[4] * 4, [5], [6] * 6] * self.repeat_count, 2, 1)): + if expected_element == 5: + with self.assertRaises(errors.InvalidArgumentError): + sess.run(self.next_element) + else: + actual_element = sess.run(self.next_element) + self.assertEqual(expected_element, actual_element, + "At index %s: %s expected, got: %s" % + (i, expected_element, actual_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + + def testErrorsInInterleaveFn(self): + + def map_py_fn(x): + if x == 5: + raise ValueError() + return x + + def interleave_fn(x): + dataset = dataset_ops.Dataset.from_tensors(x) + y = script_ops.py_func(map_py_fn, [x], x.dtype) + dataset = dataset.repeat(y) + return dataset + + self.dataset = ( + dataset_ops.Dataset.from_tensor_slices(self.input_values) + .repeat(self.repeat_count).apply( + interleave_ops.parallel_interleave(interleave_fn, self.cycle_length, + self.block_length, self.sloppy, + self.buffer_output_elements, + self.prefetch_input_elements))) + + self.iterator = self.dataset.make_initializable_iterator() + self.init_op = self.iterator.initializer + self.next_element = self.iterator.get_next() + + with self.cached_session() as sess: + sess.run( + self.init_op, + feed_dict={ + self.input_values: [4, 5, 6], + self.cycle_length: 2, + self.block_length: 1, + self.sloppy: False, + self.buffer_output_elements: 1, + self.prefetch_input_elements: 0, + }) + for i, expected_element in enumerate( + self._interleave([[4] * 4, [5], [6] * 6] * self.repeat_count, 2, 1)): + if expected_element == 5: + with self.assertRaises(errors.InvalidArgumentError): + sess.run(self.next_element) + else: + actual_element = sess.run(self.next_element) + self.assertEqual(expected_element, actual_element, + "At index %s: %s expected, got: %s" % + (i, expected_element, actual_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + + def testShutdownRace(self): + dataset = dataset_ops.Dataset.range(20) + map_fn = lambda x: dataset_ops.Dataset.range(20 * x, 20 * (x + 1)) + dataset = dataset.apply( + interleave_ops.parallel_interleave( + map_fn, + cycle_length=3, + sloppy=False, + buffer_output_elements=1, + prefetch_input_elements=0)) + dataset = dataset.batch(32) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + + results = [] + with self.cached_session() as sess: + for _ in range(2): + elements = [] + sess.run(iterator.initializer) + try: + while True: + elements.extend(sess.run(next_element)) + except errors.OutOfRangeError: + pass + results.append(elements) + + self.assertAllEqual(results[0], results[1]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/iterator_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/iterator_ops_test.py new file mode 100644 index 0000000000..94393d6d4b --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/iterator_ops_test.py @@ -0,0 +1,125 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for experimental iterator_ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.experimental.ops import iterator_ops +from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator import model_fn +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import checkpoint_management +from tensorflow.python.training import saver as saver_lib +from tensorflow.python.training import training_util + + +class CheckpointInputPipelineHookTest(test_base.DatasetTestBase): + + @staticmethod + def _model_fn(features, labels, mode, config): + del labels + del mode + del config + global_step = training_util.get_or_create_global_step() + update_global_step_op = global_step.assign_add(1) + latest_feature = variables.VariableV1( + 0, name='latest_feature', dtype=dtypes.int64) + store_latest_feature_op = latest_feature.assign(features) + ops.add_to_collection('my_vars', global_step) + ops.add_to_collection('my_vars', latest_feature) + return model_fn.EstimatorSpec( + mode='train', + train_op=control_flow_ops.group( + [update_global_step_op, store_latest_feature_op]), + loss=constant_op.constant(2.0)) + + def _read_vars(self, model_dir): + """Returns (global_step, latest_feature).""" + with ops.Graph().as_default() as g: + ckpt_path = checkpoint_management.latest_checkpoint(model_dir) + meta_filename = ckpt_path + '.meta' + saver_lib.import_meta_graph(meta_filename) + saver = saver_lib.Saver() + with self.session(graph=g) as sess: + saver.restore(sess, ckpt_path) + return sess.run(ops.get_collection('my_vars')) + + def _build_iterator_saver_hook(self, est): + return iterator_ops.CheckpointInputPipelineHook(est) + + def testReturnDatasetFromInputFn(self): + + def _input_fn(): + return dataset_ops.Dataset.range(10) + + est = estimator.Estimator(model_fn=self._model_fn) + + est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) + self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1)) + est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) + self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3)) + + def testBuildIteratorInInputFn(self): + + def _input_fn(): + ds = dataset_ops.Dataset.range(10) + iterator = ds.make_one_shot_iterator() + return iterator.get_next() + + est = estimator.Estimator(model_fn=self._model_fn) + + est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) + self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1)) + est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) + self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3)) + + def testDoNotRestore(self): + + def _input_fn(): + return dataset_ops.Dataset.range(10) + + est = estimator.Estimator(model_fn=self._model_fn) + + est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) + self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1)) + est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) + self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3)) + # Hook not provided, input pipeline was not restored. + est.train(_input_fn, steps=2) + self.assertSequenceEqual(self._read_vars(est.model_dir), (6, 1)) + + def testRaiseErrorIfNoIterator(self): + + def _input_fn(): + return constant_op.constant(1, dtype=dtypes.int64) + + est = estimator.Estimator(model_fn=self._model_fn) + + with self.assertRaises(ValueError): + est.train( + _input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/map_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/map_dataset_op_test.py new file mode 100644 index 0000000000..2f0bd1456b --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/map_dataset_op_test.py @@ -0,0 +1,359 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import hashlib +import itertools +import os +import time + +import numpy as np + +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.client import session +from tensorflow.python.data.experimental.ops import batching +from tensorflow.python.data.experimental.ops import error_ops +from tensorflow.python.data.experimental.ops import optimization +from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import io_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test +from tensorflow.python.util import compat + +_NUMPY_RANDOM_SEED = 42 + + +class MapDatasetTest(test_base.DatasetTestBase): + + def testMapIgnoreError(self): + components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32) + + dataset = ( + dataset_ops.Dataset.from_tensor_slices(components) + .map(lambda x: array_ops.check_numerics(x, "message")).apply( + error_ops.ignore_errors())) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.cached_session() as sess: + sess.run(init_op) + for x in [1., 2., 3., 5.]: + self.assertEqual(x, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testParallelMapIgnoreError(self): + components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32) + + dataset = ( + dataset_ops.Dataset.from_tensor_slices(components).map( + lambda x: array_ops.check_numerics(x, "message"), + num_parallel_calls=2).prefetch(2).apply(error_ops.ignore_errors())) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.cached_session() as sess: + sess.run(init_op) + for x in [1., 2., 3., 5.]: + self.assertEqual(x, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testReadFileIgnoreError(self): + + def write_string_to_file(value, filename): + with open(filename, "w") as f: + f.write(value) + + filenames = [ + os.path.join(self.get_temp_dir(), "file_%d.txt" % i) for i in range(5) + ] + for filename in filenames: + write_string_to_file(filename, filename) + + dataset = ( + dataset_ops.Dataset.from_tensor_slices(filenames).map( + io_ops.read_file, + num_parallel_calls=2).prefetch(2).apply(error_ops.ignore_errors())) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.cached_session() as sess: + # All of the files are present. + sess.run(init_op) + for filename in filenames: + self.assertEqual(compat.as_bytes(filename), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Delete one of the files. + os.remove(filenames[0]) + + # Attempting to read filenames[0] will fail, but ignore_errors() + # will catch the error. + sess.run(init_op) + for filename in filenames[1:]: + self.assertEqual(compat.as_bytes(filename), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testCaptureResourceInMapFn(self): + + def _build_ds(iterator): + + def _map_fn(x): + get_next = iterator.get_next() + return x * get_next + + return dataset_ops.Dataset.range(10).map(_map_fn) + + def _build_graph(): + captured_iterator = dataset_ops.Dataset.range( + 10).make_initializable_iterator() + ds = _build_ds(captured_iterator) + iterator = ds.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + return captured_iterator.initializer, init_op, get_next + + with ops.Graph().as_default() as g: + captured_init_op, init_op, get_next = _build_graph() + with self.session(graph=g) as sess: + sess.run(captured_init_op) + sess.run(init_op) + for i in range(10): + self.assertEquals(i * i, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + +class MapDatasetBenchmark(test.Benchmark): + + # The purpose of this benchmark is to compare the performance of chaining vs + # fusing of the map and batch transformations across various configurations. + # + # NOTE: It is recommended to build the benchmark with + # `-c opt --copt=-mavx --copt=-mavx2 --copt=-mfma --copt=-gmlt` + # and execute it on a machine with at least 32 CPU cores. + def benchmarkMapAndBatch(self): + + # Sequential pipeline configurations. + seq_elem_size_series = itertools.product([1], [1], [1, 2, 4, 8], [16]) + seq_batch_size_series = itertools.product([1], [1], [1], [8, 16, 32, 64]) + + # Parallel pipeline configuration. + par_elem_size_series = itertools.product([32], [32], [1, 2, 4, 8], [256]) + par_batch_size_series = itertools.product([32], [32], [1], + [128, 256, 512, 1024]) + par_num_calls_series = itertools.product([8, 16, 32, 64], [32], [1], [512]) + par_inter_op_series = itertools.product([32], [8, 16, 32, 64], [1], [512]) + + def name(method, label, num_calls, inter_op, element_size, batch_size): + return ("%s_id_%s_num_calls_%d_inter_op_%d_elem_size_%d_batch_size_%d" % ( + method, + hashlib.sha1(label).hexdigest(), + num_calls, + inter_op, + element_size, + batch_size, + )) + + def benchmark(label, series): + + print("%s:" % label) + for num_calls, inter_op, element_size, batch_size in series: + + num_iters = 1024 // ( + (element_size * batch_size) // min(num_calls, inter_op)) + k = 1024 * 1024 + dataset = dataset_ops.Dataset.from_tensors((np.random.rand( + element_size, 4 * k), np.random.rand(4 * k, 1))).repeat() + + chained_dataset = dataset.map( + math_ops.matmul, + num_parallel_calls=num_calls).batch(batch_size=batch_size) + chained_iterator = chained_dataset.make_one_shot_iterator() + chained_get_next = chained_iterator.get_next() + + chained_deltas = [] + with session.Session( + config=config_pb2.ConfigProto( + inter_op_parallelism_threads=inter_op, + use_per_session_threads=True)) as sess: + for _ in range(5): + sess.run(chained_get_next.op) + for _ in range(num_iters): + start = time.time() + sess.run(chained_get_next.op) + end = time.time() + chained_deltas.append(end - start) + + fused_dataset = dataset.apply( + batching.map_and_batch( + math_ops.matmul, + num_parallel_calls=num_calls, + batch_size=batch_size)) + fused_iterator = fused_dataset.make_one_shot_iterator() + fused_get_next = fused_iterator.get_next() + + fused_deltas = [] + with session.Session( + config=config_pb2.ConfigProto( + inter_op_parallelism_threads=inter_op, + use_per_session_threads=True)) as sess: + + for _ in range(5): + sess.run(fused_get_next.op) + for _ in range(num_iters): + start = time.time() + sess.run(fused_get_next.op) + end = time.time() + fused_deltas.append(end - start) + + print( + "batch size: %d, num parallel calls: %d, inter-op parallelism: %d, " + "element size: %d, num iters: %d\nchained wall time: %f (median), " + "%f (mean), %f (stddev), %f (min), %f (max)\n fused wall time: " + "%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n " + "chained/fused: %.2fx (median), %.2fx (mean)" % + (batch_size, num_calls, inter_op, element_size, num_iters, + np.median(chained_deltas), np.mean(chained_deltas), + np.std(chained_deltas), np.min(chained_deltas), + np.max(chained_deltas), np.median(fused_deltas), + np.mean(fused_deltas), np.std(fused_deltas), np.min(fused_deltas), + np.max(fused_deltas), + np.median(chained_deltas) / np.median(fused_deltas), + np.mean(chained_deltas) / np.mean(fused_deltas))) + + self.report_benchmark( + iters=num_iters, + wall_time=np.median(chained_deltas), + name=name("chained", label, num_calls, inter_op, element_size, + batch_size)) + + self.report_benchmark( + iters=num_iters, + wall_time=np.median(fused_deltas), + name=name("fused", label, num_calls, inter_op, element_size, + batch_size)) + + print("") + + np.random.seed(_NUMPY_RANDOM_SEED) + benchmark("Sequential element size evaluation", seq_elem_size_series) + benchmark("Sequential batch size evaluation", seq_batch_size_series) + benchmark("Parallel element size evaluation", par_elem_size_series) + benchmark("Parallel batch size evaluation", par_batch_size_series) + benchmark("Transformation parallelism evaluation", par_num_calls_series) + benchmark("Threadpool size evaluation", par_inter_op_series) + + # This benchmark compares the performance of pipeline with multiple chained + # maps with and without map fusion. + def benchmarkChainOfMaps(self): + chain_lengths = [0, 1, 2, 5, 10, 20, 50] + for chain_length in chain_lengths: + self._benchmarkChainOfMaps(chain_length, False) + self._benchmarkChainOfMaps(chain_length, True) + + def _benchmarkChainOfMaps(self, chain_length, optimize_dataset): + with ops.Graph().as_default(): + dataset = dataset_ops.Dataset.from_tensors(0).repeat(None) + for _ in range(chain_length): + dataset = dataset.map(lambda x: x) + if optimize_dataset: + dataset = dataset.apply(optimization.optimize(["map_fusion"])) + + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + with session.Session() as sess: + for _ in range(5): + sess.run(next_element.op) + deltas = [] + for _ in range(100): + start = time.time() + for _ in range(100): + sess.run(next_element.op) + end = time.time() + deltas.append(end - start) + + median_wall_time = np.median(deltas) / 100 + opt_mark = "opt" if optimize_dataset else "no-opt" + print("Map dataset {} chain length: {} Median wall time: {}".format( + opt_mark, chain_length, median_wall_time)) + self.report_benchmark( + iters=1000, + wall_time=median_wall_time, + name="benchmark_map_dataset_chain_latency_{}_{}".format( + opt_mark, chain_length)) + + +class MapAndFilterBenchmark(test.Benchmark): + + # This benchmark compares the performance of pipeline with multiple chained + # map + filter with and without map fusion. + def benchmarkMapAndFilter(self): + chain_lengths = [0, 1, 2, 5, 10, 20, 50] + for chain_length in chain_lengths: + self._benchmarkMapAndFilter(chain_length, False) + self._benchmarkMapAndFilter(chain_length, True) + + def _benchmarkMapAndFilter(self, chain_length, optimize_dataset): + with ops.Graph().as_default(): + dataset = dataset_ops.Dataset.from_tensors(0).repeat(None) + for _ in range(chain_length): + dataset = dataset.map(lambda x: x + 5).filter( + lambda x: math_ops.greater_equal(x - 5, 0)) + if optimize_dataset: + dataset = dataset.apply( + optimization.optimize(["map_and_filter_fusion"])) + + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + with session.Session() as sess: + for _ in range(10): + sess.run(next_element.op) + deltas = [] + for _ in range(100): + start = time.time() + for _ in range(100): + sess.run(next_element.op) + end = time.time() + deltas.append(end - start) + + median_wall_time = np.median(deltas) / 100 + opt_mark = "opt" if optimize_dataset else "no-opt" + print("Map and filter dataset {} chain length: {} Median wall time: {}". + format(opt_mark, chain_length, median_wall_time)) + self.report_benchmark( + iters=1000, + wall_time=median_wall_time, + name="benchmark_map_and_filter_dataset_chain_latency_{}_{}".format( + opt_mark, chain_length)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py b/tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py new file mode 100644 index 0000000000..612ee332c4 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py @@ -0,0 +1,281 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for MapDefunOp.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time + +from tensorflow.python.client import session +from tensorflow.python.data.experimental.ops import map_defun +from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import function +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import functional_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class MapDefunTest(test_base.DatasetTestBase): + + def testMapDefunSimple(self): + + @function.Defun(dtypes.int32) + def simple_fn(x): + return x * 2 + 3 + + nums = [[1, 2], [3, 4], [5, 6]] + elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") + r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(2,)])[0] + expected = elems * 2 + 3 + self.assertAllEqual(self.evaluate(r), self.evaluate(expected)) + + def testMapDefunMismatchedTypes(self): + + @function.Defun(dtypes.int32) + def fn(x): + return math_ops.cast(x, dtypes.float64) + + nums = [1, 2, 3, 4, 5, 6] + elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") + r = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])[0] + with self.assertRaises(errors.InvalidArgumentError): + self.evaluate(r) + + def testMapDefunReduceDim(self): + # Tests where the output has a different rank from the input + + @function.Defun(dtypes.int32) + def fn(x): + return array_ops.gather(x, 0) + + nums = [[1, 2], [3, 4], [5, 6]] + elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") + r = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])[0] + expected = constant_op.constant([1, 3, 5]) + self.assertAllEqual(self.evaluate(r), self.evaluate(expected)) + + def testMapDefunMultipleOutputs(self): + + @function.Defun(dtypes.int32) + def fn(x): + return (x, math_ops.cast(x * 2 + 3, dtypes.float64)) + + nums = [[1, 2], [3, 4], [5, 6]] + elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") + r = map_defun.map_defun(fn, [elems], [dtypes.int32, dtypes.float64], [(2,), + (2,)]) + expected = [elems, elems * 2 + 3] + self.assertAllEqual(self.evaluate(r), self.evaluate(expected)) + + def testMapDefunShapeInference(self): + + @function.Defun(dtypes.int32) + def fn(x): + return x + + nums = [[1, 2], [3, 4], [5, 6]] + elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") + result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2,)])[0] + self.assertEqual(result.get_shape(), (3, 2)) + + def testMapDefunPartialShapeInference(self): + + @function.Defun(dtypes.int32) + def fn(x): + return x + + elems = array_ops.placeholder(dtypes.int64, (None, 2)) + result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2,)]) + self.assertEqual(result[0].get_shape().as_list(), [None, 2]) + + def testMapDefunRaisesErrorOnRuntimeShapeMismatch(self): + + @function.Defun(dtypes.int32, dtypes.int32) + def fn(x, y): + return x, y + + elems1 = array_ops.placeholder(dtypes.int32) + elems2 = array_ops.placeholder(dtypes.int32) + result = map_defun.map_defun(fn, [elems1, elems2], + [dtypes.int32, dtypes.int32], [(), ()]) + with self.cached_session() as sess: + with self.assertRaisesWithPredicateMatch( + errors.InvalidArgumentError, + "All inputs must have the same dimension 0."): + sess.run(result, feed_dict={elems1: [1, 2, 3, 4, 5], elems2: [1, 2, 3]}) + + def testMapDefunRaisesDefunError(self): + + @function.Defun(dtypes.int32) + def fn(x): + with ops.control_dependencies([check_ops.assert_equal(x, 0)]): + return array_ops.identity(x) + + elems = constant_op.constant([0, 0, 0, 37, 0]) + result = map_defun.map_defun(fn, [elems], [dtypes.int32], [()]) + with self.assertRaises(errors.InvalidArgumentError): + self.evaluate(result) + + def testMapDefunCancelledCorrectly(self): + + @function.Defun(dtypes.int64) + def defun(x): + # x has leading dimension 5, this will raise an error + return array_ops.gather(x, 10) + + c = array_ops.tile( + array_ops.expand_dims( + constant_op.constant([1, 2, 3, 4, 5], dtype=dtypes.int64), 0), + [100, 1]) + map_defun_op = map_defun.map_defun(defun, [c], [dtypes.int64], [()])[0] + with self.assertRaisesRegexp(errors.InvalidArgumentError, + r"indices = 10 is not in \[0, 5\)"): + self.evaluate(map_defun_op) + + def testMapDefunWithUnspecifiedOutputShape(self): + + @function.Defun(dtypes.int32) + def simple_fn(x): + res = x * 2 + 3 + return (res, res + 1, res + 2) + + nums = [[1, 2], [3, 4], [5, 6]] + elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") + r = map_defun.map_defun(simple_fn, [elems], + [dtypes.int32, dtypes.int32, dtypes.int32], + [None, (None,), (2,)]) + expected = elems * 2 + 3 + self.assertAllEqual(self.evaluate(r[0]), self.evaluate(expected)) + self.assertAllEqual(self.evaluate(r[1]), self.evaluate(expected + 1)) + self.assertAllEqual(self.evaluate(r[2]), self.evaluate(expected + 2)) + + def testMapDefunWithDifferentOutputShapeEachRun(self): + + @function.Defun(dtypes.int32) + def simple_fn(x): + return x * 2 + 3 + + elems = array_ops.placeholder(dtypes.int32, name="data") + r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [None])[0] + with session.Session() as sess: + self.assertAllEqual(sess.run(r, feed_dict={elems: [0]}), [3]) + self.assertAllEqual( + sess.run(r, feed_dict={elems: [[0], [1]]}), [[3], [5]]) + + def testMapDefunWithWrongOutputShape(self): + + @function.Defun(dtypes.int32) + def simple_fn(x): + return x * 2 + 3 + + nums = [[1, 2], [3, 4], [5, 6]] + elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") + r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(1,)])[0] + with self.assertRaises(errors.InvalidArgumentError): + self.evaluate(r) + + def testMapDefunWithInvalidInput(self): + + @function.Defun(dtypes.int32) + def simple_fn(x): + return x * 2 + + c = constant_op.constant(2) + with self.assertRaises(ValueError): + # Fails at graph construction time for inputs with known shapes. + r = map_defun.map_defun(simple_fn, [c], [dtypes.int32], [None])[0] + p = array_ops.placeholder(dtypes.int32) + r = map_defun.map_defun(simple_fn, [p], [dtypes.int32], [None])[0] + with session.Session() as sess: + with self.assertRaises(errors.InvalidArgumentError): + sess.run(r, feed_dict={p: 0}) + + def _assert_op_cancelled(self, sess, map_defun_op): + with self.assertRaisesRegexp(errors.CancelledError, "was cancelled"): + sess.run(map_defun_op) + + def testMapDefunWithParentCancellation(self): + # Checks that a cancellation of the parent graph is threaded through to + # MapDefunOp correctly. + @function.Defun(dtypes.int32) + def simple_fn(x): + del x + queue = data_flow_ops.FIFOQueue(10, dtypes.int32, ()) + # Blocking + return queue.dequeue_many(5) + + c = constant_op.constant([1, 2, 3, 4, 5]) + map_defun_op = map_defun.map_defun(simple_fn, [c], [dtypes.int32], [()])[0] + + with self.cached_session() as sess: + thread = self.checkedThread( + self._assert_op_cancelled, args=(sess, map_defun_op)) + thread.start() + time.sleep(0.1) + sess.close() + thread.join() + + +class MapDefunBenchmark(test.Benchmark): + + def _run(self, op, name=None, num_iters=3000): + with session.Session() as sess: + # Warm up the session + for _ in range(5): + sess.run(op) + start = time.time() + for _ in range(num_iters): + sess.run(op) + end = time.time() + mean_us = (end - start) * 1e6 / num_iters + self.report_benchmark( + name=name, + iters=num_iters, + wall_time=mean_us, + extras={"examples_per_sec": num_iters / (end - start)}) + + def benchmarkDefunVsMapFn(self): + """Benchmarks to compare the performance of MapDefun vs tf.map_fn.""" + + @function.Defun(dtypes.int32) + def defun(x): + return array_ops.identity(x) + + def map_fn(x): + return array_ops.identity(x) + + base = math_ops.range(100) + for input_size in [10, 100, 1000, 10000]: + num_iters = 100000 // input_size + map_defun_op = map_defun.map_defun(defun, [base], [dtypes.int32], [()]) + map_fn_op = functional_ops.map_fn(map_fn, base) + + self._run( + map_defun_op, + "benchmarkMapDefun_size_%d" % input_size, + num_iters=num_iters) + self._run( + map_fn_op, "benchmarkMapFn_size_%d" % input_size, num_iters=num_iters) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD b/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD new file mode 100644 index 0000000000..68f73bddb5 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD @@ -0,0 +1,164 @@ +package(default_visibility = ["//tensorflow:internal"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "py_test") + +py_test( + name = "assert_next_dataset_op_test", + size = "medium", + srcs = ["assert_next_dataset_op_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python/data/experimental/ops:optimization", + "//tensorflow/python/data/kernel_tests:test_base", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "hoist_random_uniform_test", + size = "small", + srcs = ["hoist_random_uniform_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:math_ops", + "//tensorflow/python/data/experimental/ops:optimization", + "//tensorflow/python/data/kernel_tests:test_base", + "//tensorflow/python/data/ops:dataset_ops", + "@absl_py//absl/testing:parameterized", + ], +) + +py_test( + name = "latency_all_edges_test", + size = "small", + srcs = ["latency_all_edges_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python/data/experimental/kernel_tests:stats_dataset_test_base", + "//tensorflow/python/data/experimental/ops:optimization", + "//tensorflow/python/data/experimental/ops:stats_ops", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "map_vectorization_test", + size = "small", + srcs = ["map_vectorization_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:check_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:session", + "//tensorflow/python/data/experimental/ops:optimization", + "//tensorflow/python/data/kernel_tests:test_base", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) + +py_test( + name = "map_and_filter_fusion_test", + size = "medium", + srcs = ["map_and_filter_fusion_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:math_ops", + "//tensorflow/python/data/experimental/ops:optimization", + "//tensorflow/python/data/kernel_tests:test_base", + "//tensorflow/python/data/ops:dataset_ops", + "@absl_py//absl/testing:parameterized", + ], +) + +py_test( + name = "map_parallelization_test", + size = "small", + srcs = ["map_parallelization_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:math_ops", + "//tensorflow/python/data/experimental/ops:optimization", + "//tensorflow/python/data/kernel_tests:test_base", + "//tensorflow/python/data/ops:dataset_ops", + "@absl_py//absl/testing:parameterized", + ], +) + +py_test( + name = "model_dataset_op_test", + size = "medium", + srcs = ["model_dataset_op_test.py"], + srcs_version = "PY2AND3", + tags = [ + "optonly", + ], + deps = [ + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python/data/experimental/ops:batching", + "//tensorflow/python/data/experimental/ops:interleave_ops", + "//tensorflow/python/data/experimental/ops:optimization", + "//tensorflow/python/data/kernel_tests:test_base", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "noop_elimination_test", + size = "small", + srcs = ["noop_elimination_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python/data/experimental/ops:batching", + "//tensorflow/python/data/experimental/ops:interleave_ops", + "//tensorflow/python/data/experimental/ops:optimization", + "//tensorflow/python/data/kernel_tests:test_base", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "optimize_dataset_op_test", + size = "small", + srcs = ["optimize_dataset_op_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python/data/experimental/ops:optimization", + "//tensorflow/python/data/kernel_tests:test_base", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/assert_next_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/assert_next_dataset_op_test.py new file mode 100644 index 0000000000..45b77b5c20 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/assert_next_dataset_op_test.py @@ -0,0 +1,65 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.experimental.ops import optimization +from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import errors +from tensorflow.python.platform import test + + +class AssertNextDatasetTest(test_base.DatasetTestBase): + + def testAssertNext(self): + dataset = dataset_ops.Dataset.from_tensors(0).apply( + optimization.assert_next(["Map"])).map(lambda x: x) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.cached_session() as sess: + self.assertEqual(0, sess.run(get_next)) + + def testAssertNextInvalid(self): + dataset = dataset_ops.Dataset.from_tensors(0).apply( + optimization.assert_next(["Whoops"])).map(lambda x: x) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.cached_session() as sess: + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + "Asserted Whoops transformation at offset 0 but encountered " + "Map transformation instead."): + sess.run(get_next) + + def testAssertNextShort(self): + dataset = dataset_ops.Dataset.from_tensors(0).apply( + optimization.assert_next(["Map", "Whoops"])).map(lambda x: x) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.cached_session() as sess: + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + "Asserted next 2 transformations but encountered only 1."): + sess.run(get_next) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/hoist_random_uniform_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/hoist_random_uniform_test.py new file mode 100644 index 0000000000..3cd9753665 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/hoist_random_uniform_test.py @@ -0,0 +1,103 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for HostState optimization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +from tensorflow.python.data.experimental.ops import optimization +from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.platform import test + + +class HoistRandomUniformTest(test_base.DatasetTestBase, parameterized.TestCase): + + @staticmethod + def map_functions(): + plus_one = lambda x: x + 1 + + def random(_): + return random_ops.random_uniform([], + minval=1, + maxval=10, + dtype=dtypes.float32, + seed=42) + + def random_with_assert(x): + y = random(x) + assert_op = control_flow_ops.Assert(math_ops.greater_equal(y, 1), [y]) + with ops.control_dependencies([assert_op]): + return y + + twice_random = lambda x: (random(x) + random(x)) / 2. + + tests = [("PlusOne", plus_one, False), ("RandomUniform", random, True), + ("RandomWithAssert", random_with_assert, True), + ("TwiceRandom", twice_random, False)] + return tuple(tests) + + @parameterized.named_parameters(*map_functions.__func__()) + def testHoisting(self, function, will_optimize): + dataset = dataset_ops.Dataset.range(5).apply( + optimization.assert_next( + ["Zip[0]", "Map"] if will_optimize else ["Map"])).map(function) + + dataset = dataset.apply(optimization.optimize(["hoist_random_uniform"])) + self._testDataset(dataset) + + def testAdditionalInputs(self): + a = constant_op.constant(1, dtype=dtypes.float32) + b = constant_op.constant(0, dtype=dtypes.float32) + some_tensor = math_ops.mul(a, b) + + def random_with_capture(_): + return some_tensor + random_ops.random_uniform( + [], minval=1, maxval=10, dtype=dtypes.float32, seed=42) + + dataset = dataset_ops.Dataset.range(5).apply( + optimization.assert_next( + ["Zip[0]", "Map"])).map(random_with_capture).apply( + optimization.optimize(["hoist_random_uniform"])) + self._testDataset(dataset) + + def _testDataset(self, dataset): + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + previous_result = 0 + with self.cached_session() as sess: + for _ in range(5): + result = sess.run(get_next) + self.assertLessEqual(1, result) + self.assertLessEqual(result, 10) + # This checks if the result is somehow random by checking if we are not + # generating the same values. + self.assertNotEqual(previous_result, result) + previous_result = result + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/latency_all_edges_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/latency_all_edges_test.py new file mode 100644 index 0000000000..45623876ae --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/latency_all_edges_test.py @@ -0,0 +1,58 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the LatencyAllEdges optimization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.experimental.kernel_tests import stats_dataset_test_base +from tensorflow.python.data.experimental.ops import optimization +from tensorflow.python.data.experimental.ops import stats_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import errors +from tensorflow.python.platform import test + + +class OptimizeStatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase): + + def testLatencyStatsOptimization(self): + + stats_aggregator = stats_ops.StatsAggregator() + dataset = dataset_ops.Dataset.from_tensors(1).apply( + optimization.assert_next( + ["LatencyStats", "Map", "LatencyStats", "Prefetch", + "LatencyStats"])).map(lambda x: x * x).prefetch(1).apply( + stats_ops.set_stats_aggregator(stats_aggregator)).apply( + optimization.optimize(["latency_all_edges"])) + iterator = dataset.make_initializable_iterator() + get_next = iterator.get_next() + summary_t = stats_aggregator.get_summary() + + with self.cached_session() as sess: + sess.run(iterator.initializer) + self.assertEqual(1 * 1, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + summary_str = sess.run(summary_t) + self._assertSummaryHasCount(summary_str, + "record_latency_TensorDataset/_1", 1) + self._assertSummaryHasCount(summary_str, "record_latency_MapDataset/_4", + 1) + self._assertSummaryHasCount(summary_str, + "record_latency_PrefetchDataset/_6", 1) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/map_and_filter_fusion_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/map_and_filter_fusion_test.py new file mode 100644 index 0000000000..a439635716 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/map_and_filter_fusion_test.py @@ -0,0 +1,225 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the MapAndFilterFusion optimization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +from tensorflow.python.data.experimental.ops import optimization +from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class MapAndFilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase): + + @staticmethod + def map_functions(): + identity = lambda x: x + increment = lambda x: x + 1 + + def increment_and_square(x): + y = x + 1 + return y * y + + functions = [identity, increment, increment_and_square] + tests = [] + for i, fun1 in enumerate(functions): + for j, fun2 in enumerate(functions): + tests.append(( + "Test{}{}".format(i, j), + [fun1, fun2], + )) + for k, fun3 in enumerate(functions): + tests.append(( + "Test{}{}{}".format(i, j, k), + [fun1, fun2, fun3], + )) + + swap = lambda x, n: (n, x) + tests.append(( + "Swap1", + [lambda x: (x, 42), swap], + )) + tests.append(( + "Swap2", + [lambda x: (x, 42), swap, swap], + )) + return tuple(tests) + + @parameterized.named_parameters(*map_functions.__func__()) + def testMapFusion(self, functions): + dataset = dataset_ops.Dataset.range(5).apply( + optimization.assert_next(["Map", "Prefetch"])) + for function in functions: + dataset = dataset.map(function) + + dataset = dataset.prefetch(0).apply(optimization.optimize(["map_fusion"])) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + with self.cached_session() as sess: + for x in range(5): + result = sess.run(get_next) + r = x + for function in functions: + if isinstance(r, tuple): + r = function(*r) # Pass tuple as multiple arguments. + else: + r = function(r) + self.assertAllEqual(r, result) + + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + @staticmethod + def map_and_filter_functions(): + identity = lambda x: x + increment = lambda x: x + 1 + minus_five = lambda x: x - 5 + + def increment_and_square(x): + y = x + 1 + return y * y + + take_all = lambda x: constant_op.constant(True) + is_zero = lambda x: math_ops.equal(x, 0) + is_odd = lambda x: math_ops.equal(x % 2, 0) + greater = lambda x: math_ops.greater(x + 5, 0) + + functions = [identity, increment, minus_five, increment_and_square] + filters = [take_all, is_zero, is_odd, greater] + tests = [] + + for x, fun in enumerate(functions): + for y, predicate in enumerate(filters): + tests.append(("Mixed{}{}".format(x, y), fun, predicate)) + + # Multi output + tests.append(("Multi1", lambda x: (x, x), + lambda x, y: constant_op.constant(True))) + tests.append( + ("Multi2", lambda x: (x, 2), + lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0))) + return tuple(tests) + + @parameterized.named_parameters(*map_and_filter_functions.__func__()) + def testMapFilterFusion(self, function, predicate): + dataset = dataset_ops.Dataset.range(10).apply( + optimization.assert_next( + ["Map", + "FilterByLastComponent"])).map(function).filter(predicate).apply( + optimization.optimize(["map_and_filter_fusion"])) + self._testMapAndFilter(dataset, function, predicate) + + def _testMapAndFilter(self, dataset, function, predicate): + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + with self.cached_session() as sess: + for x in range(10): + r = function(x) + if isinstance(r, tuple): + b = predicate(*r) # Pass tuple as multiple arguments. + else: + b = predicate(r) + if sess.run(b): + result = sess.run(get_next) + self.assertAllEqual(r, result) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testAdditionalInputs(self): + a = constant_op.constant(3, dtype=dtypes.int64) + b = constant_op.constant(4, dtype=dtypes.int64) + some_tensor = math_ops.mul(a, b) + function = lambda x: x * x + + def predicate(y): + return math_ops.less(math_ops.cast(y, dtypes.int64), some_tensor) + + # We are currently not supporting functions with additional inputs. + dataset = dataset_ops.Dataset.range(10).apply( + optimization.assert_next( + ["Map", "Filter"])).map(function).filter(predicate).apply( + optimization.optimize(["map_and_filter_fusion"])) + + self._testMapAndFilter(dataset, function, predicate) + + @staticmethod + def filter_functions(): + take_all = lambda x: constant_op.constant(True) + is_zero = lambda x: math_ops.equal(x, 0) + greater = lambda x: math_ops.greater(x + 5, 0) + + tests = [] + filters = [take_all, is_zero, greater] + identity = lambda x: x + for x, predicate_1 in enumerate(filters): + for y, predicate_2 in enumerate(filters): + tests.append(("Mixed{}{}".format(x, y), identity, + [predicate_1, predicate_2])) + for z, predicate_3 in enumerate(filters): + tests.append(("Mixed{}{}{}".format(x, y, z), identity, + [predicate_1, predicate_2, predicate_3])) + + take_all_multiple = lambda x, y: constant_op.constant(True) + # Multi output + tests.append(("Multi1", lambda x: (x, x), + [take_all_multiple, take_all_multiple])) + tests.append(("Multi2", lambda x: (x, 2), [ + take_all_multiple, + lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0) + ])) + return tuple(tests) + + @parameterized.named_parameters(*filter_functions.__func__()) + def testFilterFusion(self, map_function, predicates): + dataset = dataset_ops.Dataset.range(5).apply( + optimization.assert_next(["Map", "Filter", + "Prefetch"])).map(map_function) + for predicate in predicates: + dataset = dataset.filter(predicate) + + dataset = dataset.prefetch(0).apply( + optimization.optimize(["filter_fusion"])) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + with self.cached_session() as sess: + for x in range(5): + r = map_function(x) + filtered = False + for predicate in predicates: + if isinstance(r, tuple): + b = predicate(*r) # Pass tuple as multiple arguments. + else: + b = predicate(r) + if not sess.run(b): + filtered = True + break + + if not filtered: + result = sess.run(get_next) + self.assertAllEqual(r, result) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/map_parallelization_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/map_parallelization_test.py new file mode 100644 index 0000000000..334d8e3778 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/map_parallelization_test.py @@ -0,0 +1,85 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the MapParallelization optimization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +from tensorflow.python.data.experimental.ops import optimization +from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.platform import test + + +class MapParallelizationTest(test_base.DatasetTestBase, parameterized.TestCase): + + @staticmethod + def map_functions(): + identity = lambda x: x + increment = lambda x: x + 1 + + def assert_greater(x): + assert_op = control_flow_ops.Assert(math_ops.greater(x, -1), [x]) + with ops.control_dependencies([assert_op]): + return x + + def random(_): + return random_ops.random_uniform([], + minval=0, + maxval=10, + dtype=dtypes.int64, + seed=42) + + def assert_with_random(x): + x = assert_greater(x) + return random(x) + + return (("Identity", identity, True), ("Increment", increment, True), + ("AssertGreater", assert_greater, True), ("Random", random, False), + ("AssertWithRandom", assert_with_random, False)) + + @parameterized.named_parameters(*map_functions.__func__()) + def testMapParallelization(self, function, should_optimize): + next_nodes = ["ParallelMap"] if should_optimize else ["Map"] + dataset = dataset_ops.Dataset.range(5).apply( + optimization.assert_next(next_nodes)).map(function).apply( + optimization.optimize(["map_parallelization"])) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + for x in range(5): + result = sess.run(get_next) + # No need to run the pipeline if it was not optimized. Also the results + # might be hard to check because of random. + if not should_optimize: + return + r = function(x) + self.assertAllEqual(r, result) + + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py new file mode 100644 index 0000000000..d47492753e --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py @@ -0,0 +1,223 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the MapVectorization optimization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time + +from absl.testing import parameterized +import numpy as np + +from tensorflow.python.client import session +from tensorflow.python.data.experimental.ops import optimization +from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase): + + def _get_test_datasets(self, + base_dataset, + map_fn, + num_parallel_calls=None, + expect_optimized=True): + """Given base dataset and map fn, creates test datasets. + + Returns a tuple of (unoptimized, dataset, optimized dataset). The + unoptimized dataset has the assertion that Batch follows Map. The optimized + dataset has the assertion that Map follows Batch, and has the + "map_vectorization" optimization applied. + + Args: + base_dataset: Input dataset to map->batch + map_fn: Map function to use + num_parallel_calls: (Optional.) num_parallel_calls argument for map + expect_optimized: (Optional.) Whether we expect the optimization to take + place, in which case we will assert that Batch is followed by Map, + otherwise Map followed by Batch. Defaults to True. + + Returns: + Tuple of (unoptimized dataset, optimized dataset). + """ + map_node_name = "Map" if num_parallel_calls is None else "ParallelMap" + batch_size = 100 + + def _make_dataset(node_names): + return base_dataset.apply(optimization.assert_next(node_names)).map( + map_fn, num_parallel_calls=num_parallel_calls).batch(batch_size) + + unoptimized = _make_dataset([map_node_name, "Batch"]) + optimized = _make_dataset(["Batch", map_node_name] if expect_optimized else + [map_node_name, "Batch"]).apply( + optimization.optimize(["map_vectorization"])) + + return unoptimized, optimized + + @parameterized.named_parameters( + ("Basic", lambda x: (x, x + 1), None), + ("Parallel", lambda x: (x, x + 1), 12), + ("Gather", lambda x: array_ops.gather(x, 0), 12), + ) + def testOptimization(self, map_fn, num_parallel_calls): + base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2], + [3, 4]]).repeat(5) + unoptimized, optimized = self._get_test_datasets(base_dataset, map_fn, + num_parallel_calls) + self.assertDatasetsEqual(unoptimized, optimized) + + def testOptimizationBadMapFn(self): + # Test map functions that give an error + def map_fn(x): + # x has leading dimension 5, this will raise an error + return array_ops.gather(x, 10) + + base_dataset = dataset_ops.Dataset.range(5).repeat(5).batch( + 5, drop_remainder=True) + _, optimized = self._get_test_datasets(base_dataset, map_fn) + nxt = optimized.make_one_shot_iterator().get_next() + with self.assertRaisesRegexp(errors.InvalidArgumentError, + r"indices = 10 is not in \[0, 5\)"): + self.evaluate(nxt) + + def testOptimizationWithCapturedInputs(self): + # Tests that vectorization works with captured inputs + def map_fn(x): + return x + y + + y = constant_op.constant(1, shape=(2,)) + base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2], + [3, 4]]).repeat(5) + # TODO(rachelim): when this optimization works, turn on expect_optimized + unoptimized, optimized = self._get_test_datasets( + base_dataset, map_fn, expect_optimized=False) + self.assertDatasetsEqual(optimized, unoptimized) + + def testOptimizationIgnoreStateful(self): + + def map_fn(x): + with ops.control_dependencies([check_ops.assert_equal(x, 0)]): + return array_ops.identity(x) + + base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2], + [3, 4]]).repeat(5) + unoptimized, optimized = self._get_test_datasets( + base_dataset, map_fn, expect_optimized=False) + self.assertDatasetsRaiseSameError( + unoptimized, optimized, errors.InvalidArgumentError, + [("OneShotIterator", "OneShotIterator_1", 1), + ("IteratorGetNext", "IteratorGetNext_1", 1)]) + + def testOptimizationIgnoreRagged(self): + # Make sure we ignore inputs that might not be uniformly sized + def map_fn(x): + return array_ops.gather(x, 0) + + # output_shape = (?,) + base_dataset = dataset_ops.Dataset.range(20).batch(3, drop_remainder=False) + unoptimized, optimized = self._get_test_datasets( + base_dataset, map_fn, expect_optimized=False) + self.assertDatasetsEqual(unoptimized, optimized) + + def testOptimizationIgnoreRaggedMap(self): + # Don't optimize when the output of the map fn shapes are unknown. + def map_fn(x): + return array_ops.tile(x, x) + + base_dataset = dataset_ops.Dataset.range(20).batch(1, drop_remainder=True) + unoptimized, optimized = self._get_test_datasets( + base_dataset, map_fn, expect_optimized=False) + self.assertDatasetsRaiseSameError( + unoptimized, optimized, errors.InvalidArgumentError, + [("OneShotIterator", "OneShotIterator_1", 1), + ("IteratorGetNext", "IteratorGetNext_1", 1)]) + + +class MapVectorizationBenchmark(test.Benchmark): + # TODO(rachelim): Add a benchmark for more expensive transformations, such as + # vgg_preprocessing. + + def _run(self, x, num_iters=100, name=None): + deltas = [] + with session.Session() as sess: + for _ in range(5): + # Warm up session... + sess.run(x) + for _ in range(num_iters): + start = time.time() + sess.run(x) + end = time.time() + deltas.append(end - start) + median_time = np.median(deltas) + self.report_benchmark(iters=num_iters, wall_time=median_time, name=name) + return median_time + + def _compare(self, input_dataset, map_fn, batch_size, input_size, str_id): + num_elems = np.prod(input_size) + name_template = "{}__batch_size_{}_input_size_{}_{}" + unoptimized = input_dataset.map(map_fn).batch(batch_size) + unoptimized_op = unoptimized.make_one_shot_iterator().get_next() + + optimized = unoptimized.apply(optimization.optimize(["map_vectorization"])) + optimized_op = optimized.make_one_shot_iterator().get_next() + + unoptimized_time = self._run( + unoptimized_op, + name=name_template.format(str_id, batch_size, num_elems, "unoptimized")) + optimized_time = self._run( + optimized_op, + name=name_template.format(str_id, batch_size, num_elems, "optimized")) + + print("Batch size: {}\n" + "Input size: {}\n" + "Transformation: {}\n" + "Speedup: {}\n".format(batch_size, input_size, str_id, + (unoptimized_time / optimized_time))) + + # Known cheap functions + def benchmarkIdentity(self): + self._benchmark_helper(lambda *args: [array_ops.identity(x) for x in args], + "identity") + + def benchmarkAddConst(self): + self._benchmark_helper(lambda *args: [x + 1 for x in args], "add_const") + + def benchmarkSelect(self): + self._benchmark_helper(lambda *args: args[0], "select") + + def benchmarkCast(self): + self._benchmark_helper( + lambda *args: [math_ops.cast(x, dtypes.float64) for x in args], "cast") + + def _benchmark_helper(self, map_fn, str_id): + input_sizes = [(10, 10, 3), (10, 100, 300)] + batch_size = 1000 + for input_size in input_sizes: + input_dataset = dataset_ops.Dataset.from_tensor_slices( + (np.random.rand(*input_size), np.random.rand(*input_size))).repeat() + self._compare(input_dataset, map_fn, batch_size, input_size, str_id) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/model_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/model_dataset_op_test.py new file mode 100644 index 0000000000..a9f2ce8c03 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/model_dataset_op_test.py @@ -0,0 +1,183 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time + +import numpy as np + +from tensorflow.python.data.experimental.ops import batching +from tensorflow.python.data.experimental.ops import optimization +from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class ModelDatasetTest(test_base.DatasetTestBase): + + def testModelMap(self): + k = 1024 * 1024 + dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k), + np.random.rand(4 * k, + 1))).repeat() + dataset = dataset.map(math_ops.matmul) + iterator = dataset.apply(optimization.model()).make_one_shot_iterator() + get_next = iterator.get_next() + + deltas = [] + with self.cached_session() as sess: + for _ in range(5): + sess.run(get_next.op) + for _ in range(100): + start = time.time() + sess.run(get_next.op) + end = time.time() + deltas.append(end - start) + + print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" % + (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas), + np.max(deltas))) + + def testModelParallelMap(self): + k = 1024 * 1024 + dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k), + np.random.rand(4 * k, + 1))).repeat() + dataset = dataset.map( + math_ops.matmul, num_parallel_calls=optimization.AUTOTUNE) + iterator = dataset.apply(optimization.model()).make_one_shot_iterator() + get_next = iterator.get_next() + + deltas = [] + with self.cached_session() as sess: + for _ in range(5): + sess.run(get_next.op) + for _ in range(1000): + start = time.time() + sess.run(get_next.op) + end = time.time() + deltas.append(end - start) + + print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" % + (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas), + np.max(deltas))) + + def testModelMapAndBatch(self): + batch_size = 16 + k = 1024 * 1024 + dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k), + np.random.rand(4 * k, + 1))).repeat() + dataset = dataset.apply( + batching.map_and_batch( + math_ops.matmul, + num_parallel_calls=optimization.AUTOTUNE, + batch_size=batch_size)) + iterator = dataset.apply(optimization.model()).make_one_shot_iterator() + get_next = iterator.get_next() + + deltas = [] + with self.cached_session() as sess: + for _ in range(5): + sess.run(get_next.op) + for _ in range(10): + start = time.time() + sess.run(get_next.op) + end = time.time() + deltas.append(end - start) + + print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" % + (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas), + np.max(deltas))) + + def testModelParallelInterleave(self): + k = 1024 * 1024 + dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k), + np.random.rand(4 * k, + 1))).repeat() + dataset = dataset.map(math_ops.matmul) + dataset = dataset_ops.Dataset.range(1).repeat().interleave( + lambda _: dataset, + cycle_length=10, + num_parallel_calls=optimization.AUTOTUNE) + iterator = dataset.apply(optimization.model()).make_one_shot_iterator() + get_next = iterator.get_next() + + deltas = [] + with self.cached_session() as sess: + for _ in range(5): + sess.run(get_next.op) + for _ in range(1000): + start = time.time() + sess.run(get_next.op) + end = time.time() + deltas.append(end - start) + + print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" % + (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas), + np.max(deltas))) + + def testModelNested(self): + k = 1024 * 1024 + a = (np.random.rand(1, 8 * k), np.random.rand(8 * k, 1)) + b = (np.random.rand(1, 4 * k), np.random.rand(4 * k, 1)) + c = (np.random.rand(1, 2 * k), np.random.rand(2 * k, 1)) + dataset = dataset_ops.Dataset.from_tensors((a, b, c)).repeat() + + def f1(a, b, c): + x, y = a + return math_ops.matmul(x, y), b, c + + def f2(a, b, c): + x, y = b + return a, math_ops.matmul(x, y), c + + def f3(a, b, c): + x, y = c + return a, b, math_ops.matmul(x, y) + + dataset = dataset.map(f1, num_parallel_calls=optimization.AUTOTUNE) + dataset = dataset_ops.Dataset.range(1).repeat().interleave( + lambda _: dataset, cycle_length=2) + + dataset = dataset.map(f2, num_parallel_calls=optimization.AUTOTUNE) + dataset = dataset_ops.Dataset.range(1).repeat().interleave( + lambda _: dataset, cycle_length=2) + + dataset = dataset.map(f3, num_parallel_calls=optimization.AUTOTUNE) + iterator = dataset.apply(optimization.model()).make_one_shot_iterator() + get_next = iterator.get_next() + + deltas = [] + with self.cached_session() as sess: + for _ in range(5): + sess.run(get_next) + for _ in range(100): + start = time.time() + sess.run(get_next) + end = time.time() + deltas.append(end - start) + + print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" % + (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas), + np.max(deltas))) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/noop_elimination_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/noop_elimination_test.py new file mode 100644 index 0000000000..092e0ff62a --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/noop_elimination_test.py @@ -0,0 +1,58 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the MapParallelization optimization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.experimental.ops import optimization +from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class NoopEliminationTest(test_base.DatasetTestBase): + + def testNoopElimination(self): + a = constant_op.constant(1, dtype=dtypes.int64) + b = constant_op.constant(2, dtype=dtypes.int64) + some_tensor = math_ops.mul(a, b) + + dataset = dataset_ops.Dataset.range(5) + dataset = dataset.apply( + optimization.assert_next( + ["FiniteRepeat", "FiniteSkip", "Prefetch", "Prefetch"])) + dataset = dataset.repeat(some_tensor).skip(5).prefetch(0).take(-1).skip( + 0).repeat(1).prefetch(0) + dataset = dataset.apply(optimization.optimize(["noop_elimination"])) + + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + for x in range(5): + result = sess.run(get_next) + self.assertAllEqual(result, x) + + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_op_test.py new file mode 100644 index 0000000000..eb661796c0 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_op_test.py @@ -0,0 +1,109 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.data.experimental.ops import optimization +from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.platform import test + + +class OptimizeDatasetTest(test_base.DatasetTestBase): + + def testOptimizationDefault(self): + dataset = dataset_ops.Dataset.range(10).apply( + optimization.assert_next( + ["Map", "Batch"])).map(lambda x: x * x).batch(10).apply( + optimization.optimize()) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.cached_session() as sess: + self.assertAllEqual([x * x for x in range(10)], sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testOptimizationEmpty(self): + dataset = dataset_ops.Dataset.range(10).apply( + optimization.assert_next( + ["Map", "Batch"])).map(lambda x: x * x).batch(10).apply( + optimization.optimize([])) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.cached_session() as sess: + self.assertAllEqual([x * x for x in range(10)], sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testOptimizationFusion(self): + dataset = dataset_ops.Dataset.range(10).apply( + optimization.assert_next( + ["MapAndBatch"])).map(lambda x: x * x).batch(10).apply( + optimization.optimize(["map_and_batch_fusion"])) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.cached_session() as sess: + self.assertAllEqual([x * x for x in range(10)], sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testOptimizationStatefulFunction(self): + dataset = dataset_ops.Dataset.range(10).map( + lambda _: random_ops.random_uniform([])).batch(10).apply( + optimization.optimize(["map_and_batch_fusion"])) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.cached_session() as sess: + sess.run(get_next) + + def testOptimizationLargeInputFromTensor(self): + input_t = array_ops.placeholder(dtypes.int32, (None, None, None)) + dataset = dataset_ops.Dataset.from_tensors(input_t).apply( + optimization.optimize()) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.cached_session() as sess: + sess.run(init_op, {input_t: np.ones([512, 1024, 1025], np.int32)}) + sess.run(get_next) + + def testOptimizationLargeInputFromTensorSlices(self): + input_t = array_ops.placeholder(dtypes.int32, (None, None, None, None)) + dataset = dataset_ops.Dataset.from_tensor_slices(input_t).apply( + optimization.optimize()) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.cached_session() as sess: + sess.run(init_op, {input_t: np.ones([1, 512, 1024, 1025], np.int32)}) + sess.run(get_next) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/parsing_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/parsing_ops_test.py new file mode 100644 index 0000000000..13f924b656 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/parsing_ops_test.py @@ -0,0 +1,850 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.ops.parsing_ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy + +import numpy as np + +from tensorflow.core.example import example_pb2 +from tensorflow.core.example import feature_pb2 +from tensorflow.python.data.experimental.ops import parsing_ops as contrib_parsing_ops +from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import nest +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors_impl +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import parsing_ops +from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging + +# Helpers for creating Example objects +example = example_pb2.Example +feature = feature_pb2.Feature +features = lambda d: feature_pb2.Features(feature=d) +bytes_feature = lambda v: feature(bytes_list=feature_pb2.BytesList(value=v)) +int64_feature = lambda v: feature(int64_list=feature_pb2.Int64List(value=v)) +float_feature = lambda v: feature(float_list=feature_pb2.FloatList(value=v)) +# Helpers for creating SequenceExample objects +feature_list = lambda l: feature_pb2.FeatureList(feature=l) +feature_lists = lambda d: feature_pb2.FeatureLists(feature_list=d) +sequence_example = example_pb2.SequenceExample + + +def _compare_output_to_expected(tester, dict_tensors, expected_tensors, + flat_output): + tester.assertEqual(set(dict_tensors.keys()), set(expected_tensors.keys())) + + i = 0 # Index into the flattened output of session.run() + for k, v in sorted(dict_tensors.items()): + # TODO(shivaniagrawal): flat_output is same as v. + expected_v = expected_tensors[k] + tf_logging.info("Comparing key: %s", k) + print("i", i, "flat_output", flat_output[i], "expected_v", expected_v) + if sparse_tensor.is_sparse(v): + # Three outputs for SparseTensor : indices, values, shape. + tester.assertEqual([k, len(expected_v)], [k, 3]) + print("i", i, "flat_output", flat_output[i].indices, "expected_v", + expected_v[0]) + tester.assertAllEqual(expected_v[0], flat_output[i].indices) + tester.assertAllEqual(expected_v[1], flat_output[i].values) + tester.assertAllEqual(expected_v[2], flat_output[i].dense_shape) + else: + # One output for standard Tensor. + tester.assertAllEqual(expected_v, flat_output[i]) + i += 1 + + +class ParseExampleTest(test_base.DatasetTestBase): + + def _test(self, + input_tensor, + feature_val, + expected_values=None, + expected_err=None): + + with self.cached_session() as sess: + if expected_err: + with self.assertRaisesWithPredicateMatch(expected_err[0], + expected_err[1]): + dataset = dataset_ops.Dataset.from_tensors(input_tensor).apply( + contrib_parsing_ops.parse_example_dataset(feature_val)) + get_next = dataset.make_one_shot_iterator().get_next() + sess.run(get_next) + return + else: + # Returns dict w/ Tensors and SparseTensors. + # Check values. + dataset = dataset_ops.Dataset.from_tensors(input_tensor).apply( + contrib_parsing_ops.parse_example_dataset(feature_val)) + get_next = dataset.make_one_shot_iterator().get_next() + result = sess.run(get_next) + flattened = nest.flatten(result) + print("result", result, "expected_values", expected_values) + _compare_output_to_expected(self, result, expected_values, flattened) + + # Check shapes; if serialized is a Tensor we need its size to + # properly check. + batch_size = ( + input_tensor.eval().size if isinstance(input_tensor, ops.Tensor) else + np.asarray(input_tensor).size) + for k, f in feature_val.items(): + print("output_shapes as list ", + tuple(dataset.output_shapes[k].as_list())) + if isinstance(f, parsing_ops.FixedLenFeature) and f.shape is not None: + self.assertEqual(dataset.output_shapes[k].as_list()[0], batch_size) + elif isinstance(f, parsing_ops.VarLenFeature): + self.assertEqual(dataset.output_shapes[k].as_list()[1], None) + + def testEmptySerializedWithAllDefaults(self): + sparse_name = "st_a" + a_name = "a" + b_name = "b" + c_name = "c:has_a_tricky_name" + a_default = [0, 42, 0] + b_default = np.random.rand(3, 3).astype(bytes) + c_default = np.random.rand(2).astype(np.float32) + + expected_st_a = ( # indices, values, shape + np.empty( + (0, 2), dtype=np.int64), # indices + np.empty( + (0,), dtype=np.int64), # sp_a is DT_INT64 + np.array( + [2, 0], dtype=np.int64)) # batch == 2, max_elems = 0 + + expected_output = { + sparse_name: expected_st_a, + a_name: np.array(2 * [[a_default]]), + b_name: np.array(2 * [b_default]), + c_name: np.array(2 * [c_default]), + } + + self._test( + ops.convert_to_tensor(["", ""]), { + sparse_name: + parsing_ops.VarLenFeature(dtypes.int64), + a_name: + parsing_ops.FixedLenFeature( + (1, 3), dtypes.int64, default_value=a_default), + b_name: + parsing_ops.FixedLenFeature( + (3, 3), dtypes.string, default_value=b_default), + c_name: + parsing_ops.FixedLenFeature( + (2,), dtypes.float32, default_value=c_default), + }, + expected_values=expected_output) + + def testEmptySerializedWithoutDefaultsShouldFail(self): + input_features = { + "st_a": + parsing_ops.VarLenFeature(dtypes.int64), + "a": + parsing_ops.FixedLenFeature( + (1, 3), dtypes.int64, default_value=[0, 42, 0]), + "b": + parsing_ops.FixedLenFeature( + (3, 3), + dtypes.string, + default_value=np.random.rand(3, 3).astype(bytes)), + # Feature "c" is missing a default, this gap will cause failure. + "c": + parsing_ops.FixedLenFeature( + (2,), dtype=dtypes.float32), + } + + # Edge case where the key is there but the feature value is empty + original = example(features=features({"c": feature()})) + self._test( + [original.SerializeToString()], + input_features, + expected_err=(errors_impl.InvalidArgumentError, + "Feature: c \\(data type: float\\) is required")) + + # Standard case of missing key and value. + self._test( + ["", ""], + input_features, + expected_err=(errors_impl.InvalidArgumentError, + "Feature: c \\(data type: float\\) is required")) + + def testDenseNotMatchingShapeShouldFail(self): + original = [ + example(features=features({ + "a": float_feature([1, 1, 3]), + })), example(features=features({ + "a": float_feature([-1, -1]), + })) + ] + + serialized = [m.SerializeToString() for m in original] + + self._test( + ops.convert_to_tensor(serialized), + {"a": parsing_ops.FixedLenFeature((1, 3), dtypes.float32)}, + expected_err=(errors_impl.InvalidArgumentError, + "Key: a, Index: 1. Number of float values")) + + def testDenseDefaultNoShapeShouldFail(self): + original = [example(features=features({"a": float_feature([1, 1, 3]),})),] + + serialized = [m.SerializeToString() for m in original] + + self._test( + ops.convert_to_tensor(serialized), + {"a": parsing_ops.FixedLenFeature(None, dtypes.float32)}, + expected_err=(ValueError, "Missing shape for feature a")) + + def testSerializedContainingSparse(self): + original = [ + example(features=features({ + "st_c": float_feature([3, 4]) + })), + example(features=features({ + "st_c": float_feature([]), # empty float list + })), + example(features=features({ + "st_d": feature(), # feature with nothing in it + })), + example(features=features({ + "st_c": float_feature([1, 2, -1]), + "st_d": bytes_feature([b"hi"]) + })) + ] + + serialized = [m.SerializeToString() for m in original] + + expected_st_c = ( # indices, values, shape + np.array( + [[0, 0], [0, 1], [3, 0], [3, 1], [3, 2]], dtype=np.int64), np.array( + [3.0, 4.0, 1.0, 2.0, -1.0], dtype=np.float32), np.array( + [4, 3], dtype=np.int64)) # batch == 2, max_elems = 3 + + expected_st_d = ( # indices, values, shape + np.array( + [[3, 0]], dtype=np.int64), np.array( + ["hi"], dtype=bytes), np.array( + [4, 1], dtype=np.int64)) # batch == 2, max_elems = 1 + + expected_output = { + "st_c": expected_st_c, + "st_d": expected_st_d, + } + + self._test( + ops.convert_to_tensor(serialized), { + "st_c": parsing_ops.VarLenFeature(dtypes.float32), + "st_d": parsing_ops.VarLenFeature(dtypes.string) + }, + expected_values=expected_output) + + def testSerializedContainingSparseFeature(self): + original = [ + example(features=features({ + "val": float_feature([3, 4]), + "idx": int64_feature([5, 10]) + })), + example(features=features({ + "val": float_feature([]), # empty float list + "idx": int64_feature([]) + })), + example(features=features({ + "val": feature(), # feature with nothing in it + # missing idx feature + })), + example(features=features({ + "val": float_feature([1, 2, -1]), + "idx": + int64_feature([0, 9, 3]) # unsorted + })) + ] + + serialized = [m.SerializeToString() for m in original] + + expected_sp = ( # indices, values, shape + np.array( + [[0, 5], [0, 10], [3, 0], [3, 3], [3, 9]], dtype=np.int64), + np.array( + [3.0, 4.0, 1.0, -1.0, 2.0], dtype=np.float32), np.array( + [4, 13], dtype=np.int64)) # batch == 4, max_elems = 13 + + expected_output = {"sp": expected_sp,} + + self._test( + ops.convert_to_tensor(serialized), + {"sp": parsing_ops.SparseFeature(["idx"], "val", dtypes.float32, [13])}, + expected_values=expected_output) + + def testSerializedContainingSparseFeatureReuse(self): + original = [ + example(features=features({ + "val1": float_feature([3, 4]), + "val2": float_feature([5, 6]), + "idx": int64_feature([5, 10]) + })), + example(features=features({ + "val1": float_feature([]), # empty float list + "idx": int64_feature([]) + })), + ] + + serialized = [m.SerializeToString() for m in original] + + expected_sp1 = ( # indices, values, shape + np.array( + [[0, 5], [0, 10]], dtype=np.int64), np.array( + [3.0, 4.0], dtype=np.float32), np.array( + [2, 13], dtype=np.int64)) # batch == 2, max_elems = 13 + + expected_sp2 = ( # indices, values, shape + np.array( + [[0, 5], [0, 10]], dtype=np.int64), np.array( + [5.0, 6.0], dtype=np.float32), np.array( + [2, 7], dtype=np.int64)) # batch == 2, max_elems = 13 + + expected_output = { + "sp1": expected_sp1, + "sp2": expected_sp2, + } + + self._test( + ops.convert_to_tensor(serialized), { + "sp1": + parsing_ops.SparseFeature("idx", "val1", dtypes.float32, 13), + "sp2": + parsing_ops.SparseFeature( + "idx", "val2", dtypes.float32, size=7, already_sorted=True) + }, + expected_values=expected_output) + + def testSerializedContaining3DSparseFeature(self): + original = [ + example(features=features({ + "val": float_feature([3, 4]), + "idx0": int64_feature([5, 10]), + "idx1": int64_feature([0, 2]), + })), + example(features=features({ + "val": float_feature([]), # empty float list + "idx0": int64_feature([]), + "idx1": int64_feature([]), + })), + example(features=features({ + "val": feature(), # feature with nothing in it + # missing idx feature + })), + example(features=features({ + "val": float_feature([1, 2, -1]), + "idx0": int64_feature([0, 9, 3]), # unsorted + "idx1": int64_feature([1, 0, 2]), + })) + ] + + serialized = [m.SerializeToString() for m in original] + + expected_sp = ( + # indices + np.array( + [[0, 5, 0], [0, 10, 2], [3, 0, 1], [3, 3, 2], [3, 9, 0]], + dtype=np.int64), + # values + np.array([3.0, 4.0, 1.0, -1.0, 2.0], dtype=np.float32), + # shape batch == 4, max_elems = 13 + np.array([4, 13, 3], dtype=np.int64)) + + expected_output = {"sp": expected_sp,} + + self._test( + ops.convert_to_tensor(serialized), { + "sp": + parsing_ops.SparseFeature(["idx0", "idx1"], "val", + dtypes.float32, [13, 3]) + }, + expected_values=expected_output) + + def testSerializedContainingDense(self): + aname = "a" + bname = "b*has+a:tricky_name" + original = [ + example(features=features({ + aname: float_feature([1, 1]), + bname: bytes_feature([b"b0_str"]), + })), example(features=features({ + aname: float_feature([-1, -1]), + bname: bytes_feature([b""]), + })) + ] + + serialized = [m.SerializeToString() for m in original] + + expected_output = { + aname: + np.array( + [[1, 1], [-1, -1]], dtype=np.float32).reshape(2, 1, 2, 1), + bname: + np.array( + ["b0_str", ""], dtype=bytes).reshape(2, 1, 1, 1, 1), + } + + # No defaults, values required + self._test( + ops.convert_to_tensor(serialized), { + aname: + parsing_ops.FixedLenFeature((1, 2, 1), dtype=dtypes.float32), + bname: + parsing_ops.FixedLenFeature((1, 1, 1, 1), dtype=dtypes.string), + }, + expected_values=expected_output) + + # This test is identical as the previous one except + # for the creation of 'serialized'. + def testSerializedContainingDenseWithConcat(self): + aname = "a" + bname = "b*has+a:tricky_name" + # TODO(lew): Feature appearing twice should be an error in future. + original = [ + (example(features=features({ + aname: float_feature([10, 10]), + })), example(features=features({ + aname: float_feature([1, 1]), + bname: bytes_feature([b"b0_str"]), + }))), + ( + example(features=features({ + bname: bytes_feature([b"b100"]), + })), + example(features=features({ + aname: float_feature([-1, -1]), + bname: bytes_feature([b"b1"]), + })),), + ] + + serialized = [ + m.SerializeToString() + n.SerializeToString() for (m, n) in original + ] + + expected_output = { + aname: + np.array( + [[1, 1], [-1, -1]], dtype=np.float32).reshape(2, 1, 2, 1), + bname: + np.array( + ["b0_str", "b1"], dtype=bytes).reshape(2, 1, 1, 1, 1), + } + + # No defaults, values required + self._test( + ops.convert_to_tensor(serialized), { + aname: + parsing_ops.FixedLenFeature((1, 2, 1), dtype=dtypes.float32), + bname: + parsing_ops.FixedLenFeature((1, 1, 1, 1), dtype=dtypes.string), + }, + expected_values=expected_output) + + def testSerializedContainingDenseScalar(self): + original = [ + example(features=features({ + "a": float_feature([1]), + })), example(features=features({})) + ] + + serialized = [m.SerializeToString() for m in original] + + expected_output = { + "a": + np.array( + [[1], [-1]], dtype=np.float32) # 2x1 (column vector) + } + + self._test( + ops.convert_to_tensor(serialized), { + "a": + parsing_ops.FixedLenFeature( + (1,), dtype=dtypes.float32, default_value=-1), + }, + expected_values=expected_output) + + def testSerializedContainingDenseWithDefaults(self): + original = [ + example(features=features({ + "a": float_feature([1, 1]), + })), + example(features=features({ + "b": bytes_feature([b"b1"]), + })), + example(features=features({ + "b": feature() + })), + ] + + serialized = [m.SerializeToString() for m in original] + + expected_output = { + "a": + np.array( + [[1, 1], [3, -3], [3, -3]], dtype=np.float32).reshape(3, 1, 2, + 1), + "b": + np.array( + ["tmp_str", "b1", "tmp_str"], dtype=bytes).reshape(3, 1, 1, 1, + 1), + } + + self._test( + ops.convert_to_tensor(serialized), { + "a": + parsing_ops.FixedLenFeature( + (1, 2, 1), dtype=dtypes.float32, default_value=[3.0, -3.0]), + "b": + parsing_ops.FixedLenFeature( + (1, 1, 1, 1), dtype=dtypes.string, default_value="tmp_str"), + }, + expected_values=expected_output) + + def testSerializedContainingSparseAndSparseFeatureAndDenseWithNoDefault(self): + expected_st_a = ( # indices, values, shape + np.empty( + (0, 2), dtype=np.int64), # indices + np.empty( + (0,), dtype=np.int64), # sp_a is DT_INT64 + np.array( + [2, 0], dtype=np.int64)) # batch == 2, max_elems = 0 + expected_sp = ( # indices, values, shape + np.array( + [[0, 0], [0, 3], [1, 7]], dtype=np.int64), np.array( + ["a", "b", "c"], dtype="|S"), np.array( + [2, 13], dtype=np.int64)) # batch == 4, max_elems = 13 + + original = [ + example(features=features({ + "c": float_feature([3, 4]), + "val": bytes_feature([b"a", b"b"]), + "idx": int64_feature([0, 3]) + })), example(features=features({ + "c": float_feature([1, 2]), + "val": bytes_feature([b"c"]), + "idx": int64_feature([7]) + })) + ] + + serialized = [m.SerializeToString() for m in original] + + a_default = [1, 2, 3] + b_default = np.random.rand(3, 3).astype(bytes) + expected_output = { + "st_a": expected_st_a, + "sp": expected_sp, + "a": np.array(2 * [[a_default]]), + "b": np.array(2 * [b_default]), + "c": np.array( + [[3, 4], [1, 2]], dtype=np.float32), + } + + self._test( + ops.convert_to_tensor(serialized), + { + "st_a": + parsing_ops.VarLenFeature(dtypes.int64), + "sp": + parsing_ops.SparseFeature("idx", "val", dtypes.string, 13), + "a": + parsing_ops.FixedLenFeature( + (1, 3), dtypes.int64, default_value=a_default), + "b": + parsing_ops.FixedLenFeature( + (3, 3), dtypes.string, default_value=b_default), + # Feature "c" must be provided, since it has no default_value. + "c": + parsing_ops.FixedLenFeature((2,), dtypes.float32), + }, + expected_values=expected_output) + + def testSerializedContainingSparseAndSparseFeatureWithReuse(self): + expected_idx = ( # indices, values, shape + np.array( + [[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.int64), + np.array([0, 3, 7, 1]), np.array( + [2, 2], dtype=np.int64)) # batch == 4, max_elems = 2 + + expected_sp = ( # indices, values, shape + np.array( + [[0, 0], [0, 3], [1, 1], [1, 7]], dtype=np.int64), np.array( + ["a", "b", "d", "c"], dtype="|S"), np.array( + [2, 13], dtype=np.int64)) # batch == 4, max_elems = 13 + + original = [ + example(features=features({ + "val": bytes_feature([b"a", b"b"]), + "idx": int64_feature([0, 3]) + })), example(features=features({ + "val": bytes_feature([b"c", b"d"]), + "idx": int64_feature([7, 1]) + })) + ] + + serialized = [m.SerializeToString() for m in original] + + expected_output = { + "idx": expected_idx, + "sp": expected_sp, + } + + self._test( + ops.convert_to_tensor(serialized), { + "idx": + parsing_ops.VarLenFeature(dtypes.int64), + "sp": + parsing_ops.SparseFeature(["idx"], "val", dtypes.string, [13]), + }, + expected_values=expected_output) + + def _testSerializedContainingVarLenDenseLargerBatch(self, batch_size): + # During parsing, data read from the serialized proto is stored in buffers. + # For small batch sizes, a buffer will contain one minibatch entry. + # For larger batch sizes, a buffer may contain several minibatch + # entries. This test identified a bug where the code that copied + # data out of the buffers and into the output tensors assumed each + # buffer only contained one minibatch entry. The bug has since been fixed. + truth_int = [i for i in range(batch_size)] + truth_str = [[("foo%d" % i).encode(), ("bar%d" % i).encode()] + for i in range(batch_size)] + + expected_str = copy.deepcopy(truth_str) + + # Delete some intermediate entries + for i in range(batch_size): + col = 1 + if np.random.rand() < 0.25: + # w.p. 25%, drop out the second entry + expected_str[i][col] = b"default" + col -= 1 + truth_str[i].pop() + if np.random.rand() < 0.25: + # w.p. 25%, drop out the second entry (possibly again) + expected_str[i][col] = b"default" + truth_str[i].pop() + + expected_output = { + # Batch size batch_size, 1 time step. + "a": np.array(truth_int, dtype=np.int64).reshape(batch_size, 1), + # Batch size batch_size, 2 time steps. + "b": np.array(expected_str, dtype="|S").reshape(batch_size, 2), + } + + original = [ + example(features=features( + {"a": int64_feature([truth_int[i]]), + "b": bytes_feature(truth_str[i])})) + for i in range(batch_size) + ] + + serialized = [m.SerializeToString() for m in original] + + self._test( + ops.convert_to_tensor(serialized, dtype=dtypes.string), { + "a": + parsing_ops.FixedLenSequenceFeature( + shape=(), + dtype=dtypes.int64, + allow_missing=True, + default_value=-1), + "b": + parsing_ops.FixedLenSequenceFeature( + shape=[], + dtype=dtypes.string, + allow_missing=True, + default_value="default"), + }, + expected_values=expected_output) + + def testSerializedContainingVarLenDenseLargerBatch(self): + np.random.seed(3456) + for batch_size in (1, 10, 20, 100, 256): + self._testSerializedContainingVarLenDenseLargerBatch(batch_size) + + def testSerializedContainingVarLenDense(self): + aname = "a" + bname = "b" + cname = "c" + dname = "d" + original = [ + example(features=features({ + cname: int64_feature([2]), + })), + example(features=features({ + aname: float_feature([1, 1]), + bname: bytes_feature([b"b0_str", b"b1_str"]), + })), + example(features=features({ + aname: float_feature([-1, -1, 2, 2]), + bname: bytes_feature([b"b1"]), + })), + example(features=features({ + aname: float_feature([]), + cname: int64_feature([3]), + })), + ] + + serialized = [m.SerializeToString() for m in original] + + expected_output = { + aname: + np.array( + [ + [0, 0, 0, 0], + [1, 1, 0, 0], + [-1, -1, 2, 2], + [0, 0, 0, 0], + ], + dtype=np.float32).reshape(4, 2, 2, 1), + bname: + np.array( + [["", ""], ["b0_str", "b1_str"], ["b1", ""], ["", ""]], + dtype=bytes).reshape(4, 2, 1, 1, 1), + cname: + np.array([2, 0, 0, 3], dtype=np.int64).reshape(4, 1), + dname: + np.empty(shape=(4, 0), dtype=bytes), + } + + self._test( + ops.convert_to_tensor(serialized), { + aname: + parsing_ops.FixedLenSequenceFeature( + (2, 1), dtype=dtypes.float32, allow_missing=True), + bname: + parsing_ops.FixedLenSequenceFeature( + (1, 1, 1), dtype=dtypes.string, allow_missing=True), + cname: + parsing_ops.FixedLenSequenceFeature( + shape=[], dtype=dtypes.int64, allow_missing=True), + dname: + parsing_ops.FixedLenSequenceFeature( + shape=[], dtype=dtypes.string, allow_missing=True), + }, + expected_values=expected_output) + + # Test with padding values. + expected_output_custom_padding = dict(expected_output) + expected_output_custom_padding[aname] = np.array( + [ + [-2, -2, -2, -2], + [1, 1, -2, -2], + [-1, -1, 2, 2], + [-2, -2, -2, -2], + ], + dtype=np.float32).reshape(4, 2, 2, 1) + + self._test( + ops.convert_to_tensor(serialized), { + aname: + parsing_ops.FixedLenSequenceFeature( + (2, 1), + dtype=dtypes.float32, + allow_missing=True, + default_value=-2.0), + bname: + parsing_ops.FixedLenSequenceFeature( + (1, 1, 1), dtype=dtypes.string, allow_missing=True), + cname: + parsing_ops.FixedLenSequenceFeature( + shape=[], dtype=dtypes.int64, allow_missing=True), + dname: + parsing_ops.FixedLenSequenceFeature( + shape=[], dtype=dtypes.string, allow_missing=True), + }, expected_output_custom_padding) + + # Change number of required values so the inputs are not a + # multiple of this size. + self._test( + ops.convert_to_tensor(serialized), { + aname: + parsing_ops.FixedLenSequenceFeature( + (2, 1), dtype=dtypes.float32, allow_missing=True), + bname: + parsing_ops.FixedLenSequenceFeature( + (2, 1, 1), dtype=dtypes.string, allow_missing=True), + }, + expected_err=( + errors_impl.OpError, "Key: b, Index: 2. " + "Number of bytes values is not a multiple of stride length.")) + + self._test( + ops.convert_to_tensor(serialized), { + aname: + parsing_ops.FixedLenSequenceFeature( + (2, 1), + dtype=dtypes.float32, + allow_missing=True, + default_value=[]), + bname: + parsing_ops.FixedLenSequenceFeature( + (2, 1, 1), dtype=dtypes.string, allow_missing=True), + }, + expected_err=(ValueError, + "Cannot reshape a tensor with 0 elements to shape")) + + self._test( + ops.convert_to_tensor(serialized), { + aname: + parsing_ops.FixedLenFeature((None, 2, 1), dtype=dtypes.float32), + bname: + parsing_ops.FixedLenSequenceFeature( + (2, 1, 1), dtype=dtypes.string, allow_missing=True), + }, + expected_err=(ValueError, + "First dimension of shape for feature a unknown. " + "Consider using FixedLenSequenceFeature.")) + + self._test( + ops.convert_to_tensor(serialized), { + cname: + parsing_ops.FixedLenFeature( + (1, None), dtype=dtypes.int64, default_value=[[1]]), + }, + expected_err=(ValueError, + "All dimensions of shape for feature c need to be known " + r"but received \(1, None\).")) + + self._test( + ops.convert_to_tensor(serialized), { + aname: + parsing_ops.FixedLenSequenceFeature( + (2, 1), dtype=dtypes.float32, allow_missing=True), + bname: + parsing_ops.FixedLenSequenceFeature( + (1, 1, 1), dtype=dtypes.string, allow_missing=True), + cname: + parsing_ops.FixedLenSequenceFeature( + shape=[], dtype=dtypes.int64, allow_missing=False), + dname: + parsing_ops.FixedLenSequenceFeature( + shape=[], dtype=dtypes.string, allow_missing=True), + }, + expected_err=(ValueError, + "Unsupported: FixedLenSequenceFeature requires " + "allow_missing to be True.")) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/prefetching_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/prefetching_ops_test.py new file mode 100644 index 0000000000..7d7b842c17 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/prefetching_ops_test.py @@ -0,0 +1,948 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for prefetching_ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import threading + +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.compat import compat +from tensorflow.python.data.experimental.ops import prefetching_ops +from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import iterator_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import function +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import test_util +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.platform import test + + +class PrefetchingKernelsOpsTest(test_base.DatasetTestBase): + + def setUp(self): + self._event = threading.Event() + + def _create_ds_and_iterator(self, device0, initializable=False): + + def gen(): + for i in range(1, 10): + yield [float(i)] + if i == 6: + self._event.set() + + with ops.device(device0): + ds = dataset_ops.Dataset.from_generator(gen, (dtypes.float32)) + if initializable: + ds_iterator = ds.make_initializable_iterator() + else: + ds_iterator = ds.make_one_shot_iterator() + return (ds, ds_iterator) + + def _create_ops(self, ds, ds_iterator, buffer_name, device0, device1): + ds_iterator_handle = ds_iterator.string_handle() + + @function.Defun(dtypes.string) + def _remote_fn(h): + remote_iterator = iterator_ops.Iterator.from_string_handle( + h, ds.output_types, ds.output_shapes) + return remote_iterator.get_next() + + target = constant_op.constant(device0) + with ops.device(device1): + buffer_resource_handle = prefetching_ops.function_buffering_resource( + f=_remote_fn, + output_types=[dtypes.float32], + target_device=target, + string_arg=ds_iterator_handle, + buffer_size=3, + shared_name=buffer_name) + + with ops.device(device1): + prefetch_op = prefetching_ops.function_buffering_resource_get_next( + function_buffer_resource=buffer_resource_handle, + output_types=[dtypes.float32]) + reset_op = prefetching_ops.function_buffering_resource_reset( + function_buffer_resource=buffer_resource_handle) + destroy_op = resource_variable_ops.destroy_resource_op( + buffer_resource_handle, ignore_lookup_error=True) + + return (prefetch_op, reset_op, destroy_op) + + def _prefetch_fn_helper_one_shot(self, buffer_name, device0, device1): + worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) + + ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=False) + prefetch_op, _, destroy_op = self._create_ops(ds, ds_iterator, buffer_name, + device0, device1) + + with self.test_session(config=worker_config) as sess: + elem = sess.run(prefetch_op) + self.assertEqual(elem, [1.0]) + elem = sess.run(prefetch_op) + self.assertEqual(elem, [2.0]) + elem = sess.run(prefetch_op) + self.assertEqual(elem, [3.0]) + elem = sess.run(prefetch_op) + self.assertEqual(elem, [4.0]) + self._event.wait() + elem = sess.run(prefetch_op) + self.assertEqual(elem, [5.0]) + sess.run(destroy_op) + + def testSameDeviceCPU(self): + self._prefetch_fn_helper_one_shot("same_device_cpu", + "/job:localhost/replica:0/task:0/cpu:0", + "/job:localhost/replica:0/task:0/cpu:0") + + def testDifferentDeviceCPU(self): + self._prefetch_fn_helper_one_shot("diff_device_cpu", + "/job:localhost/replica:0/task:0/cpu:0", + "/job:localhost/replica:0/task:0/cpu:1") + + def testDifferentDeviceCPUGPU(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + self._prefetch_fn_helper_one_shot("cpu_gpu", + "/job:localhost/replica:0/task:0/cpu:0", + "/job:localhost/replica:0/task:0/gpu:0") + + def testReinitialization(self): + worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) + + device0 = "/job:localhost/replica:0/task:0/cpu:0" + device1 = "/job:localhost/replica:0/task:0/cpu:1" + ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=True) + prefetch_op, reset_op, destroy_op = self._create_ops( + ds, ds_iterator, "reinit", device0, device1) + + with self.test_session(config=worker_config) as sess: + sess.run(ds_iterator.initializer) + elem = sess.run(prefetch_op) + self.assertEqual(elem, [1.0]) + elem = sess.run(prefetch_op) + self.assertEqual(elem, [2.0]) + elem = sess.run(prefetch_op) + self.assertEqual(elem, [3.0]) + elem = sess.run(prefetch_op) + self.assertEqual(elem, [4.0]) + self._event.wait() + elem = sess.run(prefetch_op) + self.assertEqual(elem, [5.0]) + # Lets reset the function buffering resource and reinitialize the + # iterator. Should be able to go through this again. + self._event.clear() + sess.run(reset_op) + sess.run(ds_iterator.initializer) + elem = sess.run(prefetch_op) + self.assertEqual(elem, [1.0]) + elem = sess.run(prefetch_op) + self.assertEqual(elem, [2.0]) + elem = sess.run(prefetch_op) + self.assertEqual(elem, [3.0]) + elem = sess.run(prefetch_op) + self.assertEqual(elem, [4.0]) + self._event.wait() + elem = sess.run(prefetch_op) + self.assertEqual(elem, [5.0]) + sess.run(destroy_op) + + def testReinitializationOutOfRange(self): + worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) + + device0 = "/job:localhost/replica:0/task:0/cpu:0" + device1 = "/job:localhost/replica:0/task:0/cpu:1" + ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=True) + prefetch_op, reset_op, destroy_op = self._create_ops( + ds, ds_iterator, "reinit", device0, device1) + + with self.test_session(config=worker_config) as sess: + sess.run(ds_iterator.initializer) + for i in range(1, 10): + elem = sess.run(prefetch_op) + self.assertEqual(elem, [float(i)]) + # Try fetching after its over twice to test out end of sequence. + with self.assertRaises(errors.OutOfRangeError): + sess.run(prefetch_op) + with self.assertRaises(errors.OutOfRangeError): + sess.run(prefetch_op) + + # Now reset everything and try it out again. + self._event.clear() + sess.run(reset_op) + sess.run(ds_iterator.initializer) + for i in range(1, 10): + elem = sess.run(prefetch_op) + self.assertEqual(elem, [float(i)]) + # Try fetching after its over twice to test out end of sequence. + with self.assertRaises(errors.OutOfRangeError): + sess.run(prefetch_op) + with self.assertRaises(errors.OutOfRangeError): + sess.run(prefetch_op) + + sess.run(destroy_op) + + def testStringsGPU(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + device0 = "/job:localhost/replica:0/task:0/cpu:0" + device1 = "/job:localhost/replica:0/task:0/gpu:0" + + ds = dataset_ops.Dataset.from_tensor_slices(["a", "b", "c"]) + ds_iterator = ds.make_one_shot_iterator() + ds_iterator_handle = ds_iterator.string_handle() + + @function.Defun(dtypes.string) + def _remote_fn(h): + remote_iterator = iterator_ops.Iterator.from_string_handle( + h, ds.output_types, ds.output_shapes) + return remote_iterator.get_next() + + target = constant_op.constant(device0) + with ops.device(device1): + buffer_resource_handle = prefetching_ops.function_buffering_resource( + f=_remote_fn, + output_types=[dtypes.string], + target_device=target, + string_arg=ds_iterator_handle, + buffer_size=3, + shared_name="strings") + + with ops.device(device1): + prefetch_op = prefetching_ops.function_buffering_resource_get_next( + function_buffer_resource=buffer_resource_handle, + output_types=[dtypes.string]) + destroy_op = resource_variable_ops.destroy_resource_op( + buffer_resource_handle, ignore_lookup_error=True) + + with self.cached_session() as sess: + self.assertEqual([b"a"], sess.run(prefetch_op)) + self.assertEqual([b"b"], sess.run(prefetch_op)) + self.assertEqual([b"c"], sess.run(prefetch_op)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(prefetch_op) + + sess.run(destroy_op) + + +class PrefetchToDeviceTest(test_base.DatasetTestBase): + + def testPrefetchToDevice(self): + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.prefetch_to_device("/cpu:1")) + + # NOTE(mrry): This device block creates the "host" dataset and iterator on + # /cpu:0, and ensures that the prefetching is across devices. In typical use + # this would not be necessary, because the GPU device would not support any + # of the dataset-related ops. + with ops.device("/cpu:0"): + iterator = device_dataset.make_one_shot_iterator() + + self.assertEqual(host_dataset.output_types, device_dataset.output_types) + self.assertEqual(host_dataset.output_types, iterator.output_types) + self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) + self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) + self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) + self.assertEqual(host_dataset.output_classes, iterator.output_classes) + + next_element = iterator.get_next() + self.assertEqual(dtypes.int64, next_element.dtype) + self.assertEqual([], next_element.shape) + + worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) + with self.test_session(config=worker_config) as sess: + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testPrefetchToSameDevice(self): + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.prefetch_to_device( + "/job:localhost/replica:0/task:0/device:CPU:0")) + + # NOTE(mrry): This device block creates the "host" dataset and iterator on + # /cpu:0, and ensures that the prefetching is across devices. In typical use + # this would not be necessary, because the GPU device would not support any + # of the dataset-related ops. + with ops.device("/cpu:0"): + iterator = device_dataset.make_one_shot_iterator() + + self.assertEqual(host_dataset.output_types, device_dataset.output_types) + self.assertEqual(host_dataset.output_types, iterator.output_types) + self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) + self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) + self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) + self.assertEqual(host_dataset.output_classes, iterator.output_classes) + + next_element = iterator.get_next() + self.assertEqual(dtypes.int64, next_element.dtype) + self.assertEqual([], next_element.shape) + + with self.cached_session() as sess: + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testPrefetchDictToDevice(self): + host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x}) + device_dataset = host_dataset.apply( + prefetching_ops.prefetch_to_device("/cpu:1")) + + # NOTE(mrry): This device block creates the "host" dataset and iterator on + # /cpu:0, and ensures that the prefetching is across devices. In typical use + # this would not be necessary, because the GPU device would not support any + # of the dataset-related ops. + with ops.device("/cpu:0"): + iterator = device_dataset.make_one_shot_iterator() + + self.assertEqual(host_dataset.output_types, device_dataset.output_types) + self.assertEqual(host_dataset.output_types, iterator.output_types) + self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) + self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) + self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) + self.assertEqual(host_dataset.output_classes, iterator.output_classes) + + next_element = iterator.get_next() + self.assertEqual(dtypes.int64, next_element["a"].dtype) + self.assertEqual([], next_element["a"].shape) + + worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) + with self.test_session(config=worker_config) as sess: + for i in range(10): + self.assertEqual({"a": i}, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testPrefetchSparseTensorsToDevice(self): + def make_tensor(i): + return sparse_tensor.SparseTensorValue( + indices=[[0, 0]], values=(i*[1]), dense_shape=[2, 2]) + host_dataset = dataset_ops.Dataset.range(10).map(make_tensor) + + device_dataset = host_dataset.apply( + prefetching_ops.prefetch_to_device("/cpu:1")) + + # NOTE(mrry): This device block creates the "host" dataset and iterator on + # /cpu:0, and ensures that the prefetching is across devices. In typical use + # this would not be necessary, because the GPU device would not support any + # of the dataset-related ops. + with ops.device("/cpu:0"): + iterator = device_dataset.make_one_shot_iterator() + + self.assertEqual(host_dataset.output_types, device_dataset.output_types) + self.assertEqual(host_dataset.output_types, iterator.output_types) + self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) + self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) + self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) + self.assertEqual(host_dataset.output_classes, iterator.output_classes) + + next_element = iterator.get_next() + self.assertEqual(dtypes.int64, next_element.dtype) + + worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) + with self.test_session(config=worker_config) as sess: + for i in range(10): + actual = sess.run(next_element) + self.assertAllEqual([i], actual.values) + self.assertAllEqual([[0, 0]], actual.indices) + self.assertAllEqual([2, 2], actual.dense_shape) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testPrefetchToDeviceGpu(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.prefetch_to_device("/gpu:0")) + + iterator = device_dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + with self.cached_session() as sess: + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testPrefetchToDeviceWithReInit(self): + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.prefetch_to_device("/cpu:1")) + + # NOTE(mrry): This device block creates the "host" dataset and iterator on + # /cpu:0, and ensures that the prefetching is across devices. In typical use + # this would not be necessary, because the GPU device would not support any + # of the dataset-related ops. + with ops.device("/cpu:0"): + iterator = device_dataset.make_initializable_iterator() + + self.assertEqual(host_dataset.output_types, device_dataset.output_types) + self.assertEqual(host_dataset.output_types, iterator.output_types) + self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) + self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) + self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) + self.assertEqual(host_dataset.output_classes, iterator.output_classes) + + next_element = iterator.get_next() + self.assertEqual(dtypes.int64, next_element.dtype) + self.assertEqual([], next_element.shape) + + worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) + with self.test_session(config=worker_config) as sess: + sess.run(iterator.initializer) + for i in range(5): + self.assertEqual(i, sess.run(next_element)) + sess.run(iterator.initializer) + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testPrefetchToDeviceGpuWithReInit(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.prefetch_to_device("/gpu:0")) + + iterator = device_dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.cached_session() as sess: + sess.run(iterator.initializer) + for i in range(5): + self.assertEqual(i, sess.run(next_element)) + sess.run(iterator.initializer) + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + +class CopyToDeviceTest(test_base.DatasetTestBase): + + def testCopyToDevice(self): + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/cpu:1")) + + with ops.device("/cpu:1"): + iterator = device_dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + self.assertEqual(host_dataset.output_types, device_dataset.output_types) + self.assertEqual(host_dataset.output_types, iterator.output_types) + self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) + self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) + self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) + self.assertEqual(host_dataset.output_classes, iterator.output_classes) + + self.assertEqual(dtypes.int64, next_element.dtype) + self.assertEqual([], next_element.shape) + + worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) + with self.test_session(config=worker_config) as sess: + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testCopyToDeviceInt32(self): + host_dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3]) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/cpu:1")) + + with ops.device("/cpu:1"): + iterator = device_dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + self.assertEqual(host_dataset.output_types, device_dataset.output_types) + self.assertEqual(host_dataset.output_types, iterator.output_types) + self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) + self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) + self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) + self.assertEqual(host_dataset.output_classes, iterator.output_classes) + + self.assertEqual(dtypes.int32, next_element.dtype) + self.assertEqual((4,), next_element.shape) + + worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) + with self.test_session(config=worker_config) as sess: + self.assertAllEqual([0, 1, 2, 3], sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testCopyToSameDevice(self): + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/cpu:0")) + + with ops.device("/cpu:0"): + iterator = device_dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + self.assertEqual(host_dataset.output_types, device_dataset.output_types) + self.assertEqual(host_dataset.output_types, iterator.output_types) + self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) + self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) + self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) + self.assertEqual(host_dataset.output_classes, iterator.output_classes) + + self.assertEqual(dtypes.int64, next_element.dtype) + self.assertEqual([], next_element.shape) + + worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) + with self.test_session(config=worker_config) as sess: + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testCopyToDeviceWithPrefetch(self): + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/cpu:1")).prefetch(1) + + with ops.device("/cpu:1"): + iterator = device_dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + self.assertEqual(host_dataset.output_types, device_dataset.output_types) + self.assertEqual(host_dataset.output_types, iterator.output_types) + self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) + self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) + self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) + self.assertEqual(host_dataset.output_classes, iterator.output_classes) + + self.assertEqual(dtypes.int64, next_element.dtype) + self.assertEqual([], next_element.shape) + + worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) + with self.test_session(config=worker_config) as sess: + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testCopyDictToDevice(self): + host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x}) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/cpu:1")) + + with ops.device("/cpu:1"): + iterator = device_dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + self.assertEqual(host_dataset.output_types, device_dataset.output_types) + self.assertEqual(host_dataset.output_types, iterator.output_types) + self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) + self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) + self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) + self.assertEqual(host_dataset.output_classes, iterator.output_classes) + + self.assertEqual(dtypes.int64, next_element["a"].dtype) + self.assertEqual([], next_element["a"].shape) + + worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) + with self.test_session(config=worker_config) as sess: + for i in range(10): + self.assertEqual({"a": i}, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testCopyDictToDeviceWithPrefetch(self): + host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x}) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/cpu:1")).prefetch(1) + + with ops.device("/cpu:1"): + iterator = device_dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + self.assertEqual(host_dataset.output_types, device_dataset.output_types) + self.assertEqual(host_dataset.output_types, iterator.output_types) + self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) + self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) + self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) + self.assertEqual(host_dataset.output_classes, iterator.output_classes) + + self.assertEqual(dtypes.int64, next_element["a"].dtype) + self.assertEqual([], next_element["a"].shape) + + worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) + with self.test_session(config=worker_config) as sess: + for i in range(10): + self.assertEqual({"a": i}, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testCopySparseTensorsToDevice(self): + + def make_tensor(i): + return sparse_tensor.SparseTensorValue( + indices=[[0, 0]], values=(i * [1]), dense_shape=[2, 2]) + + host_dataset = dataset_ops.Dataset.range(10).map(make_tensor) + + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/cpu:1")) + + with ops.device("/cpu:1"): + iterator = device_dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + self.assertEqual(host_dataset.output_types, device_dataset.output_types) + self.assertEqual(host_dataset.output_types, iterator.output_types) + self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) + self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) + self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) + self.assertEqual(host_dataset.output_classes, iterator.output_classes) + + self.assertEqual(dtypes.int64, next_element.dtype) + + worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) + with self.test_session(config=worker_config) as sess: + for i in range(10): + actual = sess.run(next_element) + self.assertAllEqual([i], actual.values) + self.assertAllEqual([[0, 0]], actual.indices) + self.assertAllEqual([2, 2], actual.dense_shape) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testCopySparseTensorsToDeviceWithPrefetch(self): + + def make_tensor(i): + return sparse_tensor.SparseTensorValue( + indices=[[0, 0]], values=(i * [1]), dense_shape=[2, 2]) + + host_dataset = dataset_ops.Dataset.range(10).map(make_tensor) + + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/cpu:1")).prefetch(1) + + with ops.device("/cpu:1"): + iterator = device_dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + self.assertEqual(host_dataset.output_types, device_dataset.output_types) + self.assertEqual(host_dataset.output_types, iterator.output_types) + self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) + self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) + self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) + self.assertEqual(host_dataset.output_classes, iterator.output_classes) + + self.assertEqual(dtypes.int64, next_element.dtype) + + worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) + with self.test_session(config=worker_config) as sess: + for i in range(10): + actual = sess.run(next_element) + self.assertAllEqual([i], actual.values) + self.assertAllEqual([[0, 0]], actual.indices) + self.assertAllEqual([2, 2], actual.dense_shape) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testCopyToDeviceGpu(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/gpu:0")) + + with ops.device("/gpu:0"): + iterator = device_dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.cached_session() as sess: + sess.run(iterator.initializer) + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testCopyToDeviceGpuWithPrefetch(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/gpu:0")).prefetch(1) + + with ops.device("/gpu:0"): + iterator = device_dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.cached_session() as sess: + sess.run(iterator.initializer) + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testCopyToDeviceGpuInt32(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3]) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/gpu:0")) + + with ops.device("/gpu:0"): + iterator = device_dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.cached_session() as sess: + sess.run(iterator.initializer) + self.assertAllEqual([0, 1, 2, 3], sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testCopyToDeviceGpuInt32AndPrefetch(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3]) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/gpu:0")).prefetch(1) + + with ops.device("/gpu:0"): + iterator = device_dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.cached_session() as sess: + sess.run(iterator.initializer) + self.assertAllEqual([0, 1, 2, 3], sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testCopyToDeviceGpuStrings(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.from_tensors(["a", "b", "c"]) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/gpu:0")) + + with ops.device("/gpu:0"): + iterator = device_dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.cached_session() as sess: + sess.run(iterator.initializer) + self.assertAllEqual([b"a", b"b", b"c"], sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testCopyToDeviceGpuStringsAndPrefetch(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.from_tensors(["a", "b", "c"]) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/gpu:0")) + + with ops.device("/gpu:0"): + iterator = device_dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.cached_session() as sess: + sess.run(iterator.initializer) + self.assertAllEqual([b"a", b"b", b"c"], sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testCopyToDevicePingPongCPUGPU(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + with compat.forward_compatibility_horizon(2018, 8, 4): + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/gpu:0", source_device="/cpu:0")) + back_to_cpu_dataset = device_dataset.apply( + prefetching_ops.copy_to_device("/cpu:0", source_device="/gpu:0")) + + with ops.device("/cpu:0"): + iterator = back_to_cpu_dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.cached_session() as sess: + sess.run(iterator.initializer) + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testCopyToDeviceWithReInit(self): + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/cpu:1")) + + with ops.device("/cpu:1"): + iterator = device_dataset.make_initializable_iterator() + next_element = iterator.get_next() + + self.assertEqual(host_dataset.output_types, device_dataset.output_types) + self.assertEqual(host_dataset.output_types, iterator.output_types) + self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) + self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) + self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) + self.assertEqual(host_dataset.output_classes, iterator.output_classes) + + self.assertEqual(dtypes.int64, next_element.dtype) + self.assertEqual([], next_element.shape) + + worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) + with self.test_session(config=worker_config) as sess: + sess.run(iterator.initializer) + for i in range(5): + self.assertEqual(i, sess.run(next_element)) + sess.run(iterator.initializer) + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testCopyToDeviceWithReInitAndPrefetch(self): + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/cpu:1")).prefetch(1) + + with ops.device("/cpu:1"): + iterator = device_dataset.make_initializable_iterator() + next_element = iterator.get_next() + + self.assertEqual(host_dataset.output_types, device_dataset.output_types) + self.assertEqual(host_dataset.output_types, iterator.output_types) + self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) + self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) + self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) + self.assertEqual(host_dataset.output_classes, iterator.output_classes) + + self.assertEqual(dtypes.int64, next_element.dtype) + self.assertEqual([], next_element.shape) + + worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) + with self.test_session(config=worker_config) as sess: + sess.run(iterator.initializer) + for i in range(5): + self.assertEqual(i, sess.run(next_element)) + sess.run(iterator.initializer) + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testCopyToDeviceGpuWithReInit(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/gpu:0")) + + with ops.device("/gpu:0"): + iterator = device_dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.cached_session() as sess: + sess.run(iterator.initializer) + for i in range(5): + self.assertEqual(i, sess.run(next_element)) + sess.run(iterator.initializer) + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testCopyToDeviceGpuWithReInitAndPrefetch(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/gpu:0")).prefetch(1) + + with ops.device("/gpu:0"): + iterator = device_dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.cached_session() as sess: + sess.run(iterator.initializer) + for i in range(5): + self.assertEqual(i, sess.run(next_element)) + sess.run(iterator.initializer) + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testIteratorGetNextAsOptionalOnGPU(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.range(3) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/gpu:0")) + with ops.device("/gpu:0"): + iterator = device_dataset.make_initializable_iterator() + next_elem = iterator_ops.get_next_as_optional(iterator) + elem_has_value_t = next_elem.has_value() + elem_value_t = next_elem.get_value() + + with self.cached_session() as sess: + # Before initializing the iterator, evaluating the optional fails with + # a FailedPreconditionError. + with self.assertRaises(errors.FailedPreconditionError): + sess.run(elem_has_value_t) + with self.assertRaises(errors.FailedPreconditionError): + sess.run(elem_value_t) + + # For each element of the dataset, assert that the optional evaluates to + # the expected value. + sess.run(iterator.initializer) + for i in range(3): + elem_has_value, elem_value = sess.run([elem_has_value_t, elem_value_t]) + self.assertTrue(elem_has_value) + self.assertEqual(i, elem_value) + + # After exhausting the iterator, `next_elem.has_value()` will evaluate to + # false, and attempting to get the value will fail. + for _ in range(2): + self.assertFalse(sess.run(elem_has_value_t)) + with self.assertRaises(errors.InvalidArgumentError): + sess.run(elem_value_t) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/range_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/range_dataset_op_test.py new file mode 100644 index 0000000000..22412c3965 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/range_dataset_op_test.py @@ -0,0 +1,78 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test RangeDataset.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.experimental.ops import counter +from tensorflow.python.data.experimental.ops import enumerate_ops +from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import tensor_shape +from tensorflow.python.platform import test + + +class RangeDatasetTest(test_base.DatasetTestBase): + + def testEnumerateDataset(self): + components = (["a", "b"], [1, 2], [37.0, 38]) + start = constant_op.constant(20, dtype=dtypes.int64) + + iterator = (dataset_ops.Dataset.from_tensor_slices(components).apply( + enumerate_ops.enumerate_dataset(start)).make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + self.assertEqual(dtypes.int64, get_next[0].dtype) + self.assertEqual((), get_next[0].shape) + self.assertEqual([tensor_shape.TensorShape([])] * 3, + [t.shape for t in get_next[1]]) + + with self.cached_session() as sess: + sess.run(init_op) + self.assertEqual((20, (b"a", 1, 37.0)), sess.run(get_next)) + self.assertEqual((21, (b"b", 2, 38.0)), sess.run(get_next)) + + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testCounter(self): + """Test dataset construction using `count`.""" + iterator = (counter.Counter(start=3, step=4) + .make_one_shot_iterator()) + get_next = iterator.get_next() + self.assertEqual([], get_next.shape.as_list()) + self.assertEqual(dtypes.int64, get_next.dtype) + + negative_iterator = (counter.Counter(start=0, step=-1) + .make_one_shot_iterator()) + negative_get_next = negative_iterator.get_next() + + with self.cached_session() as sess: + self.assertEqual(3, sess.run(get_next)) + self.assertEqual(3 + 4, sess.run(get_next)) + self.assertEqual(3 + 2 * 4, sess.run(get_next)) + + self.assertEqual(0, sess.run(negative_get_next)) + self.assertEqual(-1, sess.run(negative_get_next)) + self.assertEqual(-2, sess.run(negative_get_next)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/reader_dataset_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/reader_dataset_ops_test.py new file mode 100644 index 0000000000..a02f4bd14f --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/reader_dataset_ops_test.py @@ -0,0 +1,1083 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gzip +import os +import zlib + +import numpy as np + +from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base +from tensorflow.python.data.experimental.ops import readers +from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.data.ops import readers as core_readers +from tensorflow.python.data.util import nest +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import string_ops +from tensorflow.python.platform import test + + +class ReadBatchFeaturesTest( + reader_dataset_ops_test_base.ReadBatchFeaturesTestBase): + + def testRead(self): + for batch_size in [1, 2]: + for num_epochs in [1, 10]: + with ops.Graph().as_default() as g: + with self.session(graph=g) as sess: + # Basic test: read from file 0. + self.outputs = self.make_batch_feature( + filenames=self.test_filenames[0], + label_key="label", + num_epochs=num_epochs, + batch_size=batch_size).make_one_shot_iterator().get_next() + self.verify_records( + sess, + batch_size, + 0, + num_epochs=num_epochs, + label_key_provided=True) + with self.assertRaises(errors.OutOfRangeError): + self._next_actual_batch(sess, label_key_provided=True) + + with ops.Graph().as_default() as g: + with self.session(graph=g) as sess: + # Basic test: read from file 1. + self.outputs = self.make_batch_feature( + filenames=self.test_filenames[1], + label_key="label", + num_epochs=num_epochs, + batch_size=batch_size).make_one_shot_iterator().get_next() + self.verify_records( + sess, + batch_size, + 1, + num_epochs=num_epochs, + label_key_provided=True) + with self.assertRaises(errors.OutOfRangeError): + self._next_actual_batch(sess, label_key_provided=True) + + with ops.Graph().as_default() as g: + with self.session(graph=g) as sess: + # Basic test: read from both files. + self.outputs = self.make_batch_feature( + filenames=self.test_filenames, + label_key="label", + num_epochs=num_epochs, + batch_size=batch_size).make_one_shot_iterator().get_next() + self.verify_records( + sess, + batch_size, + num_epochs=num_epochs, + label_key_provided=True) + with self.assertRaises(errors.OutOfRangeError): + self._next_actual_batch(sess, label_key_provided=True) + + with ops.Graph().as_default() as g: + with self.session(graph=g) as sess: + # Basic test: read from both files. + self.outputs = self.make_batch_feature( + filenames=self.test_filenames, + num_epochs=num_epochs, + batch_size=batch_size).make_one_shot_iterator().get_next() + self.verify_records(sess, batch_size, num_epochs=num_epochs) + with self.assertRaises(errors.OutOfRangeError): + self._next_actual_batch(sess) + + def testReadWithEquivalentDataset(self): + features = { + "file": parsing_ops.FixedLenFeature([], dtypes.int64), + "record": parsing_ops.FixedLenFeature([], dtypes.int64), + } + dataset = ( + core_readers.TFRecordDataset(self.test_filenames) + .map(lambda x: parsing_ops.parse_single_example(x, features)) + .repeat(10).batch(2)) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + next_element = iterator.get_next() + + with self.cached_session() as sess: + sess.run(init_op) + for file_batch, _, _, _, record_batch, _ in self._next_expected_batch( + range(self._num_files), 2, 10): + actual_batch = sess.run(next_element) + self.assertAllEqual(file_batch, actual_batch["file"]) + self.assertAllEqual(record_batch, actual_batch["record"]) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testReadWithFusedShuffleRepeatDataset(self): + num_epochs = 5 + total_records = num_epochs * self._num_records + for batch_size in [1, 2]: + # Test that shuffling with same seed produces the same result. + with ops.Graph().as_default() as g: + with self.session(graph=g) as sess: + outputs1 = self.make_batch_feature( + filenames=self.test_filenames[0], + num_epochs=num_epochs, + batch_size=batch_size, + shuffle=True, + shuffle_seed=5).make_one_shot_iterator().get_next() + outputs2 = self.make_batch_feature( + filenames=self.test_filenames[0], + num_epochs=num_epochs, + batch_size=batch_size, + shuffle=True, + shuffle_seed=5).make_one_shot_iterator().get_next() + for _ in range(total_records // batch_size): + batch1 = self._run_actual_batch(outputs1, sess) + batch2 = self._run_actual_batch(outputs2, sess) + for i in range(len(batch1)): + self.assertAllEqual(batch1[i], batch2[i]) + + # Test that shuffling with different seeds produces a different order. + with ops.Graph().as_default() as g: + with self.session(graph=g) as sess: + outputs1 = self.make_batch_feature( + filenames=self.test_filenames[0], + num_epochs=num_epochs, + batch_size=batch_size, + shuffle=True, + shuffle_seed=5).make_one_shot_iterator().get_next() + outputs2 = self.make_batch_feature( + filenames=self.test_filenames[0], + num_epochs=num_epochs, + batch_size=batch_size, + shuffle=True, + shuffle_seed=15).make_one_shot_iterator().get_next() + all_equal = True + for _ in range(total_records // batch_size): + batch1 = self._run_actual_batch(outputs1, sess) + batch2 = self._run_actual_batch(outputs2, sess) + for i in range(len(batch1)): + all_equal = all_equal and np.array_equal(batch1[i], batch2[i]) + self.assertFalse(all_equal) + + def testParallelReadersAndParsers(self): + num_epochs = 5 + for batch_size in [1, 2]: + for reader_num_threads in [2, 4]: + for parser_num_threads in [2, 4]: + with ops.Graph().as_default() as g: + with self.session(graph=g) as sess: + self.outputs = self.make_batch_feature( + filenames=self.test_filenames, + label_key="label", + num_epochs=num_epochs, + batch_size=batch_size, + reader_num_threads=reader_num_threads, + parser_num_threads=parser_num_threads).make_one_shot_iterator( + ).get_next() + self.verify_records( + sess, + batch_size, + num_epochs=num_epochs, + label_key_provided=True, + interleave_cycle_length=reader_num_threads) + with self.assertRaises(errors.OutOfRangeError): + self._next_actual_batch(sess, label_key_provided=True) + + with ops.Graph().as_default() as g: + with self.session(graph=g) as sess: + self.outputs = self.make_batch_feature( + filenames=self.test_filenames, + num_epochs=num_epochs, + batch_size=batch_size, + reader_num_threads=reader_num_threads, + parser_num_threads=parser_num_threads).make_one_shot_iterator( + ).get_next() + self.verify_records( + sess, + batch_size, + num_epochs=num_epochs, + interleave_cycle_length=reader_num_threads) + with self.assertRaises(errors.OutOfRangeError): + self._next_actual_batch(sess) + + def testDropFinalBatch(self): + for batch_size in [1, 2]: + for num_epochs in [1, 10]: + with ops.Graph().as_default(): + # Basic test: read from file 0. + outputs = self.make_batch_feature( + filenames=self.test_filenames[0], + label_key="label", + num_epochs=num_epochs, + batch_size=batch_size, + drop_final_batch=True).make_one_shot_iterator().get_next() + for tensor in nest.flatten(outputs): + if isinstance(tensor, ops.Tensor): # Guard against SparseTensor. + self.assertEqual(tensor.shape[0], batch_size) + + def testIndefiniteRepeatShapeInference(self): + dataset = self.make_batch_feature( + filenames=self.test_filenames[0], + label_key="label", + num_epochs=None, + batch_size=32) + for shape, clazz in zip(nest.flatten(dataset.output_shapes), + nest.flatten(dataset.output_classes)): + if issubclass(clazz, ops.Tensor): + self.assertEqual(32, shape[0]) + + +class MakeCsvDatasetTest(test_base.DatasetTestBase): + + def _make_csv_dataset(self, filenames, batch_size, num_epochs=1, **kwargs): + return readers.make_csv_dataset( + filenames, batch_size=batch_size, num_epochs=num_epochs, **kwargs) + + def _setup_files(self, inputs, linebreak="\n", compression_type=None): + filenames = [] + for i, ip in enumerate(inputs): + fn = os.path.join(self.get_temp_dir(), "temp_%d.csv" % i) + contents = linebreak.join(ip).encode("utf-8") + if compression_type is None: + with open(fn, "wb") as f: + f.write(contents) + elif compression_type == "GZIP": + with gzip.GzipFile(fn, "wb") as f: + f.write(contents) + elif compression_type == "ZLIB": + contents = zlib.compress(contents) + with open(fn, "wb") as f: + f.write(contents) + else: + raise ValueError("Unsupported compression_type", compression_type) + filenames.append(fn) + return filenames + + def _next_expected_batch(self, expected_output, expected_keys, batch_size, + num_epochs): + features = {k: [] for k in expected_keys} + for _ in range(num_epochs): + for values in expected_output: + for n, key in enumerate(expected_keys): + features[key].append(values[n]) + if len(features[expected_keys[0]]) == batch_size: + yield features + features = {k: [] for k in expected_keys} + if features[expected_keys[0]]: # Leftover from the last batch + yield features + + def _verify_output( + self, + sess, + dataset, + batch_size, + num_epochs, + label_name, + expected_output, + expected_keys, + ): + nxt = dataset.make_one_shot_iterator().get_next() + + for expected_features in self._next_expected_batch( + expected_output, + expected_keys, + batch_size, + num_epochs, + ): + actual_features = sess.run(nxt) + + if label_name is not None: + expected_labels = expected_features.pop(label_name) + self.assertAllEqual(expected_labels, actual_features[1]) + actual_features = actual_features[0] + + for k in expected_features.keys(): + # Compare features + self.assertAllEqual(expected_features[k], actual_features[k]) + + with self.assertRaises(errors.OutOfRangeError): + sess.run(nxt) + + def _test_dataset(self, + inputs, + expected_output, + expected_keys, + batch_size=1, + num_epochs=1, + label_name=None, + **kwargs): + """Checks that elements produced by CsvDataset match expected output.""" + # Convert str type because py3 tf strings are bytestrings + filenames = self._setup_files( + inputs, compression_type=kwargs.get("compression_type", None)) + with ops.Graph().as_default() as g: + with self.session(graph=g) as sess: + dataset = self._make_csv_dataset( + filenames, + batch_size=batch_size, + num_epochs=num_epochs, + label_name=label_name, + **kwargs) + self._verify_output(sess, dataset, batch_size, num_epochs, label_name, + expected_output, expected_keys) + + def testMakeCSVDataset(self): + """Tests making a CSV dataset with keys and defaults provided.""" + record_defaults = [ + constant_op.constant([], dtypes.int32), + constant_op.constant([], dtypes.int64), + constant_op.constant([], dtypes.float32), + constant_op.constant([], dtypes.float64), + constant_op.constant([], dtypes.string) + ] + + column_names = ["col%d" % i for i in range(5)] + inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [ + ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19" + ]] + expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"], + [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]] + label = "col0" + + self._test_dataset( + inputs, + expected_output=expected_output, + expected_keys=column_names, + column_names=column_names, + label_name=label, + batch_size=1, + num_epochs=1, + shuffle=False, + header=True, + column_defaults=record_defaults, + ) + + def testMakeCSVDataset_withBatchSizeAndEpochs(self): + """Tests making a CSV dataset with keys and defaults provided.""" + record_defaults = [ + constant_op.constant([], dtypes.int32), + constant_op.constant([], dtypes.int64), + constant_op.constant([], dtypes.float32), + constant_op.constant([], dtypes.float64), + constant_op.constant([], dtypes.string) + ] + + column_names = ["col%d" % i for i in range(5)] + inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [ + ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19" + ]] + expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"], + [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]] + label = "col0" + + self._test_dataset( + inputs, + expected_output=expected_output, + expected_keys=column_names, + column_names=column_names, + label_name=label, + batch_size=3, + num_epochs=10, + shuffle=False, + header=True, + column_defaults=record_defaults, + ) + + def testMakeCSVDataset_withCompressionType(self): + """Tests `compression_type` argument.""" + record_defaults = [ + constant_op.constant([], dtypes.int32), + constant_op.constant([], dtypes.int64), + constant_op.constant([], dtypes.float32), + constant_op.constant([], dtypes.float64), + constant_op.constant([], dtypes.string) + ] + + column_names = ["col%d" % i for i in range(5)] + inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [ + ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19" + ]] + expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"], + [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]] + label = "col0" + + for compression_type in ("GZIP", "ZLIB"): + self._test_dataset( + inputs, + expected_output=expected_output, + expected_keys=column_names, + column_names=column_names, + label_name=label, + batch_size=1, + num_epochs=1, + shuffle=False, + header=True, + column_defaults=record_defaults, + compression_type=compression_type, + ) + + def testMakeCSVDataset_withBadInputs(self): + """Tests that exception is raised when input is malformed. + """ + record_defaults = [ + constant_op.constant([], dtypes.int32), + constant_op.constant([], dtypes.int64), + constant_op.constant([], dtypes.float32), + constant_op.constant([], dtypes.float64), + constant_op.constant([], dtypes.string) + ] + + column_names = ["col%d" % i for i in range(5)] + inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [ + ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19" + ]] + filenames = self._setup_files(inputs) + + # Duplicate column names + with self.assertRaises(ValueError): + self._make_csv_dataset( + filenames, + batch_size=1, + column_defaults=record_defaults, + label_name="col0", + column_names=column_names * 2) + + # Label key not one of column names + with self.assertRaises(ValueError): + self._make_csv_dataset( + filenames, + batch_size=1, + column_defaults=record_defaults, + label_name="not_a_real_label", + column_names=column_names) + + def testMakeCSVDataset_withNoLabel(self): + """Tests making a CSV dataset with no label provided.""" + record_defaults = [ + constant_op.constant([], dtypes.int32), + constant_op.constant([], dtypes.int64), + constant_op.constant([], dtypes.float32), + constant_op.constant([], dtypes.float64), + constant_op.constant([], dtypes.string) + ] + + column_names = ["col%d" % i for i in range(5)] + inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [ + ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19" + ]] + expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"], + [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]] + + self._test_dataset( + inputs, + expected_output=expected_output, + expected_keys=column_names, + column_names=column_names, + batch_size=1, + num_epochs=1, + shuffle=False, + header=True, + column_defaults=record_defaults, + ) + + def testMakeCSVDataset_withNoHeader(self): + """Tests that datasets can be created from CSV files with no header line. + """ + record_defaults = [ + constant_op.constant([], dtypes.int32), + constant_op.constant([], dtypes.int64), + constant_op.constant([], dtypes.float32), + constant_op.constant([], dtypes.float64), + constant_op.constant([], dtypes.string) + ] + + column_names = ["col%d" % i for i in range(5)] + inputs = [["0,1,2,3,4", "5,6,7,8,9"], ["10,11,12,13,14", "15,16,17,18,19"]] + expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"], + [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]] + label = "col0" + + self._test_dataset( + inputs, + expected_output=expected_output, + expected_keys=column_names, + column_names=column_names, + label_name=label, + batch_size=1, + num_epochs=1, + shuffle=False, + header=False, + column_defaults=record_defaults, + ) + + def testMakeCSVDataset_withTypes(self): + """Tests that defaults can be a dtype instead of a Tensor for required vals. + """ + record_defaults = [ + dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64, + dtypes.string + ] + + column_names = ["col%d" % i for i in range(5)] + inputs = [[",".join(x[0] for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], + [ + ",".join(x[0] for x in column_names), "10,11,12,13,14", + "15,16,17,18,19" + ]] + expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"], + [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]] + label = "col0" + + self._test_dataset( + inputs, + expected_output=expected_output, + expected_keys=column_names, + column_names=column_names, + label_name=label, + batch_size=1, + num_epochs=1, + shuffle=False, + header=True, + column_defaults=record_defaults, + ) + + def testMakeCSVDataset_withNoColNames(self): + """Tests that datasets can be created when column names are not specified. + + In that case, we should infer the column names from the header lines. + """ + record_defaults = [ + constant_op.constant([], dtypes.int32), + constant_op.constant([], dtypes.int64), + constant_op.constant([], dtypes.float32), + constant_op.constant([], dtypes.float64), + constant_op.constant([], dtypes.string) + ] + + column_names = ["col%d" % i for i in range(5)] + inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [ + ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19" + ]] + expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"], + [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]] + label = "col0" + + self._test_dataset( + inputs, + expected_output=expected_output, + expected_keys=column_names, + label_name=label, + batch_size=1, + num_epochs=1, + shuffle=False, + header=True, + column_defaults=record_defaults, + ) + + def testMakeCSVDataset_withTypeInferenceMismatch(self): + # Test that error is thrown when num fields doesn't match columns + column_names = ["col%d" % i for i in range(5)] + inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [ + ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19" + ]] + filenames = self._setup_files(inputs) + with self.assertRaises(ValueError): + self._make_csv_dataset( + filenames, + column_names=column_names + ["extra_name"], + column_defaults=None, + batch_size=2, + num_epochs=10) + + def testMakeCSVDataset_withTypeInference(self): + """Tests that datasets can be created when no defaults are specified. + + In that case, we should infer the types from the first N records. + """ + column_names = ["col%d" % i for i in range(5)] + str_int32_max = str(2**33) + inputs = [[ + ",".join(x for x in column_names), + "0,%s,2.0,3e50,rabbit" % str_int32_max + ]] + expected_output = [[0, 2**33, 2.0, 3e50, b"rabbit"]] + label = "col0" + + self._test_dataset( + inputs, + expected_output=expected_output, + expected_keys=column_names, + column_names=column_names, + label_name=label, + batch_size=1, + num_epochs=1, + shuffle=False, + header=True, + ) + + def testMakeCSVDataset_withTypeInferenceFallthrough(self): + """Tests that datasets can be created when no defaults are specified. + + Tests on a deliberately tricky file. + """ + column_names = ["col%d" % i for i in range(5)] + str_int32_max = str(2**33) + inputs = [[ + ",".join(x for x in column_names), + ",,,,", + "0,0,0.0,0.0,0.0", + "0,%s,2.0,3e50,rabbit" % str_int32_max, + ",,,,", + ]] + expected_output = [[0, 0, 0, 0, b""], [0, 0, 0, 0, b"0.0"], + [0, 2**33, 2.0, 3e50, b"rabbit"], [0, 0, 0, 0, b""]] + label = "col0" + + self._test_dataset( + inputs, + expected_output=expected_output, + expected_keys=column_names, + column_names=column_names, + label_name=label, + batch_size=1, + num_epochs=1, + shuffle=False, + header=True, + ) + + def testMakeCSVDataset_withSelectCols(self): + record_defaults = [ + constant_op.constant([], dtypes.int32), + constant_op.constant([], dtypes.int64), + constant_op.constant([], dtypes.float32), + constant_op.constant([], dtypes.float64), + constant_op.constant([], dtypes.string) + ] + column_names = ["col%d" % i for i in range(5)] + str_int32_max = str(2**33) + inputs = [[ + ",".join(x for x in column_names), + "0,%s,2.0,3e50,rabbit" % str_int32_max + ]] + expected_output = [[0, 2**33, 2.0, 3e50, b"rabbit"]] + + select_cols = [1, 3, 4] + self._test_dataset( + inputs, + expected_output=[[x[i] for i in select_cols] for x in expected_output], + expected_keys=[column_names[i] for i in select_cols], + column_names=column_names, + column_defaults=[record_defaults[i] for i in select_cols], + batch_size=1, + num_epochs=1, + shuffle=False, + header=True, + select_columns=select_cols, + ) + + # Can still do inference without provided defaults + self._test_dataset( + inputs, + expected_output=[[x[i] for i in select_cols] for x in expected_output], + expected_keys=[column_names[i] for i in select_cols], + column_names=column_names, + batch_size=1, + num_epochs=1, + shuffle=False, + header=True, + select_columns=select_cols, + ) + + # Can still do column name inference + self._test_dataset( + inputs, + expected_output=[[x[i] for i in select_cols] for x in expected_output], + expected_keys=[column_names[i] for i in select_cols], + batch_size=1, + num_epochs=1, + shuffle=False, + header=True, + select_columns=select_cols, + ) + + # Can specify column names instead of indices + self._test_dataset( + inputs, + expected_output=[[x[i] for i in select_cols] for x in expected_output], + expected_keys=[column_names[i] for i in select_cols], + column_names=column_names, + batch_size=1, + num_epochs=1, + shuffle=False, + header=True, + select_columns=[column_names[i] for i in select_cols], + ) + + def testMakeCSVDataset_withSelectColsError(self): + record_defaults = [ + constant_op.constant([], dtypes.int32), + constant_op.constant([], dtypes.int64), + constant_op.constant([], dtypes.float32), + constant_op.constant([], dtypes.float64), + constant_op.constant([], dtypes.string) + ] + column_names = ["col%d" % i for i in range(5)] + str_int32_max = str(2**33) + inputs = [[ + ",".join(x for x in column_names), + "0,%s,2.0,3e50,rabbit" % str_int32_max + ]] + + select_cols = [1, 3, 4] + filenames = self._setup_files(inputs) + + with self.assertRaises(ValueError): + # Mismatch in number of defaults and number of columns selected, + # should raise an error + self._make_csv_dataset( + filenames, + batch_size=1, + column_defaults=record_defaults, + column_names=column_names, + select_columns=select_cols) + + with self.assertRaises(ValueError): + # Invalid column name should raise an error + self._make_csv_dataset( + filenames, + batch_size=1, + column_defaults=[[0]], + column_names=column_names, + label_name=None, + select_columns=["invalid_col_name"]) + + def testMakeCSVDataset_withShuffle(self): + record_defaults = [ + constant_op.constant([], dtypes.int32), + constant_op.constant([], dtypes.int64), + constant_op.constant([], dtypes.float32), + constant_op.constant([], dtypes.float64), + constant_op.constant([], dtypes.string) + ] + + def str_series(st): + return ",".join(str(i) for i in range(st, st + 5)) + + column_names = ["col%d" % i for i in range(5)] + inputs = [ + [",".join(x for x in column_names) + ] + [str_series(5 * i) for i in range(15)], + [",".join(x for x in column_names)] + + [str_series(5 * i) for i in range(15, 20)], + ] + + filenames = self._setup_files(inputs) + + total_records = 20 + for batch_size in [1, 2]: + with ops.Graph().as_default() as g: + with self.session(graph=g) as sess: + # Test that shuffling with the same seed produces the same result + dataset1 = self._make_csv_dataset( + filenames, + column_defaults=record_defaults, + column_names=column_names, + batch_size=batch_size, + header=True, + shuffle=True, + shuffle_seed=5, + num_epochs=2, + ) + dataset2 = self._make_csv_dataset( + filenames, + column_defaults=record_defaults, + column_names=column_names, + batch_size=batch_size, + header=True, + shuffle=True, + shuffle_seed=5, + num_epochs=2, + ) + outputs1 = dataset1.make_one_shot_iterator().get_next() + outputs2 = dataset2.make_one_shot_iterator().get_next() + for _ in range(total_records // batch_size): + batch1 = nest.flatten(sess.run(outputs1)) + batch2 = nest.flatten(sess.run(outputs2)) + for i in range(len(batch1)): + self.assertAllEqual(batch1[i], batch2[i]) + + with ops.Graph().as_default() as g: + with self.session(graph=g) as sess: + # Test that shuffling with a different seed produces different results + dataset1 = self._make_csv_dataset( + filenames, + column_defaults=record_defaults, + column_names=column_names, + batch_size=batch_size, + header=True, + shuffle=True, + shuffle_seed=5, + num_epochs=2, + ) + dataset2 = self._make_csv_dataset( + filenames, + column_defaults=record_defaults, + column_names=column_names, + batch_size=batch_size, + header=True, + shuffle=True, + shuffle_seed=6, + num_epochs=2, + ) + outputs1 = dataset1.make_one_shot_iterator().get_next() + outputs2 = dataset2.make_one_shot_iterator().get_next() + all_equal = False + for _ in range(total_records // batch_size): + batch1 = nest.flatten(sess.run(outputs1)) + batch2 = nest.flatten(sess.run(outputs2)) + for i in range(len(batch1)): + all_equal = all_equal and np.array_equal(batch1[i], batch2[i]) + self.assertFalse(all_equal) + + def testIndefiniteRepeatShapeInference(self): + column_names = ["col%d" % i for i in range(5)] + inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [ + ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19" + ]] + filenames = self._setup_files(inputs) + dataset = self._make_csv_dataset(filenames, batch_size=32, num_epochs=None) + for shape in nest.flatten(dataset.output_shapes): + self.assertEqual(32, shape[0]) + + +class MakeTFRecordDatasetTest( + reader_dataset_ops_test_base.TFRecordDatasetTestBase): + + def _interleave(self, iterators, cycle_length): + pending_iterators = iterators + open_iterators = [] + num_open = 0 + for i in range(cycle_length): + if pending_iterators: + open_iterators.append(pending_iterators.pop(0)) + num_open += 1 + + while num_open: + for i in range(min(cycle_length, len(open_iterators))): + if open_iterators[i] is None: + continue + try: + yield next(open_iterators[i]) + except StopIteration: + if pending_iterators: + open_iterators[i] = pending_iterators.pop(0) + else: + open_iterators[i] = None + num_open -= 1 + + def _next_expected_batch(self, + file_indices, + batch_size, + num_epochs, + cycle_length, + drop_final_batch, + use_parser_fn): + + def _next_record(file_indices): + for j in file_indices: + for i in range(self._num_records): + yield j, i + + def _next_record_interleaved(file_indices, cycle_length): + return self._interleave([_next_record([i]) for i in file_indices], + cycle_length) + + record_batch = [] + batch_index = 0 + for _ in range(num_epochs): + if cycle_length == 1: + next_records = _next_record(file_indices) + else: + next_records = _next_record_interleaved(file_indices, cycle_length) + for f, r in next_records: + record = self._record(f, r) + if use_parser_fn: + record = record[1:] + record_batch.append(record) + batch_index += 1 + if len(record_batch) == batch_size: + yield record_batch + record_batch = [] + batch_index = 0 + if record_batch and not drop_final_batch: + yield record_batch + + def _verify_records(self, + sess, + outputs, + batch_size, + file_index, + num_epochs, + interleave_cycle_length, + drop_final_batch, + use_parser_fn): + if file_index is not None: + file_indices = [file_index] + else: + file_indices = range(self._num_files) + + for expected_batch in self._next_expected_batch( + file_indices, batch_size, num_epochs, interleave_cycle_length, + drop_final_batch, use_parser_fn): + actual_batch = sess.run(outputs) + self.assertAllEqual(expected_batch, actual_batch) + + def _read_test(self, batch_size, num_epochs, file_index=None, + num_parallel_reads=1, drop_final_batch=False, parser_fn=False): + if file_index is None: + file_pattern = self.test_filenames + else: + file_pattern = self.test_filenames[file_index] + + if parser_fn: + fn = lambda x: string_ops.substr(x, 1, 999) + else: + fn = None + + with ops.Graph().as_default() as g: + with self.session(graph=g) as sess: + outputs = readers.make_tf_record_dataset( + file_pattern=file_pattern, + num_epochs=num_epochs, + batch_size=batch_size, + parser_fn=fn, + num_parallel_reads=num_parallel_reads, + drop_final_batch=drop_final_batch, + shuffle=False).make_one_shot_iterator().get_next() + self._verify_records( + sess, outputs, batch_size, file_index, num_epochs=num_epochs, + interleave_cycle_length=num_parallel_reads, + drop_final_batch=drop_final_batch, use_parser_fn=parser_fn) + with self.assertRaises(errors.OutOfRangeError): + sess.run(outputs) + + def testRead(self): + for batch_size in [1, 2]: + for num_epochs in [1, 3]: + # Basic test: read from file 0. + self._read_test(batch_size, num_epochs, 0) + + # Basic test: read from file 1. + self._read_test(batch_size, num_epochs, 1) + + # Basic test: read from both files. + self._read_test(batch_size, num_epochs) + + # Basic test: read from both files, with parallel reads. + self._read_test(batch_size, num_epochs, num_parallel_reads=8) + + def testDropFinalBatch(self): + for batch_size in [1, 2, 10]: + for num_epochs in [1, 3]: + # Read from file 0. + self._read_test(batch_size, num_epochs, 0, drop_final_batch=True) + + # Read from both files. + self._read_test(batch_size, num_epochs, drop_final_batch=True) + + # Read from both files, with parallel reads. + self._read_test(batch_size, num_epochs, num_parallel_reads=8, + drop_final_batch=True) + + def testParserFn(self): + for batch_size in [1, 2]: + for num_epochs in [1, 3]: + for drop_final_batch in [False, True]: + self._read_test(batch_size, num_epochs, parser_fn=True, + drop_final_batch=drop_final_batch) + self._read_test(batch_size, num_epochs, num_parallel_reads=8, + parser_fn=True, drop_final_batch=drop_final_batch) + + def _shuffle_test(self, batch_size, num_epochs, num_parallel_reads=1, + seed=None): + with ops.Graph().as_default() as g: + with self.session(graph=g) as sess: + dataset = readers.make_tf_record_dataset( + file_pattern=self.test_filenames, + num_epochs=num_epochs, + batch_size=batch_size, + num_parallel_reads=num_parallel_reads, + shuffle=True, + shuffle_seed=seed) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + + sess.run(iterator.initializer) + first_batches = [] + try: + while True: + first_batches.append(sess.run(next_element)) + except errors.OutOfRangeError: + pass + + sess.run(iterator.initializer) + second_batches = [] + try: + while True: + second_batches.append(sess.run(next_element)) + except errors.OutOfRangeError: + pass + + self.assertEqual(len(first_batches), len(second_batches)) + if seed is not None: + # if you set a seed, should get the same results + for i in range(len(first_batches)): + self.assertAllEqual(first_batches[i], second_batches[i]) + + expected = [] + for f in range(self._num_files): + for r in range(self._num_records): + expected.extend([self._record(f, r)] * num_epochs) + + for batches in (first_batches, second_batches): + actual = [] + for b in batches: + actual.extend(b) + self.assertAllEqual(sorted(expected), sorted(actual)) + + def testShuffle(self): + for batch_size in [1, 2]: + for num_epochs in [1, 3]: + for num_parallel_reads in [1, 2]: + # Test that all expected elements are produced + self._shuffle_test(batch_size, num_epochs, num_parallel_reads) + # Test that elements are produced in a consistent order if + # you specify a seed. + self._shuffle_test(batch_size, num_epochs, num_parallel_reads, + seed=21345) + + def testIndefiniteRepeatShapeInference(self): + dataset = readers.make_tf_record_dataset( + file_pattern=self.test_filenames, num_epochs=None, batch_size=32) + for shape in nest.flatten(dataset.output_shapes): + self.assertEqual(32, shape[0]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/reader_dataset_ops_test_base.py b/tensorflow/python/data/experimental/kernel_tests/reader_dataset_ops_test_base.py new file mode 100644 index 0000000000..b6ab80d132 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/reader_dataset_ops_test_base.py @@ -0,0 +1,353 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Base class for testing reader datasets.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gzip +import os +import zlib + +from tensorflow.core.example import example_pb2 +from tensorflow.core.example import feature_pb2 +from tensorflow.python.data.experimental.ops import readers +from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.data.ops import iterator_ops +from tensorflow.python.data.ops import readers as core_readers +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.lib.io import python_io +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import parsing_ops +from tensorflow.python.util import compat + + +class FixedLengthRecordDatasetTestBase(test_base.DatasetTestBase): + """Base class for setting up and testing FixedLengthRecordDataset.""" + + def setUp(self): + super(FixedLengthRecordDatasetTestBase, self).setUp() + self._num_files = 2 + self._num_records = 7 + self._header_bytes = 5 + self._record_bytes = 3 + self._footer_bytes = 2 + + def _record(self, f, r): + return compat.as_bytes(str(f * 2 + r) * self._record_bytes) + + def _createFiles(self): + filenames = [] + for i in range(self._num_files): + fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i) + filenames.append(fn) + with open(fn, "wb") as f: + f.write(b"H" * self._header_bytes) + for j in range(self._num_records): + f.write(self._record(i, j)) + f.write(b"F" * self._footer_bytes) + return filenames + + +class ReadBatchFeaturesTestBase(test_base.DatasetTestBase): + """Base class for setting up and testing `make_batched_feature_dataset`.""" + + def setUp(self): + super(ReadBatchFeaturesTestBase, self).setUp() + self._num_files = 2 + self._num_records = 7 + self.test_filenames = self._createFiles() + + def make_batch_feature(self, + filenames, + num_epochs, + batch_size, + label_key=None, + reader_num_threads=1, + parser_num_threads=1, + shuffle=False, + shuffle_seed=None, + drop_final_batch=False): + self.filenames = filenames + self.num_epochs = num_epochs + self.batch_size = batch_size + + return readers.make_batched_features_dataset( + file_pattern=self.filenames, + batch_size=self.batch_size, + features={ + "file": parsing_ops.FixedLenFeature([], dtypes.int64), + "record": parsing_ops.FixedLenFeature([], dtypes.int64), + "keywords": parsing_ops.VarLenFeature(dtypes.string), + "label": parsing_ops.FixedLenFeature([], dtypes.string), + }, + label_key=label_key, + reader=core_readers.TFRecordDataset, + num_epochs=self.num_epochs, + shuffle=shuffle, + shuffle_seed=shuffle_seed, + reader_num_threads=reader_num_threads, + parser_num_threads=parser_num_threads, + drop_final_batch=drop_final_batch) + + def _record(self, f, r, l): + example = example_pb2.Example( + features=feature_pb2.Features( + feature={ + "file": + feature_pb2.Feature( + int64_list=feature_pb2.Int64List(value=[f])), + "record": + feature_pb2.Feature( + int64_list=feature_pb2.Int64List(value=[r])), + "keywords": + feature_pb2.Feature( + bytes_list=feature_pb2.BytesList( + value=self._get_keywords(f, r))), + "label": + feature_pb2.Feature( + bytes_list=feature_pb2.BytesList( + value=[compat.as_bytes(l)])) + })) + return example.SerializeToString() + + def _get_keywords(self, f, r): + num_keywords = 1 + (f + r) % 2 + keywords = [] + for index in range(num_keywords): + keywords.append(compat.as_bytes("keyword%d" % index)) + return keywords + + def _sum_keywords(self, num_files): + sum_keywords = 0 + for i in range(num_files): + for j in range(self._num_records): + sum_keywords += 1 + (i + j) % 2 + return sum_keywords + + def _createFiles(self): + filenames = [] + for i in range(self._num_files): + fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i) + filenames.append(fn) + writer = python_io.TFRecordWriter(fn) + for j in range(self._num_records): + writer.write(self._record(i, j, "fake-label")) + writer.close() + return filenames + + def _run_actual_batch(self, outputs, sess, label_key_provided=False): + if label_key_provided: + # outputs would be a tuple of (feature dict, label) + label_op = outputs[1] + features_op = outputs[0] + else: + features_op = outputs + label_op = features_op["label"] + file_op = features_op["file"] + keywords_indices_op = features_op["keywords"].indices + keywords_values_op = features_op["keywords"].values + keywords_dense_shape_op = features_op["keywords"].dense_shape + record_op = features_op["record"] + return sess.run([ + file_op, keywords_indices_op, keywords_values_op, + keywords_dense_shape_op, record_op, label_op + ]) + + def _next_actual_batch(self, sess, label_key_provided=False): + return self._run_actual_batch(self.outputs, sess, label_key_provided) + + def _interleave(self, iterators, cycle_length): + pending_iterators = iterators + open_iterators = [] + num_open = 0 + for i in range(cycle_length): + if pending_iterators: + open_iterators.append(pending_iterators.pop(0)) + num_open += 1 + + while num_open: + for i in range(min(cycle_length, len(open_iterators))): + if open_iterators[i] is None: + continue + try: + yield next(open_iterators[i]) + except StopIteration: + if pending_iterators: + open_iterators[i] = pending_iterators.pop(0) + else: + open_iterators[i] = None + num_open -= 1 + + def _next_expected_batch(self, + file_indices, + batch_size, + num_epochs, + cycle_length=1): + + def _next_record(file_indices): + for j in file_indices: + for i in range(self._num_records): + yield j, i, compat.as_bytes("fake-label") + + def _next_record_interleaved(file_indices, cycle_length): + return self._interleave([_next_record([i]) for i in file_indices], + cycle_length) + + file_batch = [] + keywords_batch_indices = [] + keywords_batch_values = [] + keywords_batch_max_len = 0 + record_batch = [] + batch_index = 0 + label_batch = [] + for _ in range(num_epochs): + if cycle_length == 1: + next_records = _next_record(file_indices) + else: + next_records = _next_record_interleaved(file_indices, cycle_length) + for record in next_records: + f = record[0] + r = record[1] + label_batch.append(record[2]) + file_batch.append(f) + record_batch.append(r) + keywords = self._get_keywords(f, r) + keywords_batch_values.extend(keywords) + keywords_batch_indices.extend( + [[batch_index, i] for i in range(len(keywords))]) + batch_index += 1 + keywords_batch_max_len = max(keywords_batch_max_len, len(keywords)) + if len(file_batch) == batch_size: + yield [ + file_batch, keywords_batch_indices, keywords_batch_values, + [batch_size, keywords_batch_max_len], record_batch, label_batch + ] + file_batch = [] + keywords_batch_indices = [] + keywords_batch_values = [] + keywords_batch_max_len = 0 + record_batch = [] + batch_index = 0 + label_batch = [] + if file_batch: + yield [ + file_batch, keywords_batch_indices, keywords_batch_values, + [len(file_batch), keywords_batch_max_len], record_batch, label_batch + ] + + def verify_records(self, + sess, + batch_size, + file_index=None, + num_epochs=1, + label_key_provided=False, + interleave_cycle_length=1): + if file_index is not None: + file_indices = [file_index] + else: + file_indices = range(self._num_files) + + for expected_batch in self._next_expected_batch( + file_indices, + batch_size, + num_epochs, + cycle_length=interleave_cycle_length): + actual_batch = self._next_actual_batch( + sess, label_key_provided=label_key_provided) + for i in range(len(expected_batch)): + self.assertAllEqual(expected_batch[i], actual_batch[i]) + + +class TextLineDatasetTestBase(test_base.DatasetTestBase): + """Base class for setting up and testing TextLineDataset.""" + + def _lineText(self, f, l): + return compat.as_bytes("%d: %d" % (f, l)) + + def _createFiles(self, + num_files, + num_lines, + crlf=False, + compression_type=None): + filenames = [] + for i in range(num_files): + fn = os.path.join(self.get_temp_dir(), "text_line.%d.txt" % i) + filenames.append(fn) + contents = [] + for j in range(num_lines): + contents.append(self._lineText(i, j)) + # Always include a newline after the record unless it is + # at the end of the file, in which case we include it + if j + 1 != num_lines or i == 0: + contents.append(b"\r\n" if crlf else b"\n") + contents = b"".join(contents) + + if not compression_type: + with open(fn, "wb") as f: + f.write(contents) + elif compression_type == "GZIP": + with gzip.GzipFile(fn, "wb") as f: + f.write(contents) + elif compression_type == "ZLIB": + contents = zlib.compress(contents) + with open(fn, "wb") as f: + f.write(contents) + else: + raise ValueError("Unsupported compression_type", compression_type) + + return filenames + + +class TFRecordDatasetTestBase(test_base.DatasetTestBase): + """Base class for setting up and testing TFRecordDataset.""" + + def setUp(self): + super(TFRecordDatasetTestBase, self).setUp() + self._num_files = 2 + self._num_records = 7 + + self.test_filenames = self._createFiles() + + self.filenames = array_ops.placeholder(dtypes.string, shape=[None]) + self.num_epochs = array_ops.placeholder_with_default( + constant_op.constant(1, dtypes.int64), shape=[]) + self.compression_type = array_ops.placeholder_with_default("", shape=[]) + self.batch_size = array_ops.placeholder(dtypes.int64, shape=[]) + + repeat_dataset = core_readers.TFRecordDataset( + self.filenames, self.compression_type).repeat(self.num_epochs) + batch_dataset = repeat_dataset.batch(self.batch_size) + + iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types) + self.init_op = iterator.make_initializer(repeat_dataset) + self.init_batch_op = iterator.make_initializer(batch_dataset) + self.get_next = iterator.get_next() + + def _record(self, f, r): + return compat.as_bytes("Record %d of file %d" % (r, f)) + + def _createFiles(self): + filenames = [] + for i in range(self._num_files): + fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i) + filenames.append(fn) + writer = python_io.TFRecordWriter(fn) + for j in range(self._num_records): + writer.write(self._record(i, j)) + writer.close() + return filenames diff --git a/tensorflow/python/data/experimental/kernel_tests/resample_test.py b/tensorflow/python/data/experimental/kernel_tests/resample_test.py new file mode 100644 index 0000000000..775648c943 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/resample_test.py @@ -0,0 +1,182 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time + +from absl.testing import parameterized +import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin + +from tensorflow.python.data.experimental.ops import resampling +from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import string_ops +from tensorflow.python.platform import test +from tensorflow.python.util import compat + + +def _time_resampling( + test_obj, data_np, target_dist, init_dist, num_to_sample): + dataset = dataset_ops.Dataset.from_tensor_slices(data_np).repeat() + + # Reshape distribution via rejection sampling. + dataset = dataset.apply( + resampling.rejection_resample( + class_func=lambda x: x, + target_dist=target_dist, + initial_dist=init_dist, + seed=142)) + + get_next = dataset.make_one_shot_iterator().get_next() + + with test_obj.test_session() as sess: + start_time = time.time() + for _ in xrange(num_to_sample): + sess.run(get_next) + end_time = time.time() + + return end_time - start_time + + +class ResampleTest(test_base.DatasetTestBase, parameterized.TestCase): + + @parameterized.named_parameters( + ("InitialDistributionKnown", True), + ("InitialDistributionUnknown", False)) + def testDistribution(self, initial_known): + classes = np.random.randint(5, size=(20000,)) # Uniformly sampled + target_dist = [0.9, 0.05, 0.05, 0.0, 0.0] + initial_dist = [0.2] * 5 if initial_known else None + classes = math_ops.to_int64(classes) # needed for Windows build. + dataset = dataset_ops.Dataset.from_tensor_slices(classes).shuffle( + 200, seed=21).map(lambda c: (c, string_ops.as_string(c))).repeat() + + get_next = dataset.apply( + resampling.rejection_resample( + target_dist=target_dist, + initial_dist=initial_dist, + class_func=lambda c, _: c, + seed=27)).make_one_shot_iterator().get_next() + + with self.cached_session() as sess: + returned = [] + while len(returned) < 4000: + returned.append(sess.run(get_next)) + + returned_classes, returned_classes_and_data = zip(*returned) + _, returned_data = zip(*returned_classes_and_data) + self.assertAllEqual([compat.as_bytes(str(c)) + for c in returned_classes], returned_data) + total_returned = len(returned_classes) + class_counts = np.array([ + len([True for v in returned_classes if v == c]) + for c in range(5)]) + returned_dist = class_counts / total_returned + self.assertAllClose(target_dist, returned_dist, atol=1e-2) + + @parameterized.named_parameters( + ("OnlyInitial", True), + ("NotInitial", False)) + def testEdgeCasesSampleFromInitialDataset(self, only_initial_dist): + init_dist = [0.5, 0.5] + target_dist = [0.5, 0.5] if only_initial_dist else [0.0, 1.0] + num_classes = len(init_dist) + # We don't need many samples to test that this works. + num_samples = 100 + data_np = np.random.choice(num_classes, num_samples, p=init_dist) + + dataset = dataset_ops.Dataset.from_tensor_slices(data_np) + + # Reshape distribution. + dataset = dataset.apply( + resampling.rejection_resample( + class_func=lambda x: x, + target_dist=target_dist, + initial_dist=init_dist)) + + get_next = dataset.make_one_shot_iterator().get_next() + + with self.cached_session() as sess: + returned = [] + with self.assertRaises(errors.OutOfRangeError): + while True: + returned.append(sess.run(get_next)) + + def testRandomClasses(self): + init_dist = [0.25, 0.25, 0.25, 0.25] + target_dist = [0.0, 0.0, 0.0, 1.0] + num_classes = len(init_dist) + # We don't need many samples to test a dirac-delta target distribution. + num_samples = 100 + data_np = np.random.choice(num_classes, num_samples, p=init_dist) + + dataset = dataset_ops.Dataset.from_tensor_slices(data_np) + + # Apply a random mapping that preserves the data distribution. + def _remap_fn(_): + return math_ops.cast(random_ops.random_uniform([1]) * num_classes, + dtypes.int32)[0] + dataset = dataset.map(_remap_fn) + + # Reshape distribution. + dataset = dataset.apply( + resampling.rejection_resample( + class_func=lambda x: x, + target_dist=target_dist, + initial_dist=init_dist)) + + get_next = dataset.make_one_shot_iterator().get_next() + + with self.cached_session() as sess: + returned = [] + with self.assertRaises(errors.OutOfRangeError): + while True: + returned.append(sess.run(get_next)) + + classes, _ = zip(*returned) + bincount = np.bincount( + np.array(classes), + minlength=num_classes).astype(np.float32) / len(classes) + + self.assertAllClose(target_dist, bincount, atol=1e-2) + + +class ResampleDatasetBenchmark(test.Benchmark): + + def benchmarkResamplePerformance(self): + init_dist = [0.25, 0.25, 0.25, 0.25] + target_dist = [0.0, 0.0, 0.0, 1.0] + num_classes = len(init_dist) + # We don't need many samples to test a dirac-delta target distribution + num_samples = 1000 + data_np = np.random.choice(num_classes, num_samples, p=init_dist) + + resample_time = _time_resampling( + self, data_np, target_dist, init_dist, num_to_sample=1000) + + self.report_benchmark( + iters=1000, wall_time=resample_time, name="benchmark_resample") + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/scan_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/scan_dataset_op_test.py new file mode 100644 index 0000000000..78ec80de23 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/scan_dataset_op_test.py @@ -0,0 +1,172 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools + +import numpy as np + +from tensorflow.python.data.experimental.ops import scan_ops +from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class ScanDatasetTest(test_base.DatasetTestBase): + + def _counting_dataset(self, start, scan_fn): + return dataset_ops.Dataset.from_tensors(0).repeat().apply( + scan_ops.scan(start, scan_fn)) + + def testCount(self): + def make_scan_fn(step): + return lambda state, _: (state + step, state) + + start = array_ops.placeholder(dtypes.int32, shape=[]) + step = array_ops.placeholder(dtypes.int32, shape=[]) + take = array_ops.placeholder(dtypes.int64, shape=[]) + iterator = self._counting_dataset( + start, make_scan_fn(step)).take(take).make_initializable_iterator() + next_element = iterator.get_next() + + with self.cached_session() as sess: + + for start_val, step_val, take_val in [(0, 1, 10), (0, 1, 0), (10, 1, 10), + (10, 2, 10), (10, -1, 10), + (10, -2, 10)]: + sess.run(iterator.initializer, + feed_dict={start: start_val, step: step_val, take: take_val}) + for expected, _ in zip( + itertools.count(start_val, step_val), range(take_val)): + self.assertEqual(expected, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + @test_util.run_in_graph_and_eager_modes + def testFibonacci(self): + iterator = dataset_ops.Dataset.from_tensors(1).repeat(None).apply( + scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1])) + ).make_one_shot_iterator() + + if context.executing_eagerly(): + next_element = iterator.get_next + else: + get_next = iterator.get_next() + next_element = lambda: get_next + + self.assertEqual(1, self.evaluate(next_element())) + self.assertEqual(1, self.evaluate(next_element())) + self.assertEqual(2, self.evaluate(next_element())) + self.assertEqual(3, self.evaluate(next_element())) + self.assertEqual(5, self.evaluate(next_element())) + self.assertEqual(8, self.evaluate(next_element())) + + def testSparseCount(self): + def _sparse(i): + return sparse_tensor.SparseTensorValue( + indices=np.array([[0, 0]]), + values=(i * np.array([1])), + dense_shape=np.array([1, 1])) + + def make_scan_fn(step): + return lambda state, _: (_sparse(state.values[0] + step), state) + + start = array_ops.placeholder(dtypes.int32, shape=[]) + step = array_ops.placeholder(dtypes.int32, shape=[]) + take = array_ops.placeholder(dtypes.int64, shape=[]) + iterator = self._counting_dataset( + _sparse(start), + make_scan_fn(step)).take(take).make_initializable_iterator() + next_element = iterator.get_next() + + with self.cached_session() as sess: + + for start_val, step_val, take_val in [(0, 1, 10), (0, 1, 0), (10, 1, 10), + (10, 2, 10), (10, -1, 10), + (10, -2, 10)]: + sess.run(iterator.initializer, + feed_dict={start: start_val, step: step_val, take: take_val}) + for expected, _ in zip( + itertools.count(start_val, step_val), range(take_val)): + self.assertEqual(expected, sess.run(next_element).values[0]) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testChangingStateShape(self): + # Test the fixed-point shape invariant calculations: start with + # initial values with known shapes, and use a scan function that + # changes the size of the state on each element. + def _scan_fn(state, input_value): + # Statically known rank, but dynamic length. + ret_longer_vector = array_ops.concat([state[0], state[0]], 0) + # Statically unknown rank. + ret_larger_rank = array_ops.expand_dims(state[1], 0) + return (ret_longer_vector, ret_larger_rank), (state, input_value) + + dataset = dataset_ops.Dataset.from_tensors(0).repeat(5).apply( + scan_ops.scan(([0], 1), _scan_fn)) + self.assertEqual([None], dataset.output_shapes[0][0].as_list()) + self.assertIs(None, dataset.output_shapes[0][1].ndims) + self.assertEqual([], dataset.output_shapes[1].as_list()) + + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + with self.cached_session() as sess: + for i in range(5): + (longer_vector_val, larger_rank_val), _ = sess.run(next_element) + self.assertAllEqual([0] * (2**i), longer_vector_val) + self.assertAllEqual(np.array(1, ndmin=i), larger_rank_val) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testIncorrectStateType(self): + + def _scan_fn(state, _): + return constant_op.constant(1, dtype=dtypes.int64), state + + dataset = dataset_ops.Dataset.range(10) + with self.assertRaisesRegexp( + TypeError, + "The element types for the new state must match the initial state."): + dataset.apply( + scan_ops.scan(constant_op.constant(1, dtype=dtypes.int32), _scan_fn)) + + def testIncorrectReturnType(self): + + def _scan_fn(unused_state, unused_input_value): + return constant_op.constant(1, dtype=dtypes.int64) + + dataset = dataset_ops.Dataset.range(10) + with self.assertRaisesRegexp( + TypeError, + "The scan function must return a pair comprising the new state and the " + "output value."): + dataset.apply( + scan_ops.scan(constant_op.constant(1, dtype=dtypes.int32), _scan_fn)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD b/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD new file mode 100644 index 0000000000..20c02a5366 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD @@ -0,0 +1,555 @@ +package(default_visibility = ["//tensorflow:internal"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "py_test") + +py_library( + name = "dataset_serialization_test_base", + srcs = [ + "dataset_serialization_test_base.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:lookup_ops", + "//tensorflow/python:platform", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:training", + "//tensorflow/python:util", + "//tensorflow/python:variables", + "//tensorflow/python/data/experimental/ops:iterator_ops", + "//tensorflow/python/data/ops:iterator_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "batch_dataset_serialization_test", + size = "medium", + srcs = ["batch_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/experimental/ops:batching", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "cache_dataset_serialization_test", + size = "small", + srcs = ["cache_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python/data/ops:dataset_ops", + "@absl_py//absl/testing:parameterized", + ], +) + +py_test( + name = "concatenate_dataset_serialization_test", + size = "small", + srcs = ["concatenate_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "csv_dataset_serialization_test", + size = "small", + srcs = ["csv_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python/data/experimental/ops:readers", + ], +) + +py_test( + name = "dataset_constructor_serialization_test", + size = "medium", + srcs = ["dataset_constructor_serialization_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "filter_dataset_serialization_test", + size = "medium", + srcs = ["filter_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python:math_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "fixed_length_record_dataset_serialization_test", + size = "medium", + srcs = ["fixed_length_record_dataset_serialization_test.py"], + shard_count = 4, + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/experimental/kernel_tests:reader_dataset_ops_test_base", + "//tensorflow/python/data/ops:readers", + ], +) + +py_test( + name = "flat_map_dataset_serialization_test", + size = "medium", + srcs = ["flat_map_dataset_serialization_test.py"], + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:function", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:variable_scope", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "group_by_reducer_serialization_test", + size = "medium", + srcs = ["group_by_reducer_serialization_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/experimental/ops:grouping", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "group_by_window_serialization_test", + size = "medium", + srcs = ["group_by_window_serialization_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/experimental/ops:grouping", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "ignore_errors_serialization_test", + size = "small", + srcs = ["ignore_errors_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/experimental/ops:error_ops", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "interleave_dataset_serialization_test", + size = "medium", + srcs = ["interleave_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) + +py_test( + name = "map_and_batch_dataset_serialization_test", + size = "medium", + srcs = ["map_and_batch_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python:math_ops", + "//tensorflow/python/data/experimental/ops:batching", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "map_dataset_serialization_test", + size = "medium", + srcs = ["map_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:function", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:variable_scope", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "optimize_dataset_serialization_test", + size = "small", + srcs = ["optimize_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/experimental/ops:optimization", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "padded_batch_dataset_serialization_test", + size = "medium", + srcs = ["padded_batch_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:string_ops", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "parallel_interleave_dataset_serialization_test", + size = "medium", + srcs = ["parallel_interleave_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/experimental/ops:interleave_ops", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "parallel_map_dataset_serialization_test", + size = "medium", + srcs = ["parallel_map_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:function", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:variable_scope", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "parse_example_dataset_serialization_test", + size = "medium", + srcs = ["parse_example_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/experimental/kernel_tests:reader_dataset_ops_test_base", + ], +) + +py_test( + name = "prefetch_dataset_serialization_test", + size = "small", + srcs = ["prefetch_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "range_dataset_serialization_test", + size = "small", + srcs = ["range_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:io_ops", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:variables", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "sample_from_datasets_serialization_test", + size = "medium", + srcs = ["sample_from_datasets_serialization_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/experimental/ops:interleave_ops", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "scan_dataset_serialization_test", + size = "small", + srcs = ["scan_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/experimental/ops:scan_ops", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "sequence_dataset_serialization_test", + size = "medium", + srcs = ["sequence_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "serialization_integration_test", + size = "small", + srcs = ["serialization_integration_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python/data/experimental/ops:iterator_ops", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "shuffle_and_repeat_dataset_serialization_test", + size = "medium", + srcs = ["shuffle_and_repeat_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/experimental/ops:shuffle_ops", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "shuffle_dataset_serialization_test", + size = "medium", + srcs = ["shuffle_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python/data/experimental/ops:iterator_ops", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "sql_dataset_serialization_test", + size = "small", + srcs = ["sql_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python/data/experimental/kernel_tests:sql_dataset_op_test_base", + "//tensorflow/python/data/experimental/ops:readers", + ], +) + +py_test( + name = "stats_dataset_serialization_test", + size = "medium", + srcs = ["stats_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python/data/experimental/ops:stats_ops", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "textline_dataset_serialization_test", + size = "medium", + srcs = ["textline_dataset_serialization_test.py"], + shard_count = 4, + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/experimental/kernel_tests:reader_dataset_ops_test_base", + "//tensorflow/python/data/ops:readers", + ], +) + +py_test( + name = "tf_record_dataset_serialization_test", + size = "medium", + srcs = ["tf_record_dataset_serialization_test.py"], + shard_count = 4, + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/experimental/kernel_tests:reader_dataset_ops_test_base", + "//tensorflow/python/data/ops:readers", + ], +) + +py_test( + name = "unbatch_dataset_serialization_test", + size = "medium", + srcs = ["unbatch_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/experimental/ops:batching", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "unique_dataset_serialization_test", + size = "small", + srcs = ["unique_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/experimental/ops:unique", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "zip_dataset_serialization_test", + size = "small", + srcs = ["zip_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/batch_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/batch_dataset_serialization_test.py new file mode 100644 index 0000000000..d72a6df14c --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/batch_dataset_serialization_test.py @@ -0,0 +1,83 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the BatchDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.experimental.ops import batching +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class BatchDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def build_dataset(self, multiplier=15.0, tensor_slice_len=2, batch_size=2): + components = ( + np.arange(tensor_slice_len), + np.array([[1, 2, 3]]) * np.arange(tensor_slice_len)[:, np.newaxis], + np.array(multiplier) * np.arange(tensor_slice_len)) + + return dataset_ops.Dataset.from_tensor_slices(components).batch(batch_size) + + def testCore(self): + tensor_slice_len = 8 + batch_size = 2 + num_outputs = tensor_slice_len // batch_size + self.run_core_tests( + lambda: self.build_dataset(15.0, tensor_slice_len, batch_size), + lambda: self.build_dataset(20.0, tensor_slice_len, batch_size), + num_outputs) + + def _build_dataset_dense_to_sparse(self, components): + return dataset_ops.Dataset.from_tensor_slices(components).map( + lambda x: array_ops.fill([x], x)).apply( + batching.dense_to_sparse_batch(4, [12])) + + def testDenseToSparseBatchDatasetCore(self): + components = np.random.randint(5, size=(40,)).astype(np.int32) + diff_comp = np.random.randint(2, size=(100,)).astype(np.int32) + + num_outputs = len(components) // 4 + self.run_core_tests(lambda: self._build_dataset_dense_to_sparse(components), + lambda: self._build_dataset_dense_to_sparse(diff_comp), + num_outputs) + + def _sparse(self, i): + return sparse_tensor.SparseTensorValue( + indices=[[0]], values=(i * [1]), dense_shape=[1]) + + def _build_dataset_sparse(self, batch_size=5): + return dataset_ops.Dataset.range(10).map(self._sparse).batch(batch_size) + + def testSparseCore(self): + self.run_core_tests(self._build_dataset_sparse, + lambda: self._build_dataset_sparse(2), 2) + + def _build_dataset_nested_sparse(self): + return dataset_ops.Dataset.range(10).map(self._sparse).batch(5).batch(2) + + def testNestedSparseCore(self): + self.run_core_tests(self._build_dataset_nested_sparse, None, 1) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/cache_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/cache_dataset_serialization_test.py new file mode 100644 index 0000000000..2bcf77f5d8 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/cache_dataset_serialization_test.py @@ -0,0 +1,253 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the CacheDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from absl.testing import parameterized + +from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import errors +from tensorflow.python.platform import test + + +class CacheDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase, + parameterized.TestCase): + + def setUp(self): + self.range_size = 10 + self.num_repeats = 3 + self.num_outputs = self.range_size * self.num_repeats + self.cache_file_prefix = 'test' + + def make_dataset_fn(self, is_memory): + if is_memory: + filename = '' + else: + filename = os.path.join(self.get_temp_dir(), self.cache_file_prefix) + + def ds_fn(): + return dataset_ops.Dataset.range(self.range_size).cache(filename).repeat( + self.num_repeats) + + return ds_fn + + def expected_outputs(self): + return list(range(self.range_size)) * self.num_repeats + + @parameterized.named_parameters( + ('Memory', True), + ('File', False), + ) + def testCheckpointBeforeOneEpoch(self, is_memory): + ds_fn = self.make_dataset_fn(is_memory) + + # Generate 5 entries from iterator and save checkpoint. + outputs = self.gen_outputs(ds_fn, [], 5, verify_exhausted=False) + self.assertSequenceEqual(outputs, range(5)) + + # Restore from checkpoint and produce the rest of the elements from the + # iterator. + outputs.extend( + self.gen_outputs( + ds_fn, [], + self.num_outputs - 5, + ckpt_saved=True, + verify_exhausted=False)) + self.assertSequenceEqual(outputs, self.expected_outputs()) + + @parameterized.named_parameters( + ('Memory', True), + ('File', False), + ) + def testCheckpointBeforeOneEpochThenRunFewSteps(self, is_memory): + ds_fn = self.make_dataset_fn(is_memory) + + # Generate 8 entries from iterator but save checkpoint after producing 5. + outputs = self.gen_outputs( + ds_fn, [5], 8, verify_exhausted=False, save_checkpoint_at_end=False) + self.assertSequenceEqual(outputs, range(8)) + + if is_memory: + outputs = outputs[:5] + outputs.extend( + self.gen_outputs( + ds_fn, [], + self.num_outputs - 5, + ckpt_saved=True, + verify_exhausted=False)) + self.assertSequenceEqual(outputs, self.expected_outputs()) + else: + # Restoring from checkpoint and running GetNext should return + # `AlreadExistsError` now because the lockfile already exists. + with self.assertRaises(errors.AlreadyExistsError): + self.gen_outputs( + ds_fn, [], + self.num_outputs - 5, + ckpt_saved=True, + verify_exhausted=False) + + @parameterized.named_parameters( + ('Memory', True), + ('File', False), + ) + def testCheckpointAfterOneEpoch(self, is_memory): + ds_fn = self.make_dataset_fn(is_memory) + + # Generate 15 entries from iterator and save checkpoint. + outputs = self.gen_outputs(ds_fn, [], 15, verify_exhausted=False) + self.assertSequenceEqual(outputs, list(range(10)) + list(range(5))) + + # Restore from checkpoint and produce the rest of the elements from the + # iterator. + outputs.extend( + self.gen_outputs( + ds_fn, [], + self.num_outputs - 15, + ckpt_saved=True, + verify_exhausted=False)) + self.assertSequenceEqual(outputs, self.expected_outputs()) + + @parameterized.named_parameters( + ('Memory', True), + ('File', False), + ) + def testCheckpointAfterOneEpochThenRunFewSteps(self, is_memory): + ds_fn = self.make_dataset_fn(is_memory) + + # Generate 18 entries from iterator but save checkpoint after producing 15. + outputs = self.gen_outputs( + ds_fn, [15], 18, verify_exhausted=False, save_checkpoint_at_end=False) + self.assertSequenceEqual(outputs, list(range(10)) + list(range(8))) + + outputs = list(range(10)) + list(range(5)) + self.gen_outputs( + ds_fn, [], + self.num_outputs - 15, + ckpt_saved=True, + verify_exhausted=False) + self.assertSequenceEqual(outputs, list(range(10)) * 3) + + @parameterized.named_parameters( + ('Memory', True), + ('File', False), + ) + def testCheckpointBeforeOneEpochButRunCompleteEpoch(self, is_memory): + ds_fn = self.make_dataset_fn(is_memory) + + # Generate 13 entries from iterator but save checkpoint after producing 5. + outputs = self.gen_outputs( + ds_fn, [5], 13, verify_exhausted=False, save_checkpoint_at_end=False) + self.assertSequenceEqual(outputs, list(range(10)) + list(range(3))) + + # Since we ran for more than one epoch, the cache was completely written. + # The ckpt was saved when the iterator was in cache-write mode. Test that + # the iterator falls back to read mode after restoring if the cache has + # been completely written. + + outputs = list(range(5)) + self.gen_outputs( + ds_fn, [], + self.num_outputs - 5, + ckpt_saved=True, + verify_exhausted=False) + self.assertSequenceEqual(outputs, list(range(10)) * 3) + + @parameterized.named_parameters( + ('Memory', True), + ('File', False), + ) + def testCheckpointUnusedWriterIterator(self, is_memory): + ds_fn = self.make_dataset_fn(is_memory) + + # Checkpoint before get_next is called even once. + outputs = self.gen_outputs(ds_fn, [], 0, verify_exhausted=False) + self.assertSequenceEqual(outputs, []) + + outputs = self.gen_outputs( + ds_fn, [], self.num_outputs, ckpt_saved=True, verify_exhausted=False) + self.assertSequenceEqual(outputs, list(range(10)) * 3) + + @parameterized.named_parameters( + ('Memory', True), + ('File', False), + ) + def testCheckpointUnusedMidwayWriterIterator(self, is_memory): + ds_fn = self.make_dataset_fn(is_memory) + + # Produce 5 elements and checkpoint. + outputs = self.gen_outputs(ds_fn, [], 5, verify_exhausted=False) + self.assertSequenceEqual(outputs, range(5)) + + # Restore from checkpoint, then produce no elements and checkpoint. + outputs.extend( + self.gen_outputs(ds_fn, [], 0, ckpt_saved=True, verify_exhausted=False)) + self.assertSequenceEqual(outputs, range(5)) + + # Restore from checkpoint and produce rest of the elements. + outputs.extend( + self.gen_outputs( + ds_fn, [], + self.num_outputs - 5, + ckpt_saved=True, + verify_exhausted=False)) + self.assertSequenceEqual(outputs, list(range(10)) * 3) + + @parameterized.named_parameters( + ('Memory', True), + ('File', False), + ) + def testUnusedCheckpointError(self, is_memory): + ds_fn = self.make_dataset_fn(is_memory) + + # Produce 5 elements and save ckpt. + outputs = self.gen_outputs(ds_fn, [], 5, verify_exhausted=False) + self.assertSequenceEqual(outputs, range(5)) + + if is_memory: + outputs = self.gen_outputs( + ds_fn, [], self.num_outputs, verify_exhausted=False) + self.assertSequenceEqual(outputs, self.expected_outputs()) + else: + # Since the complete cache has not been written, a new iterator which does + # not restore the checkpoint will throw an error since there is a partial + # cache shard. + with self.assertRaises(errors.AlreadyExistsError): + outputs = self.gen_outputs( + ds_fn, [], self.num_outputs, verify_exhausted=False) + + @parameterized.named_parameters( + ('Memory', True), + ('File', False), + ) + def testIgnoreCheckpointIfCacheWritten(self, is_memory): + ds_fn = self.make_dataset_fn(is_memory) + + # Produce 15 elements and save ckpt. This will write the complete cache. + outputs = self.gen_outputs(ds_fn, [], 15, verify_exhausted=False) + self.assertSequenceEqual(outputs, list(range(10)) + list(range(5))) + + # Build the iterator again but do not restore from ckpt. Since the cache + # has already been written we should be able to use it. + outputs = self.gen_outputs( + ds_fn, [], self.num_outputs, verify_exhausted=False) + self.assertSequenceEqual(outputs, list(range(10)) * 3) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/concatenate_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/concatenate_dataset_serialization_test.py new file mode 100644 index 0000000000..c075dff8cb --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/concatenate_dataset_serialization_test.py @@ -0,0 +1,49 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the ConcatenateDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.platform import test + + +class ConcatenateDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_concatenate_dataset(self, var_array): + input_components = (np.tile(np.array([[1], [2], [3], [4]]), 20), + np.tile(np.array([[12], [13], [14], [15]]), 4)) + to_concatenate_components = (np.tile( + np.array([[5], [6], [7], [8], [9]]), 20), var_array) + + return dataset_ops.Dataset.from_tensor_slices(input_components).concatenate( + dataset_ops.Dataset.from_tensor_slices(to_concatenate_components)) + + def testConcatenateCore(self): + num_outputs = 9 + array = np.tile(np.array([[16], [17], [18], [19], [20]]), 15) + diff_array = np.array([[1], [2], [3], [4], [5]]) + self.run_core_tests(lambda: self._build_concatenate_dataset(array), + lambda: self._build_concatenate_dataset(diff_array), + num_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/csv_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/csv_dataset_serialization_test.py new file mode 100644 index 0000000000..d4983492e7 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/csv_dataset_serialization_test.py @@ -0,0 +1,73 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the CsvDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gzip +import os + +from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.experimental.ops import readers +from tensorflow.python.platform import test + + +class CsvDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def setUp(self): + self._num_cols = 7 + self._num_rows = 10 + self._num_epochs = 14 + self._num_outputs = self._num_rows * self._num_epochs + + inputs = [ + ",".join(str(self._num_cols * j + i) + for i in range(self._num_cols)) + for j in range(self._num_rows) + ] + contents = "\n".join(inputs).encode("utf-8") + + self._filename = os.path.join(self.get_temp_dir(), "file.csv") + self._compressed = os.path.join(self.get_temp_dir(), + "comp.csv") # GZip compressed + + with open(self._filename, "wb") as f: + f.write(contents) + with gzip.GzipFile(self._compressed, "wb") as f: + f.write(contents) + + def ds_func(self, **kwargs): + compression_type = kwargs.get("compression_type", None) + if compression_type == "GZIP": + filename = self._compressed + elif compression_type is None: + filename = self._filename + else: + raise ValueError("Invalid compression type:", compression_type) + + return readers.CsvDataset(filename, **kwargs).repeat(self._num_epochs) + + def testSerializationCore(self): + defs = [[0]] * self._num_cols + self.run_core_tests( + lambda: self.ds_func(record_defaults=defs, buffer_size=2), + lambda: self.ds_func(record_defaults=defs, buffer_size=12), + self._num_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_constructor_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_constructor_serialization_test.py new file mode 100644 index 0000000000..41a095fb1a --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_constructor_serialization_test.py @@ -0,0 +1,95 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the dataset constructors serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.platform import test + + +class FromTensorsSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_tensor_dataset(self, variable_array): + components = (variable_array, np.array([1, 2, 3]), np.array(37.0)) + + return dataset_ops.Dataset.from_tensors(components) + + def testFromTensorsCore(self): + # Equal length components + arr = np.array(1) + num_outputs = 1 + diff_arr = np.array(2) + self.run_core_tests(lambda: self._build_tensor_dataset(arr), + lambda: self._build_tensor_dataset(diff_arr), + num_outputs) + + +class FromTensorSlicesSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_tensor_slices_dataset(self, components): + return dataset_ops.Dataset.from_tensor_slices(components) + + def testFromTensorSlicesCore(self): + # Equal length components + components = (np.tile(np.array([[1], [2], [3], [4]]), 20), + np.tile(np.array([[12], [13], [14], [15]]), 22), + np.array([37.0, 38.0, 39.0, 40.0])) + + diff_comp = (np.tile(np.array([[1], [2], [3], [4]]), 20), + np.tile(np.array([[5], [6], [7], [8]]), 22), + np.array([1.0, 2.0, 3.0, 4.0])) + + dict_components = {"foo": [1, 2, 3], "bar": [[4.0], [5.0], [6.0]]} + + self.run_core_tests(lambda: self._build_tensor_slices_dataset(components), + lambda: self._build_tensor_slices_dataset(diff_comp), 4) + self.run_core_tests( + lambda: self._build_tensor_slices_dataset(dict_components), None, 3) + + +class FromSparseTensorSlicesSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_sparse_tensor_slice_dataset(self, slices): + indices = np.array( + [[i, j] for i in range(len(slices)) for j in range(len(slices[i]))], + dtype=np.int64) + values = np.array([val for s in slices for val in s], dtype=np.float64) + dense_shape = np.array( + [len(slices), max(len(s) for s in slices) + 1], dtype=np.int64) + sparse_components = sparse_tensor.SparseTensor(indices, values, dense_shape) + return dataset_ops.Dataset.from_sparse_tensor_slices(sparse_components) + + def testFromSparseTensorSlicesCore(self): + slices = [[1., 2., 3.], [1.], [1.], [1., 2.], [], [1., 2.], [], [], []] + diff_slices = [[1., 2.], [2.], [2., 3., 4.], [], [], []] + + self.run_core_tests( + lambda: self._build_sparse_tensor_slice_dataset(slices), + lambda: self._build_sparse_tensor_slice_dataset(diff_slices), + 9, + sparse_tensors=True) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_serialization_test_base.py b/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_serialization_test_base.py new file mode 100644 index 0000000000..7f435b8239 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_serialization_test_base.py @@ -0,0 +1,692 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Base class for testing serializable datasets.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import numpy as np + +from tensorflow.python.data.experimental.ops import iterator_ops as contrib_iterator_ops +from tensorflow.python.data.ops import iterator_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import lookup_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test +from tensorflow.python.training import checkpoint_management +from tensorflow.python.training import saver as saver_lib +from tensorflow.python.util import nest + + +def remove_variants(get_next_op): + # TODO(b/72408568): Remove this once session.run can get + # variant tensors. + """Remove variants from a nest structure, so sess.run will execute.""" + + def _remove_variant(x): + if isinstance(x, ops.Tensor) and x.dtype == dtypes.variant: + return () + else: + return x + + return nest.map_structure(_remove_variant, get_next_op) + + +class DatasetSerializationTestBase(test.TestCase): + """Base class for testing serializable datasets.""" + + def tearDown(self): + self._delete_ckpt() + + # TODO(b/72657739): Remove sparse_tensor argument, which is to test the + # (deprecated) saveable `SparseTensorSliceDataset`, once the API + # `from_sparse_tensor_slices()`and related tests are deleted. + def run_core_tests(self, ds_fn1, ds_fn2, num_outputs, sparse_tensors=False): + """Runs the core tests. + + Args: + ds_fn1: 0-argument function that returns a Dataset. + ds_fn2: 0-argument function that returns a Dataset different from + ds_fn1. If None, verify_restore_in_modified_graph test is not run. + num_outputs: Total number of outputs expected from this Dataset. + sparse_tensors: Whether dataset is built from SparseTensor(s). + + Raises: + AssertionError if any test fails. + """ + self.verify_unused_iterator( + ds_fn1, num_outputs, sparse_tensors=sparse_tensors) + self.verify_fully_used_iterator( + ds_fn1, num_outputs, sparse_tensors=sparse_tensors) + self.verify_exhausted_iterator( + ds_fn1, num_outputs, sparse_tensors=sparse_tensors) + self.verify_init_before_restore( + ds_fn1, num_outputs, sparse_tensors=sparse_tensors) + self.verify_multiple_breaks( + ds_fn1, num_outputs, sparse_tensors=sparse_tensors) + self.verify_reset_restored_iterator( + ds_fn1, num_outputs, sparse_tensors=sparse_tensors) + self.verify_restore_in_empty_graph( + ds_fn1, num_outputs, sparse_tensors=sparse_tensors) + if ds_fn2: + self.verify_restore_in_modified_graph( + ds_fn1, ds_fn2, num_outputs, sparse_tensors=sparse_tensors) + + def verify_unused_iterator(self, + ds_fn, + num_outputs, + sparse_tensors=False, + verify_exhausted=True): + """Verifies that saving and restoring an unused iterator works. + + Args: + ds_fn: See `run_core_tests`. + num_outputs: See `run_core_tests`. + sparse_tensors: See `run_core_tests`. + verify_exhausted: See `gen_outputs`. + + Raises: + AssertionError if any test fails. + """ + self.verify_run_with_breaks( + ds_fn, [0], + num_outputs, + sparse_tensors=sparse_tensors, + verify_exhausted=verify_exhausted) + + def verify_fully_used_iterator(self, ds_fn, num_outputs, + sparse_tensors=False): + """Verifies that saving and restoring a fully used iterator works. + + Note that this only checks saving and restoring an iterator from which + `num_outputs` items have been produced but does not check for an + exhausted iterator, i.e., one from which an OutOfRange error has been + returned. + + Args: + ds_fn: See `run_core_tests`. + num_outputs: See `run_core_tests`. + sparse_tensors: See `run_core_tests`. + + Raises: + AssertionError if test fails. + """ + self.verify_run_with_breaks( + ds_fn, [num_outputs], num_outputs, sparse_tensors=sparse_tensors) + + def verify_exhausted_iterator(self, ds_fn, num_outputs, sparse_tensors=False): + """Verifies that saving and restoring an exhausted iterator works. + + An exhausted iterator is one which has returned an OutOfRange error. + + Args: + ds_fn: See `run_core_tests`. + num_outputs: See `run_core_tests`. + sparse_tensors: See `run_core_tests`. + + Raises: + AssertionError if any test fails. + """ + self.gen_outputs( + ds_fn, [], + num_outputs, + verify_exhausted=True, + sparse_tensors=sparse_tensors) + actual = self.gen_outputs( + ds_fn, [], + 0, + ckpt_saved=True, + verify_exhausted=True, + sparse_tensors=sparse_tensors) + self.assertEqual(len(actual), 0) + + def verify_init_before_restore(self, + ds_fn, + num_outputs, + sparse_tensors=False, + verify_exhausted=True): + """Verifies that restoring into an already initialized iterator works. + + Args: + ds_fn: See `run_core_tests`. + num_outputs: See `run_core_tests`. + sparse_tensors: See `run_core_tests`. + verify_exhausted: See `gen_outputs`. + + Raises: + AssertionError if any test fails. + """ + self.verify_run_with_breaks( + ds_fn, + self.gen_break_points(num_outputs), + num_outputs, + init_before_restore=True, + sparse_tensors=sparse_tensors, + verify_exhausted=verify_exhausted) + + def verify_multiple_breaks(self, + ds_fn, + num_outputs, + num_breaks=10, + sparse_tensors=False, + verify_exhausted=True): + """Attempts to save/restore at multiple break points. + + Args: + ds_fn: See `run_core_tests`. + num_outputs: See `run_core_tests`. + num_breaks: The number of break points. These are uniformly spread in + [0, num_outputs] both inclusive. + sparse_tensors: See `run_core_tests`. + verify_exhausted: See `gen_outputs`. + + Raises: + AssertionError if any test fails. + """ + self.verify_run_with_breaks( + ds_fn, + self.gen_break_points(num_outputs, num_breaks), + num_outputs, + sparse_tensors=sparse_tensors, + verify_exhausted=verify_exhausted) + + def verify_reset_restored_iterator(self, + ds_fn, + num_outputs, + break_point=None, + sparse_tensors=False, + verify_exhausted=True): + """Attempts to re-initialize a restored iterator. + + This is useful when restoring a training checkpoint during validation. + + Args: + ds_fn: See `run_core_tests`. + num_outputs: See `run_core_tests`. + break_point: Break point. Optional. Defaults to num_outputs/2. + sparse_tensors: See `run_core_tests`. + verify_exhausted: See `gen_outputs`. + + Raises: + AssertionError if any test fails. + """ + break_point = num_outputs // 2 if not break_point else break_point + + # Collect ground truth containing all outputs. + expected = self.gen_outputs( + ds_fn, [], + num_outputs, + sparse_tensors=sparse_tensors, + verify_exhausted=verify_exhausted) + + # Skip some items and save checkpoint. + self.gen_outputs( + ds_fn, [], + break_point, + sparse_tensors=sparse_tensors, + verify_exhausted=False) + + actual = [] + # Restore from checkpoint and then run init_op. + with ops.Graph().as_default() as g: + saver = self._import_meta_graph() + init_op, get_next_op = self._get_iterator_ops_from_collection( + ds_fn, sparse_tensors=sparse_tensors) + get_next_op = remove_variants(get_next_op) + with self.session(graph=g) as sess: + self._restore(saver, sess) + self._initialize(init_op, sess) + for _ in range(num_outputs): + actual.append(sess.run(get_next_op)) + if verify_exhausted: + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + self.match(expected, actual) + + def verify_restore_in_modified_graph(self, + ds_fn1, + ds_fn2, + num_outputs, + break_point=None, + sparse_tensors=False, + verify_exhausted=True): + """Attempts to restore an iterator in a modified graph. + + Builds an input pipeline using ds_fn1, runs it for `break_point` steps + and saves a checkpoint. Then builds a new graph using ds_fn2, restores + the checkpoint from ds_fn1 and verifies that the restore is successful. + + Args: + ds_fn1: See `run_core_tests`. + ds_fn2: See `run_core_tests`. + num_outputs: See `run_core_tests`. + break_point: Break point. Optional. Defaults to num_outputs/2. + sparse_tensors: See `run_core_tests`. + verify_exhausted: See `gen_outputs`. + + Raises: + AssertionError if any test fails. + """ + break_point = num_outputs // 2 if not break_point else break_point + + # Skip `break_point` items and store the remaining produced from ds_fn1 + # in `expected`. + self.gen_outputs( + ds_fn1, [], + break_point, + sparse_tensors=sparse_tensors, + verify_exhausted=False) + expected = self.gen_outputs( + ds_fn1, [], + num_outputs - break_point, + ckpt_saved=True, + sparse_tensors=sparse_tensors, + verify_exhausted=verify_exhausted) + + # Generate `break_point` items from ds_fn1 and save checkpoint. + self.gen_outputs( + ds_fn1, [], + break_point, + sparse_tensors=sparse_tensors, + verify_exhausted=False) + + actual = [] + # Build graph for ds_fn2 but load checkpoint for ds_fn1. + with ops.Graph().as_default() as g: + _, get_next_op, saver = self._build_graph( + ds_fn2, sparse_tensors=sparse_tensors) + get_next_op = remove_variants(get_next_op) + with self.session(graph=g) as sess: + self._restore(saver, sess) + for _ in range(num_outputs - break_point): + actual.append(sess.run(get_next_op)) + if verify_exhausted: + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + self.match(expected, actual) + + def verify_restore_in_empty_graph(self, + ds_fn, + num_outputs, + break_point=None, + sparse_tensors=False, + verify_exhausted=True): + """Attempts to restore an iterator in an empty graph. + + Builds an input pipeline using ds_fn, runs it for `break_point` steps + and saves a checkpoint. Then builds a new empty graph, restores + the checkpoint from ds_fn and verifies that the restore is successful. + + Args: + ds_fn: See `run_core_tests`. + num_outputs: See `run_core_tests`. + break_point: Break point. Optional. Defaults to num_outputs/2. + sparse_tensors: See `run_core_tests`. + verify_exhausted: See `gen_outputs`. + + Raises: + AssertionError if any test fails. + """ + break_point = num_outputs // 2 if not break_point else break_point + + # Skip `break_point` items and store the remaining produced from ds_fn + # in `expected`. + self.gen_outputs( + ds_fn, [], + break_point, + sparse_tensors=sparse_tensors, + verify_exhausted=False) + expected = self.gen_outputs( + ds_fn, [], + num_outputs - break_point, + ckpt_saved=True, + sparse_tensors=sparse_tensors, + verify_exhausted=verify_exhausted) + + # Generate `break_point` items from ds_fn and save checkpoint. + self.gen_outputs( + ds_fn, [], + break_point, + sparse_tensors=sparse_tensors, + verify_exhausted=False) + + actual = [] + # Build an empty graph but load checkpoint for ds_fn. + with ops.Graph().as_default() as g: + get_next_op, saver = self._build_empty_graph( + ds_fn, sparse_tensors=sparse_tensors) + get_next_op = remove_variants(get_next_op) + with self.session(graph=g) as sess: + self._restore(saver, sess) + for _ in range(num_outputs - break_point): + actual.append(sess.run(get_next_op)) + if verify_exhausted: + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + self.match(expected, actual) + + def verify_error_on_save(self, + ds_fn, + num_outputs, + error, + break_point=None, + sparse_tensors=False): + """Attempts to save a non-saveable iterator. + + Args: + ds_fn: See `run_core_tests`. + num_outputs: See `run_core_tests`. + error: Declared error when trying to save iterator. + break_point: Break point. Optional. Defaults to num_outputs/2. + sparse_tensors: See `run_core_tests`. + + Raises: + AssertionError if any test fails. + """ + + break_point = num_outputs // 2 if not break_point else break_point + with ops.Graph().as_default() as g: + init_op, get_next_op, saver = self._build_graph( + ds_fn, sparse_tensors=sparse_tensors) + get_next_op = remove_variants(get_next_op) + with self.session(graph=g) as sess: + self._initialize(init_op, sess) + for _ in range(break_point): + sess.run(get_next_op) + with self.assertRaises(error): + self._save(sess, saver) + + def verify_run_with_breaks(self, + ds_fn, + break_points, + num_outputs, + init_before_restore=False, + sparse_tensors=False, + verify_exhausted=True): + """Verifies that ds_fn() produces the same outputs with and without breaks. + + 1. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it + *without* stopping at break points. + 2. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it + with stopping at break points. + + Deep matches outputs from 1 and 2. + + Args: + ds_fn: See `gen_outputs`. + break_points: See `gen_outputs`. + num_outputs: See `gen_outputs`. + init_before_restore: See `gen_outputs`. + sparse_tensors: See `run_core_tests`. + verify_exhausted: See `gen_outputs`. + + Raises: + AssertionError if any test fails. + """ + expected = self.gen_outputs( + ds_fn, [], + num_outputs, + init_before_restore=init_before_restore, + sparse_tensors=sparse_tensors, + verify_exhausted=verify_exhausted) + + actual = self.gen_outputs( + ds_fn, + break_points, + num_outputs, + init_before_restore=init_before_restore, + sparse_tensors=sparse_tensors, + verify_exhausted=verify_exhausted) + + self.match(expected, actual) + + def gen_outputs(self, + ds_fn, + break_points, + num_outputs, + ckpt_saved=False, + init_before_restore=False, + sparse_tensors=False, + verify_exhausted=True, + save_checkpoint_at_end=True): + """Generates elements from input dataset while stopping at break points. + + Produces `num_outputs` outputs and saves the state of the iterator in the + Saver checkpoint. + + Args: + ds_fn: 0-argument function that returns the dataset. + break_points: A list of integers. For each `break_point` in + `break_points`, we produce outputs till `break_point` number of items + have been produced and then checkpoint the state. The current graph + and session are destroyed and a new graph and session are used to + produce outputs till next checkpoint or till `num_outputs` elements + have been produced. `break_point` must be <= `num_outputs`. + num_outputs: The total number of outputs to produce from the iterator. + ckpt_saved: Whether a checkpoint already exists. If False, we build the + graph from ds_fn. + init_before_restore: Whether init should be called before saver.restore. + This is just so that we can verify that restoring an already initialized + iterator works. + sparse_tensors: Whether dataset is built from SparseTensor(s). + verify_exhausted: Whether to verify that the iterator has been exhausted + after producing `num_outputs` elements. + save_checkpoint_at_end: Whether to save a checkpoint after producing all + outputs. If False, checkpoints are saved each break point but not at the + end. Note that checkpoints overwrite each other so there is always only + a single checkpoint available. Defaults to True. + + Returns: + A list of `num_outputs` items. + """ + outputs = [] + + def get_ops(): + if ckpt_saved: + saver = self._import_meta_graph() + init_op, get_next_op = self._get_iterator_ops_from_collection( + ds_fn, sparse_tensors=sparse_tensors) + else: + init_op, get_next_op, saver = self._build_graph( + ds_fn, sparse_tensors=sparse_tensors) + return init_op, get_next_op, saver + + for i in range(len(break_points) + 1): + with ops.Graph().as_default() as g: + init_op, get_next_op, saver = get_ops() + get_next_op = remove_variants(get_next_op) + with self.session(graph=g) as sess: + if ckpt_saved: + if init_before_restore: + self._initialize(init_op, sess) + self._restore(saver, sess) + else: + self._initialize(init_op, sess) + start = break_points[i - 1] if i > 0 else 0 + end = break_points[i] if i < len(break_points) else num_outputs + num_iters = end - start + for _ in range(num_iters): + outputs.append(sess.run(get_next_op)) + if i == len(break_points) and verify_exhausted: + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + if save_checkpoint_at_end or i < len(break_points): + self._save(sess, saver) + ckpt_saved = True + + return outputs + + def match(self, expected, actual): + """Matches nested structures. + + Recursively matches shape and values of `expected` and `actual`. + Handles scalars, numpy arrays and other python sequence containers + e.g. list, dict. + + Args: + expected: Nested structure 1. + actual: Nested structure 2. + + Raises: + AssertionError if matching fails. + """ + if isinstance(expected, np.ndarray): + expected = expected.tolist() + if isinstance(actual, np.ndarray): + actual = actual.tolist() + self.assertEqual(type(expected), type(actual)) + + if nest.is_sequence(expected): + self.assertEqual(len(expected), len(actual)) + if isinstance(expected, dict): + for key1, key2 in zip(sorted(expected), sorted(actual)): + self.assertEqual(key1, key2) + self.match(expected[key1], actual[key2]) + else: + for item1, item2 in zip(expected, actual): + self.match(item1, item2) + else: + self.assertEqual(expected, actual) + + def does_not_match(self, expected, actual): + with self.assertRaises(AssertionError): + self.match(expected, actual) + + def gen_break_points(self, num_outputs, num_samples=10): + """Generates `num_samples` breaks points in [0, num_outputs].""" + return np.linspace(0, num_outputs, num_samples, dtype=int) + + def _build_graph(self, ds_fn, sparse_tensors=False): + iterator = ds_fn().make_initializable_iterator() + + saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + init_op = iterator.initializer + if sparse_tensors: + get_next = sparse_tensor.SparseTensor(*iterator.get_next()) + else: + get_next = iterator.get_next() + self._add_iterator_ops_to_collection(init_op, get_next, ds_fn, + sparse_tensors) + saver = saver_lib.Saver(allow_empty=True) + return init_op, get_next, saver + + def _build_empty_graph(self, ds_fn, sparse_tensors=False): + iterator = iterator_ops.Iterator.from_structure( + self._get_output_types(ds_fn), + output_shapes=self._get_output_shapes(ds_fn), + output_classes=self._get_output_classes(ds_fn)) + saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + if sparse_tensors: + get_next = sparse_tensor.SparseTensor(*iterator.get_next()) + else: + get_next = iterator.get_next() + saver = saver_lib.Saver(allow_empty=True) + return get_next, saver + + def _add_iterator_ops_to_collection(self, + init_op, + get_next, + ds_fn, + sparse_tensors=False): + ops.add_to_collection("iterator_ops", init_op) + # `get_next` may be a tuple e.g. in TensorSliceDataset. Since Collections + # do not support tuples we flatten the tensors and restore the shape in + # `_get_iterator_ops_from_collection`. + if sparse_tensors: # specific for deprecated `from_sparse_tensor_slices`. + ops.add_to_collection("iterator_ops", get_next.indices) + ops.add_to_collection("iterator_ops", get_next.values) + ops.add_to_collection("iterator_ops", get_next.dense_shape) + return + + get_next_list = nest.flatten(get_next) + for i, output_class in enumerate( + nest.flatten(self._get_output_classes(ds_fn))): + if output_class is sparse_tensor.SparseTensor: + ops.add_to_collection("iterator_ops", get_next_list[i].indices) + ops.add_to_collection("iterator_ops", get_next_list[i].values) + ops.add_to_collection("iterator_ops", get_next_list[i].dense_shape) + else: + ops.add_to_collection("iterator_ops", get_next_list[i]) + + def _get_iterator_ops_from_collection(self, ds_fn, sparse_tensors=False): + all_ops = ops.get_collection("iterator_ops") + if sparse_tensors: # specific for deprecated `from_sparse_tensor_slices`. + init_op, indices, values, dense_shape = all_ops + return init_op, sparse_tensor.SparseTensor(indices, values, dense_shape) + get_next_list = [] + i = 1 + for output_class in nest.flatten(self._get_output_classes(ds_fn)): + if output_class is sparse_tensor.SparseTensor: + indices, values, dense_shape = all_ops[i:i + 3] + i += 3 + get_next_list.append( + sparse_tensor.SparseTensor(indices, values, dense_shape)) + else: + get_next_list.append(all_ops[i]) + i += 1 + return all_ops[0], nest.pack_sequence_as( + self._get_output_types(ds_fn), get_next_list) + + def _get_output_types(self, ds_fn): + with ops.Graph().as_default(): + return ds_fn().output_types + + def _get_output_shapes(self, ds_fn): + with ops.Graph().as_default(): + return ds_fn().output_shapes + + def _get_output_classes(self, ds_fn): + with ops.Graph().as_default(): + return ds_fn().output_classes + + def _ckpt_path(self): + return os.path.join(self.get_temp_dir(), "iterator") + + def _latest_ckpt(self): + return checkpoint_management.latest_checkpoint(self.get_temp_dir()) + + def _save(self, sess, saver): + saver.save(sess, self._ckpt_path()) + + def _restore(self, saver, sess): + sess.run(lookup_ops.tables_initializer()) + saver.restore(sess, self._latest_ckpt()) + + def _initialize(self, init_op, sess): + sess.run(variables.global_variables_initializer()) + sess.run(lookup_ops.tables_initializer()) + sess.run(init_op) + + def _import_meta_graph(self): + meta_file_path = self._ckpt_path() + ".meta" + return saver_lib.import_meta_graph(meta_file_path) + + def _delete_ckpt(self): + # Remove all checkpoint files. + prefix = self._ckpt_path() + pattern = prefix + "*" + files = gfile.Glob(pattern) + map(gfile.Remove, files) diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/filter_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/filter_dataset_serialization_test.py new file mode 100644 index 0000000000..225f6cbac0 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/filter_dataset_serialization_test.py @@ -0,0 +1,71 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the FilterDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class FilterDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_filter_range_graph(self, div): + return dataset_ops.Dataset.range(100).filter( + lambda x: math_ops.not_equal(math_ops.mod(x, div), 2)) + + def testFilterCore(self): + div = 3 + num_outputs = np.sum([x % 3 != 2 for x in range(100)]) + self.run_core_tests(lambda: self._build_filter_range_graph(div), + lambda: self._build_filter_range_graph(div * 2), + num_outputs) + + def _build_filter_dict_graph(self): + return dataset_ops.Dataset.range(10).map( + lambda x: {"foo": x * 2, "bar": x ** 2}).filter( + lambda d: math_ops.equal(d["bar"] % 2, 0)).map( + lambda d: d["foo"] + d["bar"]) + + def testFilterDictCore(self): + num_outputs = np.sum([(x**2) % 2 == 0 for x in range(10)]) + self.run_core_tests(self._build_filter_dict_graph, None, num_outputs) + + def _build_sparse_filter(self): + + def _map_fn(i): + return sparse_tensor.SparseTensor( + indices=[[0, 0]], values=(i * [1]), dense_shape=[1, 1]), i + + def _filter_fn(_, i): + return math_ops.equal(i % 2, 0) + + return dataset_ops.Dataset.range(10).map(_map_fn).filter(_filter_fn).map( + lambda x, i: x) + + def testSparseCore(self): + num_outputs = 5 + self.run_core_tests(self._build_sparse_filter, None, num_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py new file mode 100644 index 0000000000..70caf3e0d5 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py @@ -0,0 +1,45 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the FixedLengthRecordDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base +from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.ops import readers as core_readers +from tensorflow.python.platform import test + + +class FixedLengthRecordDatasetSerializationTest( + reader_dataset_ops_test_base.FixedLengthRecordDatasetTestBase, + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_iterator_graph(self, num_epochs, compression_type=None): + filenames = self._createFiles() + return core_readers.FixedLengthRecordDataset( + filenames, self._record_bytes, self._header_bytes, + self._footer_bytes).repeat(num_epochs) + + def testFixedLengthRecordCore(self): + num_epochs = 5 + num_outputs = num_epochs * self._num_files * self._num_records + self.run_core_tests(lambda: self._build_iterator_graph(num_epochs), + lambda: self._build_iterator_graph(num_epochs * 2), + num_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/flat_map_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/flat_map_dataset_serialization_test.py new file mode 100644 index 0000000000..c30534a9e9 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/flat_map_dataset_serialization_test.py @@ -0,0 +1,122 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the FlatMapDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import function +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import sparse_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import test + + +class FlatMapDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def testCore(self): + # Complicated way of saying range(start, start+25). + def build_ds(start): + + def map_fn(x): + return dataset_ops.Dataset.range(x, x + 5) + + return dataset_ops.Dataset.range(start, start + 5 * 5, 5).flat_map(map_fn) + + self.run_core_tests(lambda: build_ds(0), lambda: build_ds(10), 25) + + def testMapThenFlatMap(self): + + def build_ds(): + + def flat_map_fn(_): + + def map_fn(y): + return 10 * math_ops.to_int32(y) + + return dataset_ops.Dataset.range(100).map(map_fn) + + return dataset_ops.Dataset.range(5).flat_map(flat_map_fn) + + self.run_core_tests(build_ds, None, 500) + + def testCaptureDefunInMapFn(self): + + def build_ds(): + + def map_fn(x): + + @function.Defun(dtypes.int64) + def defun_fn(x): + return constant_op.constant(1000) + math_ops.to_int32(x) + + return dataset_ops.Dataset.from_tensor_slices([defun_fn(x)]) + + return dataset_ops.Dataset.range(100).flat_map(map_fn) + + self.run_core_tests(build_ds, None, 100) + + def testDisallowVariableCapture(self): + + def build_ds(): + test_var = variable_scope.get_variable( + name="test_var", shape=(), use_resource=True) + return dataset_ops.Dataset.range(5).flat_map( + lambda _: dataset_ops.Dataset.from_tensor_slices([test_var])) + + self.verify_error_on_save(build_ds, 5, errors.InvalidArgumentError) + + def testDisallowCapturingStatefulOps(self): + + def build_ds(): + + def flat_map_fn(_): + + def map_fn(x): + return random_ops.random_uniform( + (), 0, 10, dtype=dtypes.int32) * math_ops.to_int32(x) + + return dataset_ops.Dataset.range(100).map(map_fn) + + return dataset_ops.Dataset.range(5).flat_map(flat_map_fn) + + self.verify_error_on_save(build_ds, 500, errors.InvalidArgumentError) + + def testSparseCore(self): + + def _map_fn(i): + return sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) + + def _flat_map_fn(x): + return dataset_ops.Dataset.from_tensor_slices( + sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) + + def _build_ds(): + return dataset_ops.Dataset.range(10).map(_map_fn).flat_map(_flat_map_fn) + + self.run_core_tests(_build_ds, None, 20) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/group_by_reducer_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/group_by_reducer_serialization_test.py new file mode 100644 index 0000000000..169c8845d0 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/group_by_reducer_serialization_test.py @@ -0,0 +1,61 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the GroupByReducer serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.experimental.ops import grouping +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.platform import test + + +class GroupByReducerSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset(self, components): + reducer = grouping.Reducer( + init_func=lambda _: np.int64(0), + reduce_func=lambda x, y: x + y, + finalize_func=lambda x: x) + + return dataset_ops.Dataset.from_tensor_slices(components).apply( + grouping.group_by_reducer(lambda x: x % 5, reducer)) + + def testCoreGroupByReducer(self): + components = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=np.int64) + self.verify_unused_iterator( + lambda: self._build_dataset(components), 5, verify_exhausted=True) + self.verify_init_before_restore( + lambda: self._build_dataset(components), 5, verify_exhausted=True) + self.verify_multiple_breaks( + lambda: self._build_dataset(components), 5, verify_exhausted=True) + self.verify_reset_restored_iterator( + lambda: self._build_dataset(components), 5, verify_exhausted=True) + self.verify_restore_in_empty_graph( + lambda: self._build_dataset(components), 5, verify_exhausted=True) + diff_components = np.array([5, 4, 3, 2, 1, 0], dtype=np.int64) + self.verify_restore_in_modified_graph( + lambda: self._build_dataset(components), + lambda: self._build_dataset(diff_components), + 5, + verify_exhausted=True) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/group_by_window_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/group_by_window_serialization_test.py new file mode 100644 index 0000000000..e5bc76288e --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/group_by_window_serialization_test.py @@ -0,0 +1,57 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the GroupByWindow serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.experimental.ops import grouping +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.platform import test + + +class GroupByWindowSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset(self, components): + return dataset_ops.Dataset.from_tensor_slices(components).repeat(-1).apply( + grouping.group_by_window(lambda x: x % 3, lambda _, xs: xs.batch(4), 4)) + + def testCoreGroupByWindow(self): + components = np.array( + [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0], dtype=np.int64) + self.verify_unused_iterator( + lambda: self._build_dataset(components), 12, verify_exhausted=False) + self.verify_init_before_restore( + lambda: self._build_dataset(components), 12, verify_exhausted=False) + self.verify_multiple_breaks( + lambda: self._build_dataset(components), 12, verify_exhausted=False) + self.verify_reset_restored_iterator( + lambda: self._build_dataset(components), 12, verify_exhausted=False) + self.verify_restore_in_empty_graph( + lambda: self._build_dataset(components), 12, verify_exhausted=False) + diff_components = np.array([0, 0, 0, 1, 1, 1], dtype=np.int64) + self.verify_restore_in_modified_graph( + lambda: self._build_dataset(components), + lambda: self._build_dataset(diff_components), + 12, + verify_exhausted=False) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/ignore_errors_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/ignore_errors_serialization_test.py new file mode 100644 index 0000000000..df1f43129a --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/ignore_errors_serialization_test.py @@ -0,0 +1,46 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the IgnoreErrors input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.experimental.ops import error_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class IgnoreErrorsSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_ds(self, components): + return dataset_ops.Dataset.from_tensor_slices(components).map( + lambda x: array_ops.check_numerics(x, "message")).apply( + error_ops.ignore_errors()) + + def testIgnoreErrorsCore(self): + components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32) + diff_components = np.array([1., 2., 3., np.nan]).astype(np.float32) + num_outputs = 4 + self.run_core_tests(lambda: self._build_ds(components), + lambda: self._build_ds(diff_components), num_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/interleave_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/interleave_dataset_serialization_test.py new file mode 100644 index 0000000000..0c1d40ce39 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/interleave_dataset_serialization_test.py @@ -0,0 +1,83 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the InterleaveDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import sparse_ops +from tensorflow.python.platform import test + + +class InterleaveDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase, + parameterized.TestCase): + + def _build_iterator_graph(self, input_values, cycle_length, block_length, + num_parallel_calls): + repeat_count = 2 + return dataset_ops.Dataset.from_tensor_slices(input_values).repeat( + repeat_count).interleave( + lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x), + cycle_length, block_length, num_parallel_calls) + + @parameterized.named_parameters( + ("1", 2, 3, None), + ("2", 2, 3, 1), + ("3", 2, 3, 2), + ("4", 1, 3, None), + ("5", 1, 3, 1), + ("6", 2, 1, None), + ("7", 2, 1, 1), + ("8", 2, 1, 2), + ) + def testSerializationCore(self, cycle_length, block_length, + num_parallel_calls): + input_values = np.array([4, 5, 6], dtype=np.int64) + num_outputs = np.sum(input_values) * 2 + # pylint: disable=g-long-lambda + self.run_core_tests( + lambda: self._build_iterator_graph( + input_values, cycle_length, block_length, num_parallel_calls), + lambda: self._build_iterator_graph( + input_values, cycle_length * 2, block_length, num_parallel_calls), + num_outputs) + # pylint: enable=g-long-lambda + + def testSparseCore(self): + + def _map_fn(i): + return sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) + + def _interleave_fn(x): + return dataset_ops.Dataset.from_tensor_slices( + sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) + + def _build_dataset(): + return dataset_ops.Dataset.range(10).map(_map_fn).interleave( + _interleave_fn, cycle_length=1) + + self.run_core_tests(_build_dataset, None, 20) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py new file mode 100644 index 0000000000..166ffa99ca --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py @@ -0,0 +1,88 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the MapAndBatchDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.experimental.ops import batching +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class MapAndBatchDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def testNumParallelBatches(self): + range_size = 11 + num_repeats = 2 + batch_size = 5 + total_outputs = range_size * num_repeats + num_outputs_drop_remainder = total_outputs // batch_size + num_outputs_keep_remainder = int(math.ceil(total_outputs / batch_size)) + num_parallel_batches = 2 + + def build_ds(range_start, drop_remainder=False): + + def _map_fn(x): + return math_ops.square(x) + + return dataset_ops.Dataset.range( + range_start, range_start + range_size).repeat(num_repeats).apply( + batching.map_and_batch( + map_func=_map_fn, + batch_size=batch_size, + num_parallel_batches=num_parallel_batches, + drop_remainder=drop_remainder)) + + self.run_core_tests(lambda: build_ds(10), lambda: build_ds(15), + num_outputs_keep_remainder) + self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True), + num_outputs_drop_remainder) + + def testNumParallelCalls(self): + range_size = 11 + num_repeats = 2 + batch_size = 5 + total_outputs = range_size * num_repeats + num_outputs_drop_remainder = total_outputs // batch_size + num_outputs_keep_remainder = int(math.ceil(total_outputs / batch_size)) + num_parallel_calls = 7 + + def build_ds(range_start, drop_remainder=False): + + def _map_fn(x): + return math_ops.square(x) + + return dataset_ops.Dataset.range( + range_start, range_start + range_size).repeat(num_repeats).apply( + batching.map_and_batch( + map_func=_map_fn, + batch_size=batch_size, + num_parallel_calls=num_parallel_calls, + drop_remainder=drop_remainder)) + + self.run_core_tests(lambda: build_ds(10), lambda: build_ds(15), + num_outputs_keep_remainder) + self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True), + num_outputs_drop_remainder) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/map_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/map_dataset_serialization_test.py new file mode 100644 index 0000000000..b93156a96c --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/map_dataset_serialization_test.py @@ -0,0 +1,140 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the MapDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import function +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import test + + +class MapDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def setUp(self): + self._tensor_slice_len = 7 + self._num_epochs = 14 + self._num_outputs = self._tensor_slice_len * self._num_epochs + + def _build_ds(self, multiplier=37.0): + components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) * + np.arange(self._tensor_slice_len)[:, np.newaxis], + np.array(multiplier) * np.arange(self._tensor_slice_len)) + + def _map_fn(x, y, z): + return math_ops.square(x), math_ops.square(y), math_ops.square(z) + + return ( + dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) + .repeat(self._num_epochs)) + + def testSaveRestoreCore(self): + self.run_core_tests( + self._build_ds, + lambda: self._build_ds(multiplier=15.0), + self._num_outputs) + + def testSaveStatefulFunction(self): + + def _build_ds(): + + def _map_fn(x): + return random_ops.random_uniform( + (), 0, 10, dtype=dtypes.int32) * math_ops.to_int32(x) + + return dataset_ops.Dataset.range(100).map(_map_fn) + + self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) + + def testCaptureVariableInMapFn(self): + + def _build_ds(): + counter_var = variable_scope.get_variable( + "counter", (), dtypes.int32, use_resource=True) + return (dataset_ops.Dataset.from_tensors(0).repeat(10).map( + lambda _: counter_var.assign_add(1))) + + self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) + + def testCaptureConstantInMapFn(self): + + def _build_ds(): + constant_var = constant_op.constant(5) + return (dataset_ops.Dataset.from_tensors(0).repeat(10).map( + lambda x: x + constant_var)) + + self.run_core_tests(_build_ds, None, 10) + + def testCaptureDefunInMapFn(self): + num_outputs = 100 + + def _build_ds(): + + @function.Defun(dtypes.int64) + def defun_fn(x): + return constant_op.constant(1000) + math_ops.to_int32(x) + + return dataset_ops.Dataset.range(num_outputs).map(defun_fn) + + self.run_core_tests(_build_ds, None, num_outputs) + + def testBuildDefunInMapFn(self): + num_outputs = 100 + + def _build_ds(): + + @function.Defun(dtypes.int64) + def defun_fn(x): + + @function.Defun(dtypes.int32) + def defun_fn_deep(x): + return constant_op.constant(1000) + math_ops.to_int32(x) + + return constant_op.constant(11000) + defun_fn_deep(math_ops.to_int32(x)) + + return dataset_ops.Dataset.range(num_outputs).map(defun_fn) + + self.run_core_tests(_build_ds, None, num_outputs) + + def testSparseCore(self): + + def _sparse(i): + return sparse_tensor.SparseTensorValue( + indices=np.array([[0, 0]]), + values=(i * np.array([1])), + dense_shape=np.array([1, 1])) + + def _build_ds(num_outputs): + return dataset_ops.Dataset.range(num_outputs).map(_sparse) + + num_outputs = 10 + self.run_core_tests(lambda: _build_ds(num_outputs), + lambda: _build_ds(int(num_outputs / 2)), num_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/optimize_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/optimize_dataset_serialization_test.py new file mode 100644 index 0000000000..ed4a1da596 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/optimize_dataset_serialization_test.py @@ -0,0 +1,39 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the OptimizeDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.experimental.ops import optimization +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.platform import test + + +class OptimizeDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def testCore(self): + + def build_dataset(num_elements, batch_size): + return dataset_ops.Dataset.range(num_elements).map(lambda x: x * x).batch( + batch_size).apply(optimization.optimize(["map_and_batch_fusion"])) + + self.run_core_tests(lambda: build_dataset(200, 10), None, 20) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/padded_batch_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/padded_batch_dataset_serialization_test.py new file mode 100644 index 0000000000..6f72b24673 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/padded_batch_dataset_serialization_test.py @@ -0,0 +1,66 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the PaddedBatchDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import string_ops +from tensorflow.python.platform import test + + +class PaddedBatchDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def testPaddedBatch(self): + + def build_dataset(seq_lens): + return dataset_ops.Dataset.from_tensor_slices(seq_lens).map( + lambda x: array_ops.fill([x], x)).padded_batch( + 4, padded_shapes=[-1]) + + seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32) + seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32) + self.run_core_tests(lambda: build_dataset(seq_lens1), + lambda: build_dataset(seq_lens2), 8) + + def testPaddedBatchNonDefaultPadding(self): + + def build_dataset(seq_lens): + + def fill_tuple(x): + filled = array_ops.fill([x], x) + return (filled, string_ops.as_string(filled)) + + padded_shape = [-1] + return dataset_ops.Dataset.from_tensor_slices(seq_lens).map( + fill_tuple).padded_batch( + 4, + padded_shapes=(padded_shape, padded_shape), + padding_values=(-1, "<end>")) + + seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32) + seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32) + self.run_core_tests(lambda: build_dataset(seq_lens1), + lambda: build_dataset(seq_lens2), 8) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py new file mode 100644 index 0000000000..b8f38e8a28 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py @@ -0,0 +1,101 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the ParallelInterleaveDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.experimental.ops import interleave_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import sparse_ops +from tensorflow.python.platform import test + + +class ParallelInterleaveDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def setUp(self): + self.input_values = np.array([4, 5, 6], dtype=np.int64) + self.num_repeats = 2 + self.num_outputs = np.sum(self.input_values) * 2 + + def _build_ds(self, cycle_length, block_length, sloppy=False): + return (dataset_ops.Dataset.from_tensor_slices( + self.input_values).repeat(self.num_repeats).apply( + interleave_ops.parallel_interleave( + lambda x: dataset_ops.Dataset.range(10 * x, 11 * x), + cycle_length, block_length, sloppy))) + + def testSerializationCore(self): + # cycle_length > 1, block_length > 1 + cycle_length = 2 + block_length = 3 + self.run_core_tests( + lambda: self._build_ds(cycle_length, block_length), + lambda: self._build_ds(cycle_length * 2, block_length * 1), + self.num_outputs) + # cycle_length = 1 + cycle_length = 1 + block_length = 3 + self.run_core_tests(lambda: self._build_ds(cycle_length, block_length), + None, self.num_outputs) + # block_length = 1 + cycle_length = 2 + block_length = 1 + self.run_core_tests(lambda: self._build_ds(cycle_length, block_length), + None, self.num_outputs) + + def testSerializationWithSloppy(self): + break_points = self.gen_break_points(self.num_outputs, 10) + expected_outputs = np.repeat( + np.concatenate([np.arange(10 * x, 11 * x) for x in self.input_values]), + self.num_repeats).tolist() + + def run_test(cycle_length, block_length): + actual = self.gen_outputs( + lambda: self._build_ds(cycle_length, block_length, True), + break_points, self.num_outputs) + self.assertSequenceEqual(sorted(actual), expected_outputs) + + # cycle_length > 1, block_length > 1 + run_test(2, 3) + # cycle_length = 1 + run_test(1, 3) + # block_length = 1 + run_test(2, 1) + + def testSparseCore(self): + + def _map_fn(i): + return sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) + + def _interleave_fn(x): + return dataset_ops.Dataset.from_tensor_slices( + sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) + + def _build_dataset(): + return dataset_ops.Dataset.range(10).map(_map_fn).apply( + interleave_ops.parallel_interleave(_interleave_fn, 1)) + + self.run_core_tests(_build_dataset, None, 20) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/parallel_map_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/parallel_map_dataset_serialization_test.py new file mode 100644 index 0000000000..a0bdd4fa59 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/parallel_map_dataset_serialization_test.py @@ -0,0 +1,139 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the ParallelMapDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import function +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import test + + +class ParallelMapDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def setUp(self): + self._tensor_slice_len = 7 + self._num_epochs = 1 + self._num_outputs = self._tensor_slice_len * self._num_epochs + + def _build_ds(self, multiplier=37.0): + components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) * + np.arange(self._tensor_slice_len)[:, np.newaxis], + np.array(multiplier) * np.arange(self._tensor_slice_len)) + + def _map_fn(x, y, z): + return math_ops.square(x), math_ops.square(y), math_ops.square(z) + + return (dataset_ops.Dataset.from_tensor_slices(components).map( + _map_fn, num_parallel_calls=3).repeat(self._num_epochs)) + + def _build_ds_with_prefetch(self, multiplier=37.0): + components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) * + np.arange(self._tensor_slice_len)[:, np.newaxis], + np.array(multiplier) * np.arange(self._tensor_slice_len)) + + def _map_fn(x, y, z): + return math_ops.square(x), math_ops.square(y), math_ops.square(z) + + return (dataset_ops.Dataset.from_tensor_slices(components).map( + _map_fn, num_parallel_calls=3).repeat(self._num_epochs).prefetch(5)) + + def testSaveRestoreCore(self): + for ds_fn in [self._build_ds, self._build_ds_with_prefetch]: + self.run_core_tests( + ds_fn, + lambda: ds_fn(multiplier=15.0), # pylint: disable=cell-var-from-loop + self._num_outputs) + + def testSaveStatefulFunction(self): + + def _build_ds(): + + def _map_fn(x): + return random_ops.random_uniform( + (), 0, 10, dtype=dtypes.int32) * math_ops.to_int32(x) + + return dataset_ops.Dataset.range(100).map( + _map_fn, num_parallel_calls=2).prefetch(2) + + self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) + + def testCaptureVariableInMapFn(self): + + def _build_ds(): + counter_var = variable_scope.get_variable( + "counter", (), dtypes.int32, use_resource=True) + return (dataset_ops.Dataset.from_tensors(0).repeat(10).map( + lambda _: counter_var.assign_add(1), + num_parallel_calls=2).prefetch(2)) + + self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) + + def testCaptureConstantInMapFn(self): + + def _build_ds(): + constant_var = constant_op.constant(5) + return (dataset_ops.Dataset.from_tensors(0).repeat(10).map( + lambda x: x + constant_var, num_parallel_calls=2).prefetch(2)) + + self.run_core_tests(_build_ds, None, 10) + + def testCaptureDefunInMapFn(self): + num_outputs = 100 + + def _build_ds(): + + @function.Defun(dtypes.int64) + def defun_fn(x): + return constant_op.constant(1000) + math_ops.to_int32(x) + + return dataset_ops.Dataset.range(num_outputs).map( + defun_fn, num_parallel_calls=2).prefetch(2) + + self.run_core_tests(_build_ds, None, num_outputs) + + def testBuildDefunInMapFn(self): + num_outputs = 100 + + def _build_ds(): + + @function.Defun(dtypes.int64) + def defun_fn(x): + + @function.Defun(dtypes.int32) + def defun_fn_deep(x): + return constant_op.constant(1000) + math_ops.to_int32(x) + + return constant_op.constant(11000) + defun_fn_deep(math_ops.to_int32(x)) + + return dataset_ops.Dataset.range(num_outputs).map( + defun_fn, num_parallel_calls=2).prefetch(2) + + self.run_core_tests(_build_ds, None, num_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/parse_example_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/parse_example_dataset_serialization_test.py new file mode 100644 index 0000000000..a0dd6960b0 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/parse_example_dataset_serialization_test.py @@ -0,0 +1,50 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the ParseExampleDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base +from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.platform import test + + +class ParseExampleDatasetSerializationTest( + reader_dataset_ops_test_base.ReadBatchFeaturesTestBase, + dataset_serialization_test_base.DatasetSerializationTestBase): + + def ParseExampleDataset(self, num_repeat, batch_size): + return self.make_batch_feature( + filenames=self.test_filenames, + num_epochs=num_repeat, + batch_size=batch_size, + reader_num_threads=5, + parser_num_threads=10) + + def testSerializationCore(self): + num_repeat = 5 + batch_size = 2 + num_outputs = self._num_records * self._num_files * num_repeat // batch_size + # pylint: disable=g-long-lambda + self.run_core_tests( + lambda: self.ParseExampleDataset( + num_repeat=num_repeat, batch_size=batch_size), + lambda: self.ParseExampleDataset(num_repeat=10, batch_size=4), + num_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/prefetch_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/prefetch_dataset_serialization_test.py new file mode 100644 index 0000000000..00d74c0025 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/prefetch_dataset_serialization_test.py @@ -0,0 +1,39 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the PrefetchDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.platform import test + + +class PrefetchDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def build_dataset(self, seed): + return dataset_ops.Dataset.range(100).prefetch(10).shuffle( + buffer_size=10, seed=seed, reshuffle_each_iteration=False) + + def testCore(self): + num_outputs = 100 + self.run_core_tests(lambda: self.build_dataset(10), + lambda: self.build_dataset(20), num_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/range_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/range_dataset_serialization_test.py new file mode 100644 index 0000000000..ef99d01c73 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/range_dataset_serialization_test.py @@ -0,0 +1,118 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the RangeDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.ops import io_ops +from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +class RangeDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _iterator_checkpoint_prefix_local(self): + return os.path.join(self.get_temp_dir(), "iterator") + + def _save_op(self, iterator_resource): + iterator_state_variant = gen_dataset_ops.serialize_iterator( + iterator_resource) + save_op = io_ops.write_file( + self._iterator_checkpoint_prefix_local(), + parsing_ops.serialize_tensor(iterator_state_variant)) + return save_op + + def _restore_op(self, iterator_resource): + iterator_state_variant = parsing_ops.parse_tensor( + io_ops.read_file(self._iterator_checkpoint_prefix_local()), + dtypes.variant) + restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource, + iterator_state_variant) + return restore_op + + def testSaveRestore(self): + + def _build_graph(start, stop): + iterator = dataset_ops.Dataset.range(start, + stop).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) + return init_op, get_next, save_op, restore_op + + # Saving and restoring in different sessions. + start = 2 + stop = 10 + break_point = 5 + with ops.Graph().as_default() as g: + init_op, get_next, save_op, _ = _build_graph(start, stop) + with self.session(graph=g) as sess: + sess.run(variables.global_variables_initializer()) + sess.run(init_op) + for i in range(start, break_point): + self.assertEqual(i, sess.run(get_next)) + sess.run(save_op) + + with ops.Graph().as_default() as g: + init_op, get_next, _, restore_op = _build_graph(start, stop) + with self.session(graph=g) as sess: + sess.run(init_op) + sess.run(restore_op) + for i in range(break_point, stop): + self.assertEqual(i, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Saving and restoring in same session. + with ops.Graph().as_default() as g: + init_op, get_next, save_op, restore_op = _build_graph(start, stop) + with self.session(graph=g) as sess: + sess.run(variables.global_variables_initializer()) + sess.run(init_op) + for i in range(start, break_point): + self.assertEqual(i, sess.run(get_next)) + sess.run(save_op) + sess.run(restore_op) + for i in range(break_point, stop): + self.assertEqual(i, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def _build_range_dataset(self, start, stop): + return dataset_ops.Dataset.range(start, stop) + + def testRangeCore(self): + start = 2 + stop = 10 + stop_1 = 8 + self.run_core_tests(lambda: self._build_range_dataset(start, stop), + lambda: self._build_range_dataset(start, stop_1), + stop - start) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/sample_from_datasets_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/sample_from_datasets_serialization_test.py new file mode 100644 index 0000000000..c23c1ecdfb --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/sample_from_datasets_serialization_test.py @@ -0,0 +1,46 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the SampleFromDatasets serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.experimental.ops import interleave_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.platform import test + + +class SampleFromDatasetsSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset(self, probs, num_samples): + dataset = interleave_ops.sample_from_datasets( + [ + dataset_ops.Dataset.from_tensors(i).repeat(None) + for i in range(len(probs)) + ], + probs, + seed=1813) + return dataset.take(num_samples) + + def testSerializationCore(self): + self.run_core_tests( + lambda: self._build_dataset([0.5, 0.5], 100), + lambda: self._build_dataset([0.25, 0.25, 0.25, 0.25], 1000), 100) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/scan_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/scan_dataset_serialization_test.py new file mode 100644 index 0000000000..5f50160619 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/scan_dataset_serialization_test.py @@ -0,0 +1,40 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the ScanDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.experimental.ops import scan_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.platform import test + + +class ScanDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset(self, num_elements): + return dataset_ops.Dataset.from_tensors(1).repeat(num_elements).apply( + scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1]))) + + def testScanCore(self): + num_output = 5 + self.run_core_tests(lambda: self._build_dataset(num_output), + lambda: self._build_dataset(2), num_output) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/sequence_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/sequence_dataset_serialization_test.py new file mode 100644 index 0000000000..fe99a3d3d9 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/sequence_dataset_serialization_test.py @@ -0,0 +1,129 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the sequence datasets serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.platform import test + + +class SkipDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_skip_dataset(self, count): + components = (np.arange(10),) + return dataset_ops.Dataset.from_tensor_slices(components).skip(count) + + def testSkipFewerThanInputs(self): + count = 4 + num_outputs = 10 - count + self.run_core_tests(lambda: self._build_skip_dataset(count), + lambda: self._build_skip_dataset(count + 2), + num_outputs) + + def testSkipVarious(self): + # Skip more than inputs + self.run_core_tests(lambda: self._build_skip_dataset(20), None, 0) + # Skip exactly the input size + self.run_core_tests(lambda: self._build_skip_dataset(10), None, 0) + self.run_core_tests(lambda: self._build_skip_dataset(-1), None, 0) + # Skip nothing + self.run_core_tests(lambda: self._build_skip_dataset(0), None, 10) + + def testInvalidSkip(self): + with self.assertRaisesRegexp(ValueError, + 'Shape must be rank 0 but is rank 1'): + self.run_core_tests(lambda: self._build_skip_dataset([1, 2]), None, 0) + + +class TakeDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_take_dataset(self, count): + components = (np.arange(10),) + return dataset_ops.Dataset.from_tensor_slices(components).take(count) + + def testTakeFewerThanInputs(self): + count = 4 + self.run_core_tests( + lambda: self._build_take_dataset(count), + lambda: self._build_take_dataset(count + 2), + count, + ) + + def testTakeVarious(self): + # Take more than inputs + self.run_core_tests(lambda: self._build_take_dataset(20), None, 10) + # Take exactly the input size + self.run_core_tests(lambda: self._build_take_dataset(10), None, 10) + # Take all + self.run_core_tests(lambda: self._build_take_dataset(-1), None, 10) + # Take nothing + self.run_core_tests(lambda: self._build_take_dataset(0), None, 0) + + def testInvalidTake(self): + with self.assertRaisesRegexp(ValueError, + 'Shape must be rank 0 but is rank 1'): + self.run_core_tests(lambda: self._build_take_dataset([1, 2]), None, 0) + + +class RepeatDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_repeat_dataset(self, count, take_count=3): + components = (np.arange(10),) + return dataset_ops.Dataset.from_tensor_slices(components).take( + take_count).repeat(count) + + def testFiniteRepeat(self): + count = 10 + self.run_core_tests(lambda: self._build_repeat_dataset(count), + lambda: self._build_repeat_dataset(count + 2), + 3 * count) + + def testEmptyRepeat(self): + self.run_core_tests(lambda: self._build_repeat_dataset(0), None, 0) + + def testInfiniteRepeat(self): + self.verify_unused_iterator( + lambda: self._build_repeat_dataset(-1), 10, verify_exhausted=False) + self.verify_init_before_restore( + lambda: self._build_repeat_dataset(-1), 10, verify_exhausted=False) + self.verify_multiple_breaks( + lambda: self._build_repeat_dataset(-1), 20, verify_exhausted=False) + self.verify_reset_restored_iterator( + lambda: self._build_repeat_dataset(-1), 20, verify_exhausted=False) + self.verify_restore_in_modified_graph( + lambda: self._build_repeat_dataset(-1), + lambda: self._build_repeat_dataset(2), + 20, + verify_exhausted=False) + # Test repeat empty dataset + self.run_core_tests(lambda: self._build_repeat_dataset(-1, 0), None, 0) + + def testInvalidRepeat(self): + with self.assertRaisesRegexp( + ValueError, 'Shape must be rank 0 but is rank 1'): + self.run_core_tests(lambda: self._build_repeat_dataset([1, 2], 0), + None, 0) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/serialization_integration_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/serialization_integration_test.py new file mode 100644 index 0000000000..88d5c896c9 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/serialization_integration_test.py @@ -0,0 +1,85 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Integration test for dataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.python.data.experimental.ops import iterator_ops as contrib_iterator_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import ops +from tensorflow.python.platform import test +from tensorflow.python.training import saver as saver_lib + + +class SerializationIntegrationTest(test.TestCase): + + def _build_input_pipeline(self, name, num_outputs): + with ops.name_scope(name): + ds = dataset_ops.Dataset.range(num_outputs).shuffle( + 10, reshuffle_each_iteration=False).prefetch(10) + iterator = ds.make_initializable_iterator() + saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + return iterator.initializer, iterator.get_next() + + def _build_graph(self, num_pipelines, num_outputs): + init_ops = [] + get_next_ops = [] + for i in range(num_pipelines): + name = "input_pipeline_%d" % i + init_op, get_next_op = self._build_input_pipeline(name, num_outputs) + init_ops.append(init_op) + get_next_ops.append(get_next_op) + saver = saver_lib.Saver() + return init_ops, get_next_ops, saver + + def _ckpt_path(self): + return os.path.join(self.get_temp_dir(), "iterator") + + def testConcurrentSaves(self): + num_pipelines = 100 + num_outputs = 100 + break_point = 10 + all_outputs = [[] for _ in range(num_pipelines)] + with ops.Graph().as_default() as g: + init_ops, get_next_ops, saver = self._build_graph(num_pipelines, + num_outputs) + with self.session(graph=g) as sess: + sess.run(init_ops) + for _ in range(break_point): + output = sess.run(get_next_ops) + for i in range(num_pipelines): + all_outputs[i].append(output[i]) + saver.save(sess, self._ckpt_path()) + + with ops.Graph().as_default() as g: + init_ops, get_next_ops, saver = self._build_graph(num_pipelines, + num_outputs) + with self.session(graph=g) as sess: + saver.restore(sess, self._ckpt_path()) + for _ in range(num_outputs - break_point): + output = sess.run(get_next_ops) + for i in range(num_pipelines): + all_outputs[i].append(output[i]) + + for output in all_outputs: + self.assertSequenceEqual(sorted(output), range(num_outputs)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py new file mode 100644 index 0000000000..f847ac19f9 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py @@ -0,0 +1,39 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the ShuffleAndRepeatDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.experimental.ops import shuffle_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.platform import test + + +class ShuffleAndRepeatSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_ds(self, seed): + return dataset_ops.Dataset.range(20).apply( + shuffle_ops.shuffle_and_repeat(buffer_size=5, count=5, seed=seed)) + + def testCore(self): + self.run_core_tests(lambda: self._build_ds(10), lambda: self._build_ds(20), + 100) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/shuffle_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/shuffle_dataset_serialization_test.py new file mode 100644 index 0000000000..a04f1ddafc --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/shuffle_dataset_serialization_test.py @@ -0,0 +1,148 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the ShuffleDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.experimental.ops import iterator_ops as contrib_iterator_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import ops +from tensorflow.python.platform import test +from tensorflow.python.training import saver as saver_lib + + +class ShuffleDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_shuffle_dataset( + self, + range_limit=10, + num_repeats=5, + buffer_size=5, + seed=None, + reshuffle_each_iteration=None, + ): + return dataset_ops.Dataset.range(range_limit).shuffle( + buffer_size, + seed=seed, + reshuffle_each_iteration=reshuffle_each_iteration).repeat(num_repeats) + + def testShuffleCore(self): + + seed = 55 + range_limit = 5 + num_repeats = 2 + num_outputs = range_limit * num_repeats + buffer_sizes = [1, 3, 5, 8, 10] + # pylint: disable=cell-var-from-loop + # pylint: disable=g-long-lambda + for reshuffle_each_iteration in [True, False]: + for buffer_size in buffer_sizes: + self.run_core_tests( + lambda: self._build_shuffle_dataset( + range_limit=range_limit, + num_repeats=num_repeats, + buffer_size=buffer_size, + seed=seed, + reshuffle_each_iteration=reshuffle_each_iteration), + lambda: self._build_shuffle_dataset( + range_limit=range_limit, + num_repeats=num_repeats, + buffer_size=buffer_size, + seed=10, + reshuffle_each_iteration=reshuffle_each_iteration), + num_outputs) + # pylint: enable=cell-var-from-loop + # pylint: enable=g-long-lambda + + def testNonDeterministicSeeding(self): + + range_limit = 5 + num_repeats = 2 + num_outputs = range_limit * num_repeats + buffer_sizes = [1, 3, 5, 8, 10] + for reshuffle_each_iteration in [True, False]: + for buffer_size in buffer_sizes: + + def ds_fn(): + # pylint: disable=cell-var-from-loop + return self._build_shuffle_dataset( + range_limit=range_limit, + num_repeats=num_repeats, + buffer_size=buffer_size, + seed=None, # Iterator seeds are generated non-deterministically. + reshuffle_each_iteration=reshuffle_each_iteration) + # pylint: enable=cell-var-from-loop + + # We checkpoint the initial state of the Dataset so that we can restore + # the seeds in the next run. Since the seeding is non-deterministic + # the dataset gets initialized with different seeds each time. + expected = self.gen_outputs( + ds_fn, + break_points=[0], + num_outputs=num_outputs, + ckpt_saved=False, + verify_exhausted=False, + save_checkpoint_at_end=False) + actual = self.gen_outputs( + ds_fn, + break_points=self.gen_break_points(num_outputs), + num_outputs=num_outputs, + ckpt_saved=True, + verify_exhausted=False) + self.match(expected, actual) + + def testMultipleIterators(self): + range_limit = 5 + num_repeats = 2 + num_outputs = range_limit * num_repeats + buffer_sizes = [1, 3, 5, 8, 10] + + for reshuffle_each_iteration in [True, False]: + for buffer_size in buffer_sizes: + + def ds_fn(): + # pylint: disable=cell-var-from-loop + return self._build_shuffle_dataset( + range_limit=range_limit, + num_repeats=num_repeats, + buffer_size=buffer_size, + seed=None, # Iterator seeds are generated non-deterministically. + reshuffle_each_iteration=reshuffle_each_iteration) + # pylint: enable=cell-var-from-loop + + with ops.Graph().as_default() as g: + ds = ds_fn() + iterators = [ds.make_one_shot_iterator(), ds.make_one_shot_iterator()] + get_next_ops = [it.get_next() for it in iterators] + saveables = [ + contrib_iterator_ops.make_saveable_from_iterator(it) + for it in iterators + ] + for saveable in saveables: + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + saver = saver_lib.Saver(allow_empty=True) + with self.session(graph=g) as sess: + self._save(sess, saver) + expected = [sess.run(get_next_ops) for _ in range(num_outputs)] + self._restore(saver, sess) + actual = [sess.run(get_next_ops) for _ in range(num_outputs)] + self.match(expected, actual) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/sql_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/sql_dataset_serialization_test.py new file mode 100644 index 0000000000..b179770ce3 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/sql_dataset_serialization_test.py @@ -0,0 +1,53 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the SqlDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.python.data.experimental.kernel_tests import sql_dataset_op_test_base +from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.experimental.ops import readers +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class SqlDatasetSerializationTest( + sql_dataset_op_test_base.SqlDatasetTestBase, + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset(self, num_repeats): + data_source_name = os.path.join(test.get_temp_dir(), "tftest.sqlite") + driver_name = array_ops.placeholder_with_default( + array_ops.constant("sqlite", dtypes.string), shape=[]) + query = ("SELECT first_name, last_name, motto FROM students ORDER BY " + "first_name DESC") + output_types = (dtypes.string, dtypes.string, dtypes.string) + return readers.SqlDataset(driver_name, data_source_name, query, + output_types).repeat(num_repeats) + + def testSQLSaveable(self): + num_repeats = 4 + num_outputs = num_repeats * 2 + self.run_core_tests(lambda: self._build_dataset(num_repeats), + lambda: self._build_dataset(num_repeats // 2), + num_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/stats_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/stats_dataset_serialization_test.py new file mode 100644 index 0000000000..ef7061b190 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/stats_dataset_serialization_test.py @@ -0,0 +1,106 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the StatsDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.experimental.ops import stats_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +# TODO(shivaniagrawal): Can not checkpoint input_pipeline with the +# transformation `stats_ops.set_stats_aggregator`, since we don't support +# serializing StatsAggregator yet. +class StatsDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset_bytes_stats(self, num_elements): + return dataset_ops.Dataset.range(num_elements).map( + lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).apply( + stats_ops.bytes_produced_stats("bytes_produced")) + + def test_bytes_produced_stats_invalid_tag_shape(self): + with self.assertRaisesRegexp( + ValueError, "Shape must be rank 0 but is rank 1"): + # pylint: disable=g-long-lambda + self.run_core_tests( + lambda: dataset_ops.Dataset.range(100).apply( + stats_ops.bytes_produced_stats(["bytes_produced"])), + None, 100) + # pylint: enable=g-long-lambda + + def testBytesStatsDatasetSaveableCore(self): + num_outputs = 100 + self.run_core_tests( + lambda: self._build_dataset_bytes_stats(num_outputs), + lambda: self._build_dataset_bytes_stats(num_outputs // 10), num_outputs) + + def _build_dataset_latency_stats(self, num_elements, tag="record_latency"): + return dataset_ops.Dataset.range(num_elements).apply( + stats_ops.latency_stats(tag)) + + def _build_dataset_multiple_tags(self, + num_elements, + tag1="record_latency", + tag2="record_latency_2"): + return dataset_ops.Dataset.range(num_elements).apply( + stats_ops.latency_stats(tag1)).apply(stats_ops.latency_stats(tag2)) + + def test_latency_stats_invalid_tag_shape(self): + with self.assertRaisesRegexp( + ValueError, "Shape must be rank 0 but is rank 1"): + # pylint: disable=g-long-lambda + self.run_core_tests( + lambda: dataset_ops.Dataset.range(100).apply( + stats_ops.latency_stats(["record_latency", "record_latency_2"])), + None, 100) + # pylint: enable=g-long-lambda + + def testLatencyStatsDatasetSaveableCore(self): + num_outputs = 100 + + self.run_core_tests( + lambda: self._build_dataset_latency_stats(num_outputs), + lambda: self._build_dataset_latency_stats(num_outputs // 10), + num_outputs) + + self.run_core_tests(lambda: self._build_dataset_multiple_tags(num_outputs), + None, num_outputs) + + tag1 = "record_latency" + tag2 = "record_latency" + self.run_core_tests( + lambda: self._build_dataset_multiple_tags(num_outputs, tag1, tag2), + None, num_outputs) + + def _build_dataset_stats_aggregator(self): + stats_aggregator = stats_ops.StatsAggregator() + return dataset_ops.Dataset.range(10).apply( + stats_ops.set_stats_aggregator(stats_aggregator)) + + def test_set_stats_aggregator_not_support_checkpointing(self): + with self.assertRaisesRegexp(errors.UnimplementedError, + "does not support checkpointing"): + self.run_core_tests(self._build_dataset_stats_aggregator, None, 10) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/textline_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/textline_dataset_serialization_test.py new file mode 100644 index 0000000000..c87a7443a7 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/textline_dataset_serialization_test.py @@ -0,0 +1,53 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the TextLineDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base +from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.ops import readers as core_readers +from tensorflow.python.platform import test + + +class TextLineDatasetSerializationTest( + reader_dataset_ops_test_base.TextLineDatasetTestBase, + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_iterator_graph(self, test_filenames, compression_type=None): + return core_readers.TextLineDataset( + test_filenames, compression_type=compression_type, buffer_size=10) + + def testTextLineCore(self): + compression_types = [None, "GZIP", "ZLIB"] + num_files = 5 + lines_per_file = 5 + num_outputs = num_files * lines_per_file + for compression_type in compression_types: + test_filenames = self._createFiles( + num_files, + lines_per_file, + crlf=True, + compression_type=compression_type) + # pylint: disable=cell-var-from-loop + self.run_core_tests( + lambda: self._build_iterator_graph(test_filenames, compression_type), + lambda: self._build_iterator_graph(test_filenames), num_outputs) + # pylint: enable=cell-var-from-loop + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/tf_record_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/tf_record_dataset_serialization_test.py new file mode 100644 index 0000000000..f0dcc131d4 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/tf_record_dataset_serialization_test.py @@ -0,0 +1,99 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the TFRecordDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gzip +import os +import zlib + +from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base +from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.ops import readers as core_readers +from tensorflow.python.platform import test + + +class TFRecordDatasetSerializationTest( + reader_dataset_ops_test_base.TFRecordDatasetTestBase, + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_iterator_graph(self, + num_epochs, + batch_size=1, + compression_type=None, + buffer_size=None): + filenames = self._createFiles() + if compression_type == "ZLIB": + zlib_files = [] + for i, fn in enumerate(filenames): + with open(fn, "rb") as f: + cdata = zlib.compress(f.read()) + zfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.z" % i) + with open(zfn, "wb") as f: + f.write(cdata) + zlib_files.append(zfn) + filenames = zlib_files + + elif compression_type == "GZIP": + gzip_files = [] + for i, fn in enumerate(self.test_filenames): + with open(fn, "rb") as f: + gzfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.gz" % i) + with gzip.GzipFile(gzfn, "wb") as gzf: + gzf.write(f.read()) + gzip_files.append(gzfn) + filenames = gzip_files + + return core_readers.TFRecordDataset( + filenames, compression_type, + buffer_size=buffer_size).repeat(num_epochs).batch(batch_size) + + def testTFRecordWithoutBufferCore(self): + num_epochs = 5 + batch_size = num_epochs + num_outputs = num_epochs * self._num_files * self._num_records // batch_size + # pylint: disable=g-long-lambda + self.run_core_tests( + lambda: self._build_iterator_graph(num_epochs, batch_size, + buffer_size=0), + lambda: self._build_iterator_graph(num_epochs * 2, batch_size), + num_outputs) + self.run_core_tests( + lambda: self._build_iterator_graph(num_epochs, buffer_size=0), None, + num_outputs * batch_size) + # pylint: enable=g-long-lambda + + def testTFRecordWithBufferCore(self): + num_epochs = 5 + num_outputs = num_epochs * self._num_files * self._num_records + self.run_core_tests(lambda: self._build_iterator_graph(num_epochs), + lambda: self._build_iterator_graph(num_epochs * 2), + num_outputs) + + def testTFRecordWithCompressionCore(self): + num_epochs = 5 + num_outputs = num_epochs * self._num_files * self._num_records + self.run_core_tests( + lambda: self._build_iterator_graph(num_epochs, compression_type="ZLIB"), + lambda: self._build_iterator_graph(num_epochs * 2), num_outputs) + self.run_core_tests( + lambda: self._build_iterator_graph(num_epochs, compression_type="GZIP"), + lambda: self._build_iterator_graph(num_epochs * 2), num_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/unbatch_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/unbatch_dataset_serialization_test.py new file mode 100644 index 0000000000..528598dfe4 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/unbatch_dataset_serialization_test.py @@ -0,0 +1,51 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the UnbatchDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.experimental.ops import batching +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.platform import test + + +class UnbatchDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def build_dataset(self, multiplier=15.0, tensor_slice_len=2, batch_size=2): + components = ( + np.arange(tensor_slice_len), + np.array([[1, 2, 3]]) * np.arange(tensor_slice_len)[:, np.newaxis], + np.array(multiplier) * np.arange(tensor_slice_len)) + + return dataset_ops.Dataset.from_tensor_slices(components).batch( + batch_size).apply(batching.unbatch()) + + def testCore(self): + tensor_slice_len = 8 + batch_size = 2 + num_outputs = tensor_slice_len + self.run_core_tests( + lambda: self.build_dataset(15.0, tensor_slice_len, batch_size), + lambda: self.build_dataset(20.0, tensor_slice_len, batch_size), + num_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/unique_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/unique_dataset_serialization_test.py new file mode 100644 index 0000000000..e2862af4d6 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/unique_dataset_serialization_test.py @@ -0,0 +1,40 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the UniqueDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.experimental.ops import unique +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.platform import test + + +class UniqueDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def testUnique(self): + + def build_dataset(num_elements, unique_elem_range): + return dataset_ops.Dataset.range(num_elements).map( + lambda x: x % unique_elem_range).apply(unique.unique()) + + self.run_core_tests(lambda: build_dataset(200, 100), + lambda: build_dataset(40, 100), 100) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/zip_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/zip_dataset_serialization_test.py new file mode 100644 index 0000000000..4ea6131c22 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/zip_dataset_serialization_test.py @@ -0,0 +1,54 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the ZipDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.platform import test + + +class ZipDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset(self, arr): + components = [ + np.tile(np.array([[1], [2], [3], [4]]), 20), + np.tile(np.array([[12], [13], [14], [15]]), 22), + np.array(arr) + ] + datasets = [ + dataset_ops.Dataset.from_tensor_slices(component) + for component in components + ] + return dataset_ops.Dataset.zip((datasets[0], (datasets[1], datasets[2]))) + + def testCore(self): + # Equal length components + arr = [37.0, 38.0, 39.0, 40.0] + num_outputs = len(arr) + self.run_core_tests(lambda: self._build_dataset(arr), None, num_outputs) + # Variable length components + diff_size_arr = [1.0, 2.0] + self.run_core_tests(lambda: self._build_dataset(diff_size_arr), + lambda: self._build_dataset(arr), 2) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization_integration_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization_integration_test.py new file mode 100644 index 0000000000..88d5c896c9 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/serialization_integration_test.py @@ -0,0 +1,85 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Integration test for dataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.python.data.experimental.ops import iterator_ops as contrib_iterator_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import ops +from tensorflow.python.platform import test +from tensorflow.python.training import saver as saver_lib + + +class SerializationIntegrationTest(test.TestCase): + + def _build_input_pipeline(self, name, num_outputs): + with ops.name_scope(name): + ds = dataset_ops.Dataset.range(num_outputs).shuffle( + 10, reshuffle_each_iteration=False).prefetch(10) + iterator = ds.make_initializable_iterator() + saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + return iterator.initializer, iterator.get_next() + + def _build_graph(self, num_pipelines, num_outputs): + init_ops = [] + get_next_ops = [] + for i in range(num_pipelines): + name = "input_pipeline_%d" % i + init_op, get_next_op = self._build_input_pipeline(name, num_outputs) + init_ops.append(init_op) + get_next_ops.append(get_next_op) + saver = saver_lib.Saver() + return init_ops, get_next_ops, saver + + def _ckpt_path(self): + return os.path.join(self.get_temp_dir(), "iterator") + + def testConcurrentSaves(self): + num_pipelines = 100 + num_outputs = 100 + break_point = 10 + all_outputs = [[] for _ in range(num_pipelines)] + with ops.Graph().as_default() as g: + init_ops, get_next_ops, saver = self._build_graph(num_pipelines, + num_outputs) + with self.session(graph=g) as sess: + sess.run(init_ops) + for _ in range(break_point): + output = sess.run(get_next_ops) + for i in range(num_pipelines): + all_outputs[i].append(output[i]) + saver.save(sess, self._ckpt_path()) + + with ops.Graph().as_default() as g: + init_ops, get_next_ops, saver = self._build_graph(num_pipelines, + num_outputs) + with self.session(graph=g) as sess: + saver.restore(sess, self._ckpt_path()) + for _ in range(num_outputs - break_point): + output = sess.run(get_next_ops) + for i in range(num_pipelines): + all_outputs[i].append(output[i]) + + for output in all_outputs: + self.assertSequenceEqual(sorted(output), range(num_outputs)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/shuffle_dataset_op_test.py new file mode 100644 index 0000000000..50895b5945 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/shuffle_dataset_op_test.py @@ -0,0 +1,115 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.data.experimental.ops import shuffle_ops +from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.platform import test + + +class ShuffleAndRepeatTest(test_base.DatasetTestBase): + + def _build_ds(self, seed, count=5, num_elements=20): + return dataset_ops.Dataset.range(num_elements).apply( + shuffle_ops.shuffle_and_repeat(buffer_size=5, count=count, seed=seed)) + + def _gen_outputs(self, ds_fn, num_outputs, verify_exhausted=True): + get_next = ds_fn().make_one_shot_iterator().get_next() + outputs = [] + with self.cached_session() as sess: + for _ in range(num_outputs): + outputs.append(sess.run(get_next)) + if verify_exhausted: + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + return outputs + + def testCorrectOutput(self): + output = self._gen_outputs(lambda: self._build_ds(10), 100) + self.assertSequenceEqual( + sorted(output), sorted( + np.array([range(20) for _ in range(5)]).flatten())) + for i in range(5): + self.assertSequenceEqual(sorted(output[i * 20:(i + 1) * 20]), range(20)) + + def testReshuffling(self): + # Check that the output orders of different epochs are indeed different. + output = self._gen_outputs(lambda: self._build_ds(10), 100) + for i in range(4): + epoch1 = output[i * 20:(i + 1) * 20] + epoch2 = output[(i + 1) * 20:(i + 2) * 20] + self.assertNotEqual(epoch1, epoch2) + + def testSameOrderForSameSeeds(self): + output1 = self._gen_outputs(lambda: self._build_ds(10), 100) + output2 = self._gen_outputs(lambda: self._build_ds(10), 100) + self.assertEqual(output1, output2) + + def testDifferentOrderForDifferentSeeds(self): + output1 = self._gen_outputs(lambda: self._build_ds(10), 100) + output2 = self._gen_outputs(lambda: self._build_ds(20), 100) + self.assertNotEqual(output1, output2) + self.assertEqual(sorted(output1), sorted(output2)) + + def testCountNone(self): + output1 = self._gen_outputs( + lambda: self._build_ds(10, count=None), 100, verify_exhausted=False) + output2 = self._gen_outputs( + lambda: self._build_ds(20, count=None), 100, verify_exhausted=False) + self.assertNotEqual(output1, output2) + self.assertEqual(sorted(output1), sorted(output2)) + + def testCountMinusOne(self): + output1 = self._gen_outputs( + lambda: self._build_ds(10, count=-1), 100, verify_exhausted=False) + output2 = self._gen_outputs( + lambda: self._build_ds(20, count=-1), 100, verify_exhausted=False) + self.assertNotEqual(output1, output2) + self.assertEqual(sorted(output1), sorted(output2)) + + def testInfiniteOutputs(self): + # Asserting the iterator is exhausted after producing 100 items should fail. + with self.assertRaises(AssertionError): + self._gen_outputs(lambda: self._build_ds(10, count=None), 100) + with self.assertRaises(AssertionError): + self._gen_outputs(lambda: self._build_ds(10, count=-1), 100) + + def testInfiniteEmpty(self): + with self.assertRaises(errors.OutOfRangeError): + self._gen_outputs(lambda: self._build_ds(10, count=None, num_elements=0), + 100) + with self.assertRaises(errors.OutOfRangeError): + self._gen_outputs(lambda: self._build_ds(10, count=-1, num_elements=0), + 100) + + def testLargeBufferSize(self): + with ops.Graph().as_default() as g: + ds = dataset_ops.Dataset.range(20).apply( + shuffle_ops.shuffle_and_repeat(buffer_size=21)) + get_next_op = ds.make_one_shot_iterator().get_next() + with self.session(graph=g) as sess: + sess.run(get_next_op) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/sql_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/sql_dataset_op_test.py new file mode 100644 index 0000000000..301f75488a --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/sql_dataset_op_test.py @@ -0,0 +1,590 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for experimental sql input op.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.experimental.kernel_tests import sql_dataset_op_test_base +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.platform import test + + +class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): + + # Test that SqlDataset can read from a database table. + def testReadResultSet(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, + dtypes.string), 2) + with self.cached_session() as sess: + for _ in range(2): # Run twice to verify statelessness of db operations. + sess.run( + init_op, + feed_dict={ + self.query: "SELECT first_name, last_name, motto FROM students " + "ORDER BY first_name DESC" + }) + for _ in range(2): # Dataset is repeated. See setUp. + self.assertEqual((b"John", b"Doe", b"Hi!"), sess.run(get_next)) + self.assertEqual((b"Jane", b"Moe", b"Hi again!"), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test that SqlDataset works on a join query. + def testReadResultSetJoinQuery(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, + dtypes.string)) + with self.cached_session() as sess: + sess.run( + init_op, + feed_dict={ + self.query: + "SELECT students.first_name, state, motto FROM students " + "INNER JOIN people " + "ON students.first_name = people.first_name " + "AND students.last_name = people.last_name" + }) + self.assertEqual((b"John", b"California", b"Hi!"), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test that SqlDataset can read a database entry with a null-terminator + # in the middle of the text and place the entry in a `string` tensor. + def testReadResultSetNullTerminator(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, + dtypes.string)) + with self.cached_session() as sess: + sess.run( + init_op, + feed_dict={ + self.query: + "SELECT first_name, last_name, favorite_nonsense_word " + "FROM students ORDER BY first_name DESC" + }) + self.assertEqual((b"John", b"Doe", b"n\0nsense"), sess.run(get_next)) + self.assertEqual((b"Jane", b"Moe", b"nonsense\0"), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test that SqlDataset works when used on two different queries. + # Because the output types of the dataset must be determined at graph-creation + # time, the two queries must have the same number and types of columns. + def testReadResultSetReuseSqlDataset(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, + dtypes.string)) + with self.cached_session() as sess: + sess.run( + init_op, + feed_dict={ + self.query: "SELECT first_name, last_name, motto FROM students " + "ORDER BY first_name DESC" + }) + self.assertEqual((b"John", b"Doe", b"Hi!"), sess.run(get_next)) + self.assertEqual((b"Jane", b"Moe", b"Hi again!"), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + sess.run( + init_op, + feed_dict={ + self.query: "SELECT first_name, last_name, state FROM people " + "ORDER BY first_name DESC" + }) + self.assertEqual((b"John", b"Doe", b"California"), sess.run(get_next)) + self.assertEqual((b"Benjamin", b"Franklin", b"Pennsylvania"), + sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test that an `OutOfRangeError` is raised on the first call to + # `get_next_str_only` if result set is empty. + def testReadEmptyResultSet(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, + dtypes.string)) + with self.cached_session() as sess: + sess.run( + init_op, + feed_dict={ + self.query: "SELECT first_name, last_name, motto FROM students " + "WHERE first_name = 'Nonexistent'" + }) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test that an error is raised when `driver_name` is invalid. + def testReadResultSetWithInvalidDriverName(self): + init_op = self._createSqlDataset((dtypes.string, dtypes.string, + dtypes.string))[0] + with self.cached_session() as sess: + with self.assertRaises(errors.InvalidArgumentError): + sess.run( + init_op, + feed_dict={ + self.driver_name: "sqlfake", + self.query: "SELECT first_name, last_name, motto FROM students " + "ORDER BY first_name DESC" + }) + + # Test that an error is raised when a column name in `query` is nonexistent + def testReadResultSetWithInvalidColumnName(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, + dtypes.string)) + with self.cached_session() as sess: + sess.run( + init_op, + feed_dict={ + self.query: + "SELECT first_name, last_name, fake_column FROM students " + "ORDER BY first_name DESC" + }) + with self.assertRaises(errors.UnknownError): + sess.run(get_next) + + # Test that an error is raised when there is a syntax error in `query`. + def testReadResultSetOfQueryWithSyntaxError(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, + dtypes.string)) + with self.cached_session() as sess: + sess.run( + init_op, + feed_dict={ + self.query: + "SELEmispellECT first_name, last_name, motto FROM students " + "ORDER BY first_name DESC" + }) + with self.assertRaises(errors.UnknownError): + sess.run(get_next) + + # Test that an error is raised when the number of columns in `query` + # does not match the length of `output_types`. + def testReadResultSetWithMismatchBetweenColumnsAndOutputTypes(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, + dtypes.string)) + with self.cached_session() as sess: + sess.run( + init_op, + feed_dict={ + self.query: "SELECT first_name, last_name FROM students " + "ORDER BY first_name DESC" + }) + with self.assertRaises(errors.InvalidArgumentError): + sess.run(get_next) + + # Test that no results are returned when `query` is an insert query rather + # than a select query. In particular, the error refers to the number of + # output types passed to the op not matching the number of columns in the + # result set of the query (namely, 0 for an insert statement.) + def testReadResultSetOfInsertQuery(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, + dtypes.string)) + with self.cached_session() as sess: + sess.run( + init_op, + feed_dict={ + self.query: + "INSERT INTO students (first_name, last_name, motto) " + "VALUES ('Foo', 'Bar', 'Baz'), ('Fizz', 'Buzz', 'Fizzbuzz')" + }) + with self.assertRaises(errors.InvalidArgumentError): + sess.run(get_next) + + # Test that `SqlDataset` can read an integer from a SQLite database table and + # place it in an `int8` tensor. + def testReadResultSetInt8(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int8)) + with self.cached_session() as sess: + sess.run( + init_op, + feed_dict={ + self.query: "SELECT first_name, desk_number FROM students " + "ORDER BY first_name DESC" + }) + self.assertEqual((b"John", 9), sess.run(get_next)) + self.assertEqual((b"Jane", 127), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test that `SqlDataset` can read a negative or 0-valued integer from a + # SQLite database table and place it in an `int8` tensor. + def testReadResultSetInt8NegativeAndZero(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int8, + dtypes.int8)) + with self.cached_session() as sess: + sess.run( + init_op, + feed_dict={ + self.query: "SELECT first_name, income, favorite_negative_number " + "FROM students " + "WHERE first_name = 'John' ORDER BY first_name DESC" + }) + self.assertEqual((b"John", 0, -2), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test that `SqlDataset` can read a large (positive or negative) integer from + # a SQLite database table and place it in an `int8` tensor. + def testReadResultSetInt8MaxValues(self): + init_op, get_next = self._createSqlDataset((dtypes.int8, dtypes.int8)) + with self.cached_session() as sess: + sess.run( + init_op, + feed_dict={ + self.query: + "SELECT desk_number, favorite_negative_number FROM students " + "ORDER BY first_name DESC" + }) + self.assertEqual((9, -2), sess.run(get_next)) + # Max and min values of int8 + self.assertEqual((127, -128), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test that `SqlDataset` can read an integer from a SQLite database table and + # place it in an `int16` tensor. + def testReadResultSetInt16(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16)) + with self.cached_session() as sess: + sess.run( + init_op, + feed_dict={ + self.query: "SELECT first_name, desk_number FROM students " + "ORDER BY first_name DESC" + }) + self.assertEqual((b"John", 9), sess.run(get_next)) + self.assertEqual((b"Jane", 127), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test that `SqlDataset` can read a negative or 0-valued integer from a + # SQLite database table and place it in an `int16` tensor. + def testReadResultSetInt16NegativeAndZero(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16, + dtypes.int16)) + with self.cached_session() as sess: + sess.run( + init_op, + feed_dict={ + self.query: "SELECT first_name, income, favorite_negative_number " + "FROM students " + "WHERE first_name = 'John' ORDER BY first_name DESC" + }) + self.assertEqual((b"John", 0, -2), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test that `SqlDataset` can read a large (positive or negative) integer from + # a SQLite database table and place it in an `int16` tensor. + def testReadResultSetInt16MaxValues(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16)) + with self.cached_session() as sess: + sess.run( + init_op, + feed_dict={ + self.query: "SELECT first_name, favorite_medium_sized_number " + "FROM students ORDER BY first_name DESC" + }) + # Max value of int16 + self.assertEqual((b"John", 32767), sess.run(get_next)) + # Min value of int16 + self.assertEqual((b"Jane", -32768), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test that `SqlDataset` can read an integer from a SQLite database table and + # place it in an `int32` tensor. + def testReadResultSetInt32(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32)) + with self.cached_session() as sess: + sess.run( + init_op, + feed_dict={ + self.query: "SELECT first_name, desk_number FROM students " + "ORDER BY first_name DESC" + }) + self.assertEqual((b"John", 9), sess.run(get_next)) + self.assertEqual((b"Jane", 127), sess.run(get_next)) + + # Test that `SqlDataset` can read a negative or 0-valued integer from a + # SQLite database table and place it in an `int32` tensor. + def testReadResultSetInt32NegativeAndZero(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32)) + with self.cached_session() as sess: + sess.run( + init_op, + feed_dict={ + self.query: "SELECT first_name, income FROM students " + "ORDER BY first_name DESC" + }) + self.assertEqual((b"John", 0), sess.run(get_next)) + self.assertEqual((b"Jane", -20000), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test that `SqlDataset` can read a large (positive or negative) integer from + # a SQLite database table and place it in an `int32` tensor. + def testReadResultSetInt32MaxValues(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32)) + with self.cached_session() as sess: + sess.run( + init_op, + feed_dict={ + self.query: "SELECT first_name, favorite_number FROM students " + "ORDER BY first_name DESC" + }) + # Max value of int32 + self.assertEqual((b"John", 2147483647), sess.run(get_next)) + # Min value of int32 + self.assertEqual((b"Jane", -2147483648), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test that `SqlDataset` can read a numeric `varchar` from a SQLite database + # table and place it in an `int32` tensor. + def testReadResultSetInt32VarCharColumnAsInt(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32)) + with self.cached_session() as sess: + sess.run( + init_op, + feed_dict={ + self.query: "SELECT first_name, school_id FROM students " + "ORDER BY first_name DESC" + }) + self.assertEqual((b"John", 123), sess.run(get_next)) + self.assertEqual((b"Jane", 1000), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test that `SqlDataset` can read an integer from a SQLite database table + # and place it in an `int64` tensor. + def testReadResultSetInt64(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64)) + with self.cached_session() as sess: + sess.run( + init_op, + feed_dict={ + self.query: "SELECT first_name, desk_number FROM students " + "ORDER BY first_name DESC" + }) + self.assertEqual((b"John", 9), sess.run(get_next)) + self.assertEqual((b"Jane", 127), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test that `SqlDataset` can read a negative or 0-valued integer from a + # SQLite database table and place it in an `int64` tensor. + def testReadResultSetInt64NegativeAndZero(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64)) + with self.cached_session() as sess: + sess.run( + init_op, + feed_dict={ + self.query: "SELECT first_name, income FROM students " + "ORDER BY first_name DESC" + }) + self.assertEqual((b"John", 0), sess.run(get_next)) + self.assertEqual((b"Jane", -20000), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test that `SqlDataset` can read a large (positive or negative) integer from + # a SQLite database table and place it in an `int64` tensor. + def testReadResultSetInt64MaxValues(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64)) + with self.cached_session() as sess: + sess.run( + init_op, + feed_dict={ + self.query: + "SELECT first_name, favorite_big_number FROM students " + "ORDER BY first_name DESC" + }) + # Max value of int64 + self.assertEqual((b"John", 9223372036854775807), sess.run(get_next)) + # Min value of int64 + self.assertEqual((b"Jane", -9223372036854775808), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test that `SqlDataset` can read an integer from a SQLite database table and + # place it in a `uint8` tensor. + def testReadResultSetUInt8(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint8)) + with self.cached_session() as sess: + sess.run( + init_op, + feed_dict={ + self.query: "SELECT first_name, desk_number FROM students " + "ORDER BY first_name DESC" + }) + self.assertEqual((b"John", 9), sess.run(get_next)) + self.assertEqual((b"Jane", 127), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test that `SqlDataset` can read the minimum and maximum uint8 values from a + # SQLite database table and place them in `uint8` tensors. + def testReadResultSetUInt8MinAndMaxValues(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint8)) + with self.cached_session() as sess: + sess.run( + init_op, + feed_dict={ + self.query: "SELECT first_name, brownie_points FROM students " + "ORDER BY first_name DESC" + }) + # Min value of uint8 + self.assertEqual((b"John", 0), sess.run(get_next)) + # Max value of uint8 + self.assertEqual((b"Jane", 255), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test that `SqlDataset` can read an integer from a SQLite database table + # and place it in a `uint16` tensor. + def testReadResultSetUInt16(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint16)) + with self.cached_session() as sess: + sess.run( + init_op, + feed_dict={ + self.query: "SELECT first_name, desk_number FROM students " + "ORDER BY first_name DESC" + }) + self.assertEqual((b"John", 9), sess.run(get_next)) + self.assertEqual((b"Jane", 127), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test that `SqlDataset` can read the minimum and maximum uint16 values from a + # SQLite database table and place them in `uint16` tensors. + def testReadResultSetUInt16MinAndMaxValues(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint16)) + with self.cached_session() as sess: + sess.run( + init_op, + feed_dict={ + self.query: "SELECT first_name, account_balance FROM students " + "ORDER BY first_name DESC" + }) + # Min value of uint16 + self.assertEqual((b"John", 0), sess.run(get_next)) + # Max value of uint16 + self.assertEqual((b"Jane", 65535), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test that `SqlDataset` can read a 0-valued and 1-valued integer from a + # SQLite database table and place them as `True` and `False` respectively + # in `bool` tensors. + def testReadResultSetBool(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.bool)) + with self.cached_session() as sess: + sess.run( + init_op, + feed_dict={ + self.query: + "SELECT first_name, registration_complete FROM students " + "ORDER BY first_name DESC" + }) + self.assertEqual((b"John", True), sess.run(get_next)) + self.assertEqual((b"Jane", False), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test that `SqlDataset` can read an integer that is not 0-valued or 1-valued + # from a SQLite database table and place it as `True` in a `bool` tensor. + def testReadResultSetBoolNotZeroOrOne(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.bool)) + with self.cached_session() as sess: + sess.run( + init_op, + feed_dict={ + self.query: "SELECT first_name, favorite_medium_sized_number " + "FROM students ORDER BY first_name DESC" + }) + self.assertEqual((b"John", True), sess.run(get_next)) + self.assertEqual((b"Jane", True), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test that `SqlDataset` can read a float from a SQLite database table + # and place it in a `float64` tensor. + def testReadResultSetFloat64(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, + dtypes.float64)) + with self.cached_session() as sess: + sess.run( + init_op, + feed_dict={ + self.query: + "SELECT first_name, last_name, victories FROM townspeople " + "ORDER BY first_name" + }) + self.assertEqual((b"George", b"Washington", 20.0), sess.run(get_next)) + self.assertEqual((b"John", b"Adams", -19.95), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test that `SqlDataset` can read a float from a SQLite database table beyond + # the precision of 64-bit IEEE, without throwing an error. Test that + # `SqlDataset` identifies such a value as equal to itself. + def testReadResultSetFloat64OverlyPrecise(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, + dtypes.float64)) + with self.cached_session() as sess: + sess.run( + init_op, + feed_dict={ + self.query: + "SELECT first_name, last_name, accolades FROM townspeople " + "ORDER BY first_name" + }) + self.assertEqual( + (b"George", b"Washington", + 1331241.321342132321324589798264627463827647382647382643874), + sess.run(get_next)) + self.assertEqual( + (b"John", b"Adams", + 1331241321342132321324589798264627463827647382647382643874.0), + sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test that `SqlDataset` can read a float from a SQLite database table, + # representing the largest integer representable as a 64-bit IEEE float + # such that the previous integer is also representable as a 64-bit IEEE float. + # Test that `SqlDataset` can distinguish these two numbers. + def testReadResultSetFloat64LargestConsecutiveWholeNumbersNotEqual(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, + dtypes.float64)) + with self.cached_session() as sess: + sess.run( + init_op, + feed_dict={ + self.query: + "SELECT first_name, last_name, triumphs FROM townspeople " + "ORDER BY first_name" + }) + self.assertNotEqual((b"George", b"Washington", 9007199254740992.0), + sess.run(get_next)) + self.assertNotEqual((b"John", b"Adams", 9007199254740991.0), + sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/sql_dataset_op_test_base.py b/tensorflow/python/data/experimental/kernel_tests/sql_dataset_op_test_base.py new file mode 100644 index 0000000000..a135c357f0 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/sql_dataset_op_test_base.py @@ -0,0 +1,95 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Base class for testing SqlDataset.""" + + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import sqlite3 + +from tensorflow.python.data.experimental.ops import readers +from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class SqlDatasetTestBase(test_base.DatasetTestBase): + """Base class for setting up and testing SqlDataset.""" + + def _createSqlDataset(self, output_types, num_repeats=1): + dataset = readers.SqlDataset(self.driver_name, self.data_source_name, + self.query, output_types).repeat(num_repeats) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + return init_op, get_next + + def setUp(self): + self.data_source_name = os.path.join(test.get_temp_dir(), "tftest.sqlite") + self.driver_name = array_ops.placeholder_with_default( + array_ops.constant("sqlite", dtypes.string), shape=[]) + self.query = array_ops.placeholder(dtypes.string, shape=[]) + + conn = sqlite3.connect(self.data_source_name) + c = conn.cursor() + c.execute("DROP TABLE IF EXISTS students") + c.execute("DROP TABLE IF EXISTS people") + c.execute("DROP TABLE IF EXISTS townspeople") + c.execute( + "CREATE TABLE IF NOT EXISTS students (id INTEGER NOT NULL PRIMARY KEY, " + "first_name VARCHAR(100), last_name VARCHAR(100), motto VARCHAR(100), " + "school_id VARCHAR(100), favorite_nonsense_word VARCHAR(100), " + "desk_number INTEGER, income INTEGER, favorite_number INTEGER, " + "favorite_big_number INTEGER, favorite_negative_number INTEGER, " + "favorite_medium_sized_number INTEGER, brownie_points INTEGER, " + "account_balance INTEGER, registration_complete INTEGER)") + c.executemany( + "INSERT INTO students (first_name, last_name, motto, school_id, " + "favorite_nonsense_word, desk_number, income, favorite_number, " + "favorite_big_number, favorite_negative_number, " + "favorite_medium_sized_number, brownie_points, account_balance, " + "registration_complete) " + "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + [("John", "Doe", "Hi!", "123", "n\0nsense", 9, 0, 2147483647, + 9223372036854775807, -2, 32767, 0, 0, 1), + ("Jane", "Moe", "Hi again!", "1000", "nonsense\0", 127, -20000, + -2147483648, -9223372036854775808, -128, -32768, 255, 65535, 0)]) + c.execute( + "CREATE TABLE IF NOT EXISTS people (id INTEGER NOT NULL PRIMARY KEY, " + "first_name VARCHAR(100), last_name VARCHAR(100), state VARCHAR(100))") + c.executemany( + "INSERT INTO PEOPLE (first_name, last_name, state) VALUES (?, ?, ?)", + [("Benjamin", "Franklin", "Pennsylvania"), ("John", "Doe", + "California")]) + c.execute( + "CREATE TABLE IF NOT EXISTS townspeople (id INTEGER NOT NULL PRIMARY " + "KEY, first_name VARCHAR(100), last_name VARCHAR(100), victories " + "FLOAT, accolades FLOAT, triumphs FLOAT)") + c.executemany( + "INSERT INTO townspeople (first_name, last_name, victories, " + "accolades, triumphs) VALUES (?, ?, ?, ?, ?)", + [("George", "Washington", 20.00, + 1331241.321342132321324589798264627463827647382647382643874, + 9007199254740991.0), + ("John", "Adams", -19.95, + 1331241321342132321324589798264627463827647382647382643874.0, + 9007199254740992.0)]) + conn.commit() + conn.close() diff --git a/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py new file mode 100644 index 0000000000..6761fbd16b --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py @@ -0,0 +1,253 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline statistics gathering ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.data.experimental.kernel_tests import stats_dataset_test_base +from tensorflow.python.data.experimental.ops import stats_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase): + + def testBytesProduced(self): + stats_aggregator = stats_ops.StatsAggregator() + dataset = dataset_ops.Dataset.range(100).map( + lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).apply( + stats_ops.bytes_produced_stats("bytes_produced")).apply( + stats_ops.set_stats_aggregator(stats_aggregator)) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + summary_t = stats_aggregator.get_summary() + + with self.cached_session() as sess: + sess.run(iterator.initializer) + expected_sum = 0.0 + for i in range(100): + self.assertAllEqual( + np.array([i] * i, dtype=np.int64), sess.run(next_element)) + summary_str = sess.run(summary_t) + self._assertSummaryHasCount(summary_str, "bytes_produced", float(i + 1)) + expected_sum += i * 8.0 + self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + summary_str = sess.run(summary_t) + self._assertSummaryHasCount(summary_str, "bytes_produced", 100.0) + self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum) + + def testLatencyStats(self): + stats_aggregator = stats_ops.StatsAggregator() + dataset = dataset_ops.Dataset.range(100).apply( + stats_ops.latency_stats("record_latency")).apply( + stats_ops.set_stats_aggregator(stats_aggregator)) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + summary_t = stats_aggregator.get_summary() + + with self.cached_session() as sess: + sess.run(iterator.initializer) + for i in range(100): + self.assertEqual(i, sess.run(next_element)) + self._assertSummaryHasCount( + sess.run(summary_t), "record_latency", float(i + 1)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 100.0) + + def testPrefetchBufferUtilization(self): + stats_aggregator = stats_ops.StatsAggregator() + dataset = dataset_ops.Dataset.range(100).map( + lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).prefetch( + -1).apply(stats_ops.set_stats_aggregator(stats_aggregator)) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + summary_t = stats_aggregator.get_summary() + + with self.cached_session() as sess: + sess.run(iterator.initializer) + for i in range(100): + self.assertAllEqual( + np.array([i] * i, dtype=np.int64), sess.run(next_element)) + summary_str = sess.run(summary_t) + self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization", + float(i + 1)) + self._assertSummaryContains(summary_str, "Prefetch::buffer_capacity") + self._assertSummaryContains(summary_str, "Prefetch::buffer_size") + self._assertSummaryHasRange(summary_str, "Prefetch::buffer_utilization", + 0, 1) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + summary_str = sess.run(summary_t) + self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization", + 100) + + def testPrefetchBufferScalars(self): + stats_aggregator = stats_ops.StatsAggregator() + dataset = dataset_ops.Dataset.range(10).map( + lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).prefetch( + 0).apply(stats_ops.set_stats_aggregator(stats_aggregator)) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + summary_t = stats_aggregator.get_summary() + + with self.cached_session() as sess: + sess.run(iterator.initializer) + for i in range(10): + self.assertAllEqual( + np.array([i] * i, dtype=np.int64), sess.run(next_element)) + summary_str = sess.run(summary_t) + self._assertSummaryHasScalarValue(summary_str, + "Prefetch::buffer_capacity", 0) + self._assertSummaryHasScalarValue(summary_str, "Prefetch::buffer_size", + 0) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testFilteredElementsStats(self): + stats_aggregator = stats_ops.StatsAggregator() + dataset = dataset_ops.Dataset.range(101).filter( + lambda x: math_ops.equal(math_ops.mod(x, 3), 0)).apply( + stats_ops.set_stats_aggregator(stats_aggregator)) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + summary_t = stats_aggregator.get_summary() + + with self.test_session() as sess: + sess.run(iterator.initializer) + for i in range(34): + self.assertEqual(i * 3, sess.run(next_element)) + if i is not 0: + self._assertSummaryHasScalarValue( + sess.run(summary_t), "Filter::dropped_elements", float(i * 2)) + self._assertSummaryHasScalarValue( + sess.run(summary_t), "Filter::filtered_elements", float(i + 1)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + self._assertSummaryHasScalarValue( + sess.run(summary_t), "Filter::dropped_elements", 67.0) + self._assertSummaryHasScalarValue( + sess.run(summary_t), "Filter::filtered_elements", 34.0) + + def testReinitialize(self): + stats_aggregator = stats_ops.StatsAggregator() + dataset = dataset_ops.Dataset.range(100).apply( + stats_ops.latency_stats("record_latency")).apply( + stats_ops.set_stats_aggregator(stats_aggregator)) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + summary_t = stats_aggregator.get_summary() + + with self.cached_session() as sess: + for j in range(5): + sess.run(iterator.initializer) + for i in range(100): + self.assertEqual(i, sess.run(next_element)) + self._assertSummaryHasCount( + sess.run(summary_t), "record_latency", float((j * 100) + i + 1)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + self._assertSummaryHasCount( + sess.run(summary_t), "record_latency", (j + 1) * 100.0) + + def testNoAggregatorRegistered(self): + dataset = dataset_ops.Dataset.range(100).apply( + stats_ops.latency_stats("record_latency")) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.cached_session() as sess: + sess.run(iterator.initializer) + for i in range(100): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testMultipleTags(self): + stats_aggregator = stats_ops.StatsAggregator() + dataset = dataset_ops.Dataset.range(100).apply( + stats_ops.latency_stats("record_latency")).apply( + stats_ops.latency_stats("record_latency_2")).apply( + stats_ops.set_stats_aggregator(stats_aggregator)) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + summary_t = stats_aggregator.get_summary() + + with self.cached_session() as sess: + sess.run(iterator.initializer) + for i in range(100): + self.assertEqual(i, sess.run(next_element)) + self._assertSummaryHasCount( + sess.run(summary_t), "record_latency", float(i + 1)) + self._assertSummaryHasCount( + sess.run(summary_t), "record_latency_2", float(i + 1)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 100.0) + self._assertSummaryHasCount( + sess.run(summary_t), "record_latency_2", 100.0) + + def testRepeatedTags(self): + stats_aggregator = stats_ops.StatsAggregator() + dataset = dataset_ops.Dataset.range(100).apply( + stats_ops.latency_stats("record_latency")).apply( + stats_ops.latency_stats("record_latency")).apply( + stats_ops.set_stats_aggregator(stats_aggregator)) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + summary_t = stats_aggregator.get_summary() + + with self.cached_session() as sess: + sess.run(iterator.initializer) + for i in range(100): + self.assertEqual(i, sess.run(next_element)) + self._assertSummaryHasCount( + sess.run(summary_t), "record_latency", float(2 * (i + 1))) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0) + + def testMultipleIteratorsSameAggregator(self): + stats_aggregator = stats_ops.StatsAggregator() + dataset = dataset_ops.Dataset.range(100).apply( + stats_ops.latency_stats("record_latency")).apply( + stats_ops.set_stats_aggregator(stats_aggregator)) + iterator_0 = dataset.make_initializable_iterator() + iterator_1 = dataset.make_initializable_iterator() + next_element = iterator_0.get_next() + iterator_1.get_next() + summary_t = stats_aggregator.get_summary() + + with self.cached_session() as sess: + sess.run([iterator_0.initializer, iterator_1.initializer]) + for i in range(100): + self.assertEqual(i * 2, sess.run(next_element)) + self._assertSummaryHasCount( + sess.run(summary_t), "record_latency", float(2 * (i + 1))) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/stats_dataset_test_base.py b/tensorflow/python/data/experimental/kernel_tests/stats_dataset_test_base.py new file mode 100644 index 0000000000..80f2625927 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/stats_dataset_test_base.py @@ -0,0 +1,71 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Base class for testing the input pipeline statistics gathering ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +from tensorflow.core.framework import summary_pb2 +from tensorflow.python.data.kernel_tests import test_base + + +class StatsDatasetTestBase(test_base.DatasetTestBase): + """Base class for testing statistics gathered in `StatsAggregator`.""" + + def _assertSummaryContains(self, summary_str, tag): + summary_proto = summary_pb2.Summary() + summary_proto.ParseFromString(summary_str) + for value in summary_proto.value: + if tag == value.tag: + return + self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) + + def _assertSummaryHasCount(self, summary_str, tag, expected_value): + summary_proto = summary_pb2.Summary() + summary_proto.ParseFromString(summary_str) + for value in summary_proto.value: + if tag == value.tag: + self.assertEqual(expected_value, value.histo.num) + return + self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) + + def _assertSummaryHasRange(self, summary_str, tag, min_value, max_value): + summary_proto = summary_pb2.Summary() + summary_proto.ParseFromString(summary_str) + for value in summary_proto.value: + if tag == value.tag: + self.assertLessEqual(min_value, value.histo.min) + self.assertGreaterEqual(max_value, value.histo.max) + return + self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) + + def _assertSummaryHasSum(self, summary_str, tag, expected_value): + summary_proto = summary_pb2.Summary() + summary_proto.ParseFromString(summary_str) + for value in summary_proto.value: + if tag == value.tag: + self.assertEqual(expected_value, value.histo.sum) + return + self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) + + def _assertSummaryHasScalarValue(self, summary_str, tag, expected_value): + summary_proto = summary_pb2.Summary() + summary_proto.ParseFromString(summary_str) + for value in summary_proto.value: + if tag == value.tag: + self.assertEqual(expected_value, value.simple_value) + return + self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) diff --git a/tensorflow/python/data/experimental/kernel_tests/threadpool_dataset_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/threadpool_dataset_ops_test.py new file mode 100644 index 0000000000..4432dcb05a --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/threadpool_dataset_ops_test.py @@ -0,0 +1,91 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline statistics gathering ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import threading + +from absl.testing import parameterized +import numpy as np + +from tensorflow.python.data.experimental.ops import threadpool +from tensorflow.python.data.experimental.ops import unique +from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.ops import script_ops +from tensorflow.python.platform import test + + +class OverrideThreadpoolDatasetTest(test_base.DatasetTestBase, + parameterized.TestCase): + + @parameterized.named_parameters( + ("1", 1, None), + ("2", 2, None), + ("3", 4, None), + ("4", 8, None), + ("5", 16, None), + ("6", 4, -1), + ("7", 4, 0), + ("8", 4, 1), + ("9", 4, 4), + ) + def testNumThreads(self, num_threads, max_intra_op_parallelism): + + def get_thread_id(_): + # Python creates a dummy thread object to represent the current + # thread when called from an "alien" thread (such as a + # `PrivateThreadPool` thread in this case). It does not include + # the TensorFlow-given display name, but it has a unique + # identifier that maps one-to-one with the underlying OS thread. + return np.array(threading.current_thread().ident).astype(np.int64) + + dataset = ( + dataset_ops.Dataset.range(1000).map( + lambda x: script_ops.py_func(get_thread_id, [x], dtypes.int64), + num_parallel_calls=32).apply(unique.unique())) + + dataset = threadpool.override_threadpool( + dataset, + threadpool.PrivateThreadPool( + num_threads, + max_intra_op_parallelism=max_intra_op_parallelism, + display_name="private_thread_pool_%d" % num_threads)) + + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.cached_session() as sess: + sess.run(iterator.initializer) + thread_ids = [] + try: + while True: + thread_ids.append(sess.run(next_element)) + except errors.OutOfRangeError: + pass + self.assertEqual(len(thread_ids), len(set(thread_ids))) + self.assertGreater(len(thread_ids), 0) + # NOTE(mrry): We don't control the thread pool scheduling, and + # so cannot guarantee that all of the threads in the pool will + # perform work. + self.assertLessEqual(len(thread_ids), num_threads) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/unique_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/unique_dataset_op_test.py new file mode 100644 index 0000000000..b5a0b20f3f --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/unique_dataset_op_test.py @@ -0,0 +1,83 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.experimental.ops import unique +from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.platform import test +from tensorflow.python.util import compat + + +class UniqueDatasetTest(test_base.DatasetTestBase): + + def _testSimpleHelper(self, dtype, test_cases): + """Test the `unique()` transformation on a list of test cases. + + Args: + dtype: The `dtype` of the elements in each test case. + test_cases: A list of pairs of lists. The first component is the test + input that will be passed to the transformation; the second component + is the expected sequence of outputs from the transformation. + """ + + # The `current_test_case` will be updated when we loop over `test_cases` + # below; declare it here so that the generator can capture it once. + current_test_case = [] + dataset = dataset_ops.Dataset.from_generator(lambda: current_test_case, + dtype).apply(unique.unique()) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.cached_session() as sess: + for test_case, expected in test_cases: + current_test_case = test_case + sess.run(iterator.initializer) + for element in expected: + if dtype == dtypes.string: + element = compat.as_bytes(element) + self.assertAllEqual(element, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testSimpleInt(self): + for dtype in [dtypes.int32, dtypes.int64]: + self._testSimpleHelper(dtype, [ + ([], []), + ([1], [1]), + ([1, 1, 1, 1, 1, 1, 1], [1]), + ([1, 2, 3, 4], [1, 2, 3, 4]), + ([1, 2, 4, 3, 2, 1, 2, 3, 4], [1, 2, 4, 3]), + ([[1], [1, 1], [1, 1, 1]], [[1], [1, 1], [1, 1, 1]]), + ([[1, 1], [1, 1], [2, 2], [3, 3], [1, 1]], [[1, 1], [2, 2], [3, 3]]), + ]) + + def testSimpleString(self): + self._testSimpleHelper(dtypes.string, [ + ([], []), + (["hello"], ["hello"]), + (["hello", "hello", "hello"], ["hello"]), + (["hello", "world"], ["hello", "world"]), + (["foo", "bar", "baz", "baz", "bar", "foo"], ["foo", "bar", "baz"]), + ]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/writer_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/writer_ops_test.py new file mode 100644 index 0000000000..25a2e63ba1 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/writer_ops_test.py @@ -0,0 +1,118 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.python.data.experimental.ops import writers +from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import readers +from tensorflow.python.framework import dtypes +from tensorflow.python.lib.io import python_io +from tensorflow.python.lib.io import tf_record +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test +from tensorflow.python.util import compat + + +class TFRecordWriterTest(test_base.DatasetTestBase): + + def setUp(self): + super(TFRecordWriterTest, self).setUp() + self._num_records = 7 + self.filename = array_ops.placeholder(dtypes.string, shape=[]) + self.compression_type = array_ops.placeholder_with_default("", shape=[]) + + input_dataset = readers.TFRecordDataset([self.filename], + self.compression_type) + self.writer = writers.TFRecordWriter( + self._outputFilename(), self.compression_type).write(input_dataset) + + def _record(self, i): + return compat.as_bytes("Record %d" % (i)) + + def _createFile(self, options=None): + filename = self._inputFilename() + writer = python_io.TFRecordWriter(filename, options) + for i in range(self._num_records): + writer.write(self._record(i)) + writer.close() + return filename + + def _inputFilename(self): + return os.path.join(self.get_temp_dir(), "tf_record.in.txt") + + def _outputFilename(self): + return os.path.join(self.get_temp_dir(), "tf_record.out.txt") + + def testWrite(self): + with self.cached_session() as sess: + sess.run( + self.writer, feed_dict={ + self.filename: self._createFile(), + }) + for i, r in enumerate(tf_record.tf_record_iterator(self._outputFilename())): + self.assertAllEqual(self._record(i), r) + + def testWriteZLIB(self): + options = tf_record.TFRecordOptions(tf_record.TFRecordCompressionType.ZLIB) + with self.cached_session() as sess: + sess.run( + self.writer, + feed_dict={ + self.filename: self._createFile(options), + self.compression_type: "ZLIB", + }) + for i, r in enumerate( + tf_record.tf_record_iterator(self._outputFilename(), options=options)): + self.assertAllEqual(self._record(i), r) + + def testWriteGZIP(self): + options = tf_record.TFRecordOptions(tf_record.TFRecordCompressionType.GZIP) + with self.cached_session() as sess: + sess.run( + self.writer, + feed_dict={ + self.filename: self._createFile(options), + self.compression_type: "GZIP", + }) + for i, r in enumerate( + tf_record.tf_record_iterator(self._outputFilename(), options=options)): + self.assertAllEqual(self._record(i), r) + + def testFailDataset(self): + with self.assertRaises(TypeError): + writers.TFRecordWriter(self._outputFilename(), + self.compression_type).write("whoops") + + def testFailDType(self): + input_dataset = dataset_ops.Dataset.from_tensors(10) + with self.assertRaises(TypeError): + writers.TFRecordWriter(self._outputFilename(), + self.compression_type).write(input_dataset) + + def testFailShape(self): + input_dataset = dataset_ops.Dataset.from_tensors([["hello"], ["world"]]) + with self.assertRaises(TypeError): + writers.TFRecordWriter(self._outputFilename(), + self.compression_type).write(input_dataset) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/ops/BUILD b/tensorflow/python/data/experimental/ops/BUILD new file mode 100644 index 0000000000..915d399f1b --- /dev/null +++ b/tensorflow/python/data/experimental/ops/BUILD @@ -0,0 +1,377 @@ +package(default_visibility = ["//tensorflow:internal"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load( + "//tensorflow:tensorflow.bzl", + "tf_gen_op_wrapper_py", + "tf_kernel_library", +) +load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") + +py_library( + name = "counter", + srcs = ["counter.py"], + srcs_version = "PY2AND3", + deps = [ + ":scan_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_library( + name = "get_single_element", + srcs = ["get_single_element.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", + ], +) + +py_library( + name = "iterator_ops", + srcs = [ + "iterator_ops.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:basic_session_run_hooks", + "//tensorflow/python:checkpoint_management", + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:framework_ops", + "//tensorflow/python:saver", + "//tensorflow/python:session_run_hook", + "//tensorflow/python/data/ops:iterator_ops", + "//tensorflow/python/data/ops:optional_ops", + ], +) + +py_library( + name = "random_ops", + srcs = [ + "random_ops.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:constant_op", + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:random_seed", + "//tensorflow/python:tensor_shape", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", + ], +) + +py_library( + name = "readers", + srcs = [ + "readers.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":batching", + ":interleave_ops", + ":optimization", + ":parsing_ops", + ":shuffle_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:dtypes", + "//tensorflow/python:experimental_dataset_ops_gen", + "//tensorflow/python:framework_ops", + "//tensorflow/python:lib", + "//tensorflow/python:platform", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:util", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/ops:readers", + "//tensorflow/python/data/util:convert", + "//tensorflow/python/data/util:nest", + "//third_party/py/numpy", + ], +) + +py_library( + name = "shuffle_ops", + srcs = [ + "shuffle_ops.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_library( + name = "batching", + srcs = ["batching.py"], + srcs_version = "PY2AND3", + deps = [ + ":get_single_element", + ":grouping", + "//tensorflow/python:array_ops", + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:tensor_util", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:convert", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", + "//third_party/py/numpy", + ], +) + +py_library( + name = "enumerate_ops", + srcs = ["enumerate_ops.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:dtypes", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_library( + name = "error_ops", + srcs = ["error_ops.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:experimental_dataset_ops_gen", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", + ], +) + +py_library( + name = "grouping", + srcs = ["grouping.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:check_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:function", + "//tensorflow/python:math_ops", + "//tensorflow/python:tensor_shape", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", + ], +) + +py_library( + name = "interleave_ops", + srcs = ["interleave_ops.py"], + srcs_version = "PY2AND3", + deps = [ + ":random_ops", + "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:experimental_dataset_ops_gen", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:stateless_random_ops_gen", + "//tensorflow/python:util", + "//tensorflow/python/data/ops:readers", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", + ], +) + +py_library( + name = "optimization", + srcs = ["optimization.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:dtypes", + "//tensorflow/python:experimental_dataset_ops_gen", + "//tensorflow/python:framework_ops", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", + ], +) + +py_library( + name = "parsing_ops", + srcs = ["parsing_ops.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:tensor_shape", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + ], +) + +py_library( + name = "map_defun", + srcs = ["map_defun.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:framework_ops", + "//tensorflow/python:tensor_shape", + ], +) + +py_library( + name = "resampling", + srcs = ["resampling.py"], + srcs_version = "PY2AND3", + deps = [ + ":batching", + ":interleave_ops", + ":scan_ops", + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:logging_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_library( + name = "scan_ops", + srcs = ["scan_ops.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:framework_ops", + "//tensorflow/python:function", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", + ], +) + +py_library( + name = "stats_ops", + srcs = ["stats_ops.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/ops:iterator_ops", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", + ], +) + +py_library( + name = "threadpool", + srcs = ["threadpool.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:experimental_dataset_ops_gen", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", + "//tensorflow/python/eager:context", + ], +) + +py_library( + name = "unique", + srcs = [ + "unique.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:dtypes", + "//tensorflow/python:experimental_dataset_ops_gen", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", + ], +) + +py_library( + name = "writers", + srcs = [ + "writers.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:dtypes", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_library( + name = "indexed_dataset_ops", + srcs = ["indexed_dataset_ops.py"], + deps = [ + "//tensorflow/python:experimental_dataset_ops_gen", + "//tensorflow/python:framework_ops", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", + ], +) + +py_library( + name = "prefetching_ops", + srcs = ["prefetching_ops.py"], + deps = [ + "//tensorflow/python:experimental_dataset_ops_gen", + "//tensorflow/python:framework_ops", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", + ], +) + +py_library( + name = "dataset_ops", + deps = [ + ":batching", + ":counter", + ":enumerate_ops", + ":error_ops", + ":get_single_element", + ":grouping", + ":indexed_dataset_ops", + ":interleave_ops", + ":map_defun", + ":optimization", + ":prefetching_ops", + ":readers", + ":resampling", + ":scan_ops", + ":shuffle_ops", + ":stats_ops", + ":threadpool", + ":unique", + ":writers", + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:util", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + ], +) diff --git a/tensorflow/python/data/experimental/ops/batching.py b/tensorflow/python/data/experimental/ops/batching.py new file mode 100644 index 0000000000..d42af9e7e9 --- /dev/null +++ b/tensorflow/python/data/experimental/ops/batching.py @@ -0,0 +1,669 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Batching dataset transformations.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.data.experimental.ops import get_single_element +from tensorflow.python.data.experimental.ops import grouping +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import convert +from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import sparse_ops +from tensorflow.python.util.tf_export import tf_export + + +def batch_window(dataset): + """Batches a window of tensors. + + Args: + dataset: the input dataset. + + Returns: + A `Tensor` representing the batch of the entire input dataset. + """ + if isinstance(dataset.output_classes, tuple): + raise TypeError("Input dataset expected to have a single component") + if dataset.output_classes is ops.Tensor: + return _batch_dense_window(dataset) + elif dataset.output_classes is sparse_tensor.SparseTensor: + return _batch_sparse_window(dataset) + else: + raise TypeError("Unsupported dataset type: %s" % dataset.output_classes) + + +def _batch_dense_window(dataset): + """Batches a window of dense tensors.""" + + def key_fn(_): + return np.int64(0) + + def shape_init_fn(_): + return array_ops.shape(first_element) + + def shape_reduce_fn(state, value): + check_ops.assert_equal(state, array_ops.shape(value)) + return state + + def finalize_fn(state): + return state + + if dataset.output_shapes.is_fully_defined(): + shape = dataset.output_shapes + else: + first_element = get_single_element.get_single_element(dataset.take(1)) + shape_reducer = grouping.Reducer(shape_init_fn, shape_reduce_fn, + finalize_fn) + shape = get_single_element.get_single_element( + dataset.apply(grouping.group_by_reducer(key_fn, shape_reducer))) + + def batch_init_fn(_): + batch_shape = array_ops.concat([[0], shape], 0) + return gen_array_ops.empty(batch_shape, dtype=dataset.output_types) + + def batch_reduce_fn(state, value): + return array_ops.concat([state, [value]], 0) + + batch_reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn) + return get_single_element.get_single_element( + dataset.apply(grouping.group_by_reducer(key_fn, batch_reducer))) + + +def _batch_sparse_window(dataset): + """Batches a window of sparse tensors.""" + + def key_fn(_): + return np.int64(0) + + def shape_init_fn(_): + return first_element.dense_shape + + def shape_reduce_fn(state, value): + check_ops.assert_equal(state, value.dense_shape) + return state + + def finalize_fn(state): + return state + + if dataset.output_shapes.is_fully_defined(): + shape = dataset.output_shapes + else: + first_element = get_single_element.get_single_element(dataset.take(1)) + shape_reducer = grouping.Reducer(shape_init_fn, shape_reduce_fn, + finalize_fn) + shape = get_single_element.get_single_element( + dataset.apply(grouping.group_by_reducer(key_fn, shape_reducer))) + + def batch_init_fn(_): + indices_shape = array_ops.concat([[0], [array_ops.size(shape) + 1]], 0) + return sparse_tensor.SparseTensor( + indices=gen_array_ops.empty(indices_shape, dtype=dtypes.int64), + values=constant_op.constant([], shape=[0], dtype=dataset.output_types), + dense_shape=array_ops.concat( + [np.array([0], dtype=np.int64), + math_ops.cast(shape, dtypes.int64)], 0)) + + def batch_reduce_fn(state, value): + return sparse_ops.sparse_concat(0, [state, value]) + + def reshape_fn(value): + return sparse_ops.sparse_reshape( + value, + array_ops.concat([np.array([1], dtype=np.int64), value.dense_shape], 0)) + + batch_reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn) + return get_single_element.get_single_element( + dataset.map(reshape_fn).apply( + grouping.group_by_reducer(key_fn, batch_reducer))) + + +@tf_export("data.experimental.dense_to_sparse_batch") +def dense_to_sparse_batch(batch_size, row_shape): + """A transformation that batches ragged elements into `tf.SparseTensor`s. + + Like `Dataset.padded_batch()`, this transformation combines multiple + consecutive elements of the dataset, which might have different + shapes, into a single element. The resulting element has three + components (`indices`, `values`, and `dense_shape`), which + comprise a `tf.SparseTensor` that represents the same data. The + `row_shape` represents the dense shape of each row in the + resulting `tf.SparseTensor`, to which the effective batch size is + prepended. For example: + + ```python + # NOTE: The following examples use `{ ... }` to represent the + # contents of a dataset. + a = { ['a', 'b', 'c'], ['a', 'b'], ['a', 'b', 'c', 'd'] } + + a.apply(tf.data.experimental.dense_to_sparse_batch( + batch_size=2, row_shape=[6])) == + { + ([[0, 0], [0, 1], [0, 2], [1, 0], [1, 1]], # indices + ['a', 'b', 'c', 'a', 'b'], # values + [2, 6]), # dense_shape + ([[0, 0], [0, 1], [0, 2], [0, 3]], + ['a', 'b', 'c', 'd'], + [1, 6]) + } + ``` + + Args: + batch_size: A `tf.int64` scalar `tf.Tensor`, representing the + number of consecutive elements of this dataset to combine in a + single batch. + row_shape: A `tf.TensorShape` or `tf.int64` vector tensor-like + object representing the equivalent dense shape of a row in the + resulting `tf.SparseTensor`. Each element of this dataset must + have the same rank as `row_shape`, and must have size less + than or equal to `row_shape` in each dimension. + + Returns: + A `Dataset` transformation function, which can be passed to + `tf.data.Dataset.apply`. + """ + + def _apply_fn(dataset): + return _DenseToSparseBatchDataset(dataset, batch_size, row_shape) + + return _apply_fn + + +def padded_batch_window(dataset, padded_shape, padding_value=None): + """Batches a window of tensors with padding. + + Args: + dataset: the input dataset. + padded_shape: (Optional.) `tf.TensorShape` or `tf.int64` vector tensor-like + object representing the shape to which the input elements should be padded + prior to batching. Any unknown dimensions (e.g. `tf.Dimension(None)` in a + `tf.TensorShape` or `-1` in a tensor-like object) will be padded to the + maximum size of that dimension in each batch. + padding_value: (Optional.) A scalar-shaped `tf.Tensor`, representing the + padding value to use. Defaults are `0` for numeric types and the empty + string for string types. If `dataset` contains `tf.SparseTensor`, this + value is ignored. + + Returns: + A `Tensor` representing the batch of the entire input dataset. + + Raises: + ValueError: if invalid arguments are provided. + """ + if not issubclass(dataset.output_classes, + (ops.Tensor, sparse_tensor.SparseTensor)): + raise TypeError("Input dataset expected to have a single tensor component") + if issubclass(dataset.output_classes, (ops.Tensor)): + return _padded_batch_dense_window(dataset, padded_shape, padding_value) + elif issubclass(dataset.output_classes, (sparse_tensor.SparseTensor)): + if padding_value is not None: + raise ValueError("Padding value not allowed for sparse tensors") + return _padded_batch_sparse_window(dataset, padded_shape) + else: + raise TypeError("Unsupported dataset type: %s" % dataset.output_classes) + + +def _padded_batch_dense_window(dataset, padded_shape, padding_value=None): + """Batches a window of dense tensors with padding.""" + + padded_shape = math_ops.cast( + convert.partial_shape_to_tensor(padded_shape), dtypes.int32) + + def key_fn(_): + return np.int64(0) + + def max_init_fn(_): + return padded_shape + + def max_reduce_fn(state, value): + """Computes the maximum shape to pad to.""" + condition = math_ops.reduce_all( + math_ops.logical_or( + math_ops.less_equal(array_ops.shape(value), padded_shape), + math_ops.equal(padded_shape, -1))) + assert_op = control_flow_ops.Assert(condition, [ + "Actual shape greater than padded shape: ", + array_ops.shape(value), padded_shape + ]) + with ops.control_dependencies([assert_op]): + return math_ops.maximum(state, array_ops.shape(value)) + + def finalize_fn(state): + return state + + # Compute the padded shape. + max_reducer = grouping.Reducer(max_init_fn, max_reduce_fn, finalize_fn) + padded_shape = get_single_element.get_single_element( + dataset.apply(grouping.group_by_reducer(key_fn, max_reducer))) + + if padding_value is None: + if dataset.output_types == dtypes.string: + padding_value = "" + elif dataset.output_types == dtypes.bool: + padding_value = False + elif dataset.output_types == dtypes.variant: + raise TypeError("Unable to create padding for field of type 'variant'") + else: + padding_value = 0 + + def batch_init_fn(_): + batch_shape = array_ops.concat( + [np.array([0], dtype=np.int32), padded_shape], 0) + return gen_array_ops.empty(batch_shape, dtype=dataset.output_types) + + def batch_reduce_fn(state, value): + return array_ops.concat([state, [value]], 0) + + def pad_fn(value): + shape = array_ops.shape(value) + left = array_ops.zeros_like(shape) + right = padded_shape - shape + return array_ops.pad( + value, array_ops.stack([left, right], 1), constant_values=padding_value) + + batch_reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn) + return get_single_element.get_single_element( + dataset.map(pad_fn).apply( + grouping.group_by_reducer(key_fn, batch_reducer))) + + +def _padded_batch_sparse_window(dataset, padded_shape): + """Batches a window of sparse tensors with padding.""" + + def key_fn(_): + return np.int64(0) + + def max_init_fn(_): + return convert.partial_shape_to_tensor(padded_shape) + + def max_reduce_fn(state, value): + """Computes the maximum shape to pad to.""" + condition = math_ops.reduce_all( + math_ops.logical_or( + math_ops.less_equal(value.dense_shape, padded_shape), + math_ops.equal(padded_shape, -1))) + assert_op = control_flow_ops.Assert(condition, [ + "Actual shape greater than padded shape: ", value.dense_shape, + padded_shape + ]) + with ops.control_dependencies([assert_op]): + return math_ops.maximum(state, value.dense_shape) + + def finalize_fn(state): + return state + + # Compute the padded shape. + max_reducer = grouping.Reducer(max_init_fn, max_reduce_fn, finalize_fn) + padded_shape = get_single_element.get_single_element( + dataset.apply(grouping.group_by_reducer(key_fn, max_reducer))) + + def batch_init_fn(_): + indices_shape = array_ops.concat([[0], [array_ops.size(padded_shape) + 1]], + 0) + return sparse_tensor.SparseTensor( + indices=gen_array_ops.empty(indices_shape, dtype=dtypes.int64), + values=constant_op.constant([], shape=[0], dtype=dataset.output_types), + dense_shape=array_ops.concat( + [np.array([0], dtype=np.int64), padded_shape], 0)) + + def batch_reduce_fn(state, value): + padded_value = sparse_tensor.SparseTensor( + indices=value.indices, values=value.values, dense_shape=padded_shape) + reshaped_value = sparse_ops.sparse_reshape( + padded_value, + array_ops.concat( + [np.array([1], dtype=np.int64), padded_value.dense_shape], 0)) + return sparse_ops.sparse_concat(0, [state, reshaped_value]) + + reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn) + return get_single_element.get_single_element( + dataset.apply(grouping.group_by_reducer(key_fn, reducer))) + + +class _UnbatchDataset(dataset_ops.UnaryDataset): + """A dataset that splits the elements of its input into multiple elements.""" + + def __init__(self, input_dataset): + """See `unbatch()` for more details.""" + super(_UnbatchDataset, self).__init__(input_dataset) + flat_shapes = nest.flatten(input_dataset.output_shapes) + if any(s.ndims == 0 for s in flat_shapes): + raise ValueError("Cannot unbatch an input with scalar components.") + known_batch_dim = tensor_shape.Dimension(None) + for s in flat_shapes: + try: + known_batch_dim = known_batch_dim.merge_with(s[0]) + except ValueError: + raise ValueError("Cannot unbatch an input whose components have " + "different batch sizes.") + self._input_dataset = input_dataset + + def _as_variant_tensor(self): + return gen_dataset_ops.unbatch_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + **dataset_ops.flat_structure(self)) + + @property + def output_classes(self): + return self._input_dataset.output_classes + + @property + def output_shapes(self): + return nest.map_structure(lambda s: s[1:], + self._input_dataset.output_shapes) + + @property + def output_types(self): + return self._input_dataset.output_types + + +@tf_export("data.experimental.unbatch") +def unbatch(): + """Splits elements of a dataset into multiple elements on the batch dimension. + + For example, if elements of the dataset are shaped `[B, a0, a1, ...]`, + where `B` may vary for each input element, then for each element in the + dataset, the unbatched dataset will contain `B` consecutive elements + of shape `[a0, a1, ...]`. + + ```python + # NOTE: The following example uses `{ ... }` to represent the contents + # of a dataset. + a = { ['a', 'b', 'c'], ['a', 'b'], ['a', 'b', 'c', 'd'] } + + a.apply(tf.data.experimental.unbatch()) == { + 'a', 'b', 'c', 'a', 'b', 'a', 'b', 'c', 'd'} + ``` + + Returns: + A `Dataset` transformation function, which can be passed to + `tf.data.Dataset.apply`. + """ + + def _apply_fn(dataset): + """Function from `Dataset` to `Dataset` that applies the transformation.""" + if not sparse.any_sparse(dataset.output_classes): + return _UnbatchDataset(dataset) + + # NOTE(mrry): We must ensure that any SparseTensors in `dataset` + # are normalized to the rank-1 dense representation, so that the + # sparse-oblivious unbatching logic will slice them + # appropriately. This leads to a somewhat inefficient re-encoding step + # for all SparseTensor components. + # TODO(mrry): Consider optimizing this in future + # if it turns out to be a bottleneck. + def normalize(arg, *rest): + if rest: + return sparse.serialize_many_sparse_tensors((arg,) + rest) + else: + return sparse.serialize_many_sparse_tensors(arg) + + normalized_dataset = dataset.map(normalize) + + # NOTE(mrry): Our `map()` has lost information about the sparseness + # of any SparseTensor components, so re-apply the structure of the + # original dataset. + restructured_dataset = _RestructuredDataset( + normalized_dataset, + dataset.output_types, + dataset.output_shapes, + dataset.output_classes, + allow_unsafe_cast=True) + return _UnbatchDataset(restructured_dataset) + + return _apply_fn + + +class _DenseToSparseBatchDataset(dataset_ops.UnaryDataset): + """A `Dataset` that batches ragged dense elements into `tf.SparseTensor`s.""" + + def __init__(self, input_dataset, batch_size, row_shape): + """See `Dataset.dense_to_sparse_batch()` for more details.""" + super(_DenseToSparseBatchDataset, self).__init__(input_dataset) + if not isinstance(input_dataset.output_types, dtypes.DType): + raise TypeError("DenseToSparseDataset requires an input whose elements " + "have a single component, whereas the input has %r." % + input_dataset.output_types) + self._input_dataset = input_dataset + self._batch_size = batch_size + self._row_shape = row_shape + + def _as_variant_tensor(self): + return gen_dataset_ops.dense_to_sparse_batch_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + self._batch_size, + row_shape=convert.partial_shape_to_tensor(self._row_shape), + **dataset_ops.flat_structure(self)) + + @property + def output_classes(self): + return sparse_tensor.SparseTensor + + @property + def output_shapes(self): + return tensor_shape.vector(None).concatenate(self._row_shape) + + @property + def output_types(self): + return self._input_dataset.output_types + + +class _RestructuredDataset(dataset_ops.UnaryDataset): + """An internal helper for changing the structure and shape of a dataset.""" + + def __init__(self, + dataset, + output_types, + output_shapes=None, + output_classes=None, + allow_unsafe_cast=False): + """Creates a new dataset with the given output types and shapes. + + The given `dataset` must have a structure that is convertible: + * `dataset.output_types` must be the same as `output_types` module nesting. + * Each shape in `dataset.output_shapes` must be compatible with each shape + in `output_shapes` (if given). + + Note: This helper permits "unsafe casts" for shapes, equivalent to using + `tf.Tensor.set_shape()` where domain-specific knowledge is available. + + Args: + dataset: A `Dataset` object. + output_types: A nested structure of `tf.DType` objects. + output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects. + If omitted, the shapes will be inherited from `dataset`. + output_classes: (Optional.) A nested structure of class types. + If omitted, the class types will be inherited from `dataset`. + allow_unsafe_cast: (Optional.) If `True`, the caller may switch the + reported output types and shapes of the restructured dataset, e.g. to + switch a sparse tensor represented as `tf.variant` to its user-visible + type and shape. + + Raises: + ValueError: If either `output_types` or `output_shapes` is not compatible + with the structure of `dataset`. + """ + super(_RestructuredDataset, self).__init__(dataset) + self._input_dataset = dataset + + if not allow_unsafe_cast: + # Validate that the types are compatible. + output_types = nest.map_structure(dtypes.as_dtype, output_types) + flat_original_types = nest.flatten(dataset.output_types) + flat_new_types = nest.flatten(output_types) + if flat_original_types != flat_new_types: + raise ValueError( + "Dataset with output types %r cannot be restructured to have " + "output types %r" % (dataset.output_types, output_types)) + + self._output_types = output_types + + if output_shapes is None: + # Inherit shapes from the original `dataset`. + self._output_shapes = nest.pack_sequence_as(output_types, + nest.flatten( + dataset.output_shapes)) + else: + if not allow_unsafe_cast: + # Validate that the shapes are compatible. + nest.assert_same_structure(output_types, output_shapes) + flat_original_shapes = nest.flatten(dataset.output_shapes) + flat_new_shapes = nest.flatten_up_to(output_types, output_shapes) + + for original_shape, new_shape in zip(flat_original_shapes, + flat_new_shapes): + if not original_shape.is_compatible_with(new_shape): + raise ValueError( + "Dataset with output shapes %r cannot be restructured to have " + "incompatible output shapes %r" % (dataset.output_shapes, + output_shapes)) + self._output_shapes = nest.map_structure_up_to( + output_types, tensor_shape.as_shape, output_shapes) + if output_classes is None: + # Inherit class types from the original `dataset`. + self._output_classes = nest.pack_sequence_as(output_types, + nest.flatten( + dataset.output_classes)) + else: + self._output_classes = output_classes + + def _as_variant_tensor(self): + return self._input_dataset._as_variant_tensor() # pylint: disable=protected-access + + @property + def output_classes(self): + return self._output_classes + + @property + def output_types(self): + return self._output_types + + @property + def output_shapes(self): + return self._output_shapes + + +class _MapAndBatchDataset(dataset_ops.MapDataset): + """A `Dataset` that maps a function over a batch of elements.""" + + def __init__(self, input_dataset, map_func, batch_size, num_parallel_calls, + drop_remainder): + """See `Dataset.map()` for details.""" + super(_MapAndBatchDataset, self).__init__(input_dataset, map_func) + self._batch_size_t = ops.convert_to_tensor( + batch_size, dtype=dtypes.int64, name="batch_size") + self._num_parallel_calls_t = ops.convert_to_tensor( + num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls") + self._drop_remainder_t = ops.convert_to_tensor( + drop_remainder, dtype=dtypes.bool, name="drop_remainder") + + self._batch_size = batch_size + self._drop_remainder = drop_remainder + + def _as_variant_tensor(self): + # pylint: disable=protected-access + input_resource = self._input_dataset._as_variant_tensor() + return gen_dataset_ops.map_and_batch_dataset_v2( + input_resource, + self._map_func.captured_inputs, + f=self._map_func, + batch_size=self._batch_size_t, + num_parallel_calls=self._num_parallel_calls_t, + drop_remainder=self._drop_remainder_t, + **dataset_ops.flat_structure(self)) + # pylint: enable=protected-access + + @property + def output_shapes(self): + dim = self._batch_size if self._drop_remainder else None + return nest.pack_sequence_as(self._output_shapes, [ + tensor_shape.vector(dim).concatenate(s) + for s in nest.flatten(self._output_shapes) + ]) + + @property + def output_types(self): + return self._output_types + + +@tf_export("data.experimental.map_and_batch") +def map_and_batch(map_func, + batch_size, + num_parallel_batches=None, + drop_remainder=False, + num_parallel_calls=None): + """Fused implementation of `map` and `batch`. + + Maps `map_func` across `batch_size` consecutive elements of this dataset + and then combines them into a batch. Functionally, it is equivalent to `map` + followed by `batch`. However, by fusing the two transformations together, the + implementation can be more efficient. Surfacing this transformation in the API + is temporary. Once automatic input pipeline optimization is implemented, + the fusing of `map` and `batch` will happen automatically and this API will be + deprecated. + + Args: + map_func: A function mapping a nested structure of tensors to another + nested structure of tensors. + batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of + consecutive elements of this dataset to combine in a single batch. + num_parallel_batches: (Optional.) A `tf.int64` scalar `tf.Tensor`, + representing the number of batches to create in parallel. On one hand, + higher values can help mitigate the effect of stragglers. On the other + hand, higher values can increase contention if CPU is scarce. + drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing + whether the last batch should be dropped in case its size is smaller than + desired; the default behavior is not to drop the smaller batch. + num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`, + representing the number of elements to process in parallel. If not + specified, `batch_size * num_parallel_batches` elements will be + processed in parallel. + + Returns: + A `Dataset` transformation function, which can be passed to + `tf.data.Dataset.apply`. + + Raises: + ValueError: If both `num_parallel_batches` and `num_parallel_calls` are + specified. + """ + + if num_parallel_batches is None and num_parallel_calls is None: + num_parallel_calls = batch_size + elif num_parallel_batches is not None and num_parallel_calls is None: + num_parallel_calls = batch_size * num_parallel_batches + elif num_parallel_batches is not None and num_parallel_calls is not None: + raise ValueError("The `num_parallel_batches` and `num_parallel_calls` " + "arguments are mutually exclusive.") + + def _apply_fn(dataset): + return _MapAndBatchDataset(dataset, map_func, batch_size, + num_parallel_calls, drop_remainder) + + return _apply_fn diff --git a/tensorflow/python/data/experimental/ops/counter.py b/tensorflow/python/data/experimental/ops/counter.py new file mode 100644 index 0000000000..42200eaef9 --- /dev/null +++ b/tensorflow/python/data/experimental/ops/counter.py @@ -0,0 +1,55 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The Counter Dataset.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.experimental.ops import scan_ops + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.util.tf_export import tf_export + + +@tf_export("data.experimental.Counter") +def Counter(start=0, step=1, dtype=dtypes.int64): + """Creates a `Dataset` that counts from `start` in steps of size `step`. + + For example: + + ```python + Dataset.count() == [0, 1, 2, ...) + Dataset.count(2) == [2, 3, ...) + Dataset.count(2, 5) == [2, 7, 12, ...) + Dataset.count(0, -1) == [0, -1, -2, ...) + Dataset.count(10, -1) == [10, 9, ...) + ``` + + Args: + start: (Optional.) The starting value for the counter. Defaults to 0. + step: (Optional.) The step size for the counter. Defaults to 1. + dtype: (Optional.) The data type for counter elements. Defaults to + `tf.int64`. + + Returns: + A `Dataset` of scalar `dtype` elements. + """ + with ops.name_scope("counter"): + start = ops.convert_to_tensor(start, dtype=dtype, name="start") + step = ops.convert_to_tensor(step, dtype=dtype, name="step") + return dataset_ops.Dataset.from_tensors(0).repeat(None).apply( + scan_ops.scan(start, lambda state, _: (state + step, state))) diff --git a/tensorflow/python/data/experimental/ops/enumerate_ops.py b/tensorflow/python/data/experimental/ops/enumerate_ops.py new file mode 100644 index 0000000000..a1af98f552 --- /dev/null +++ b/tensorflow/python/data/experimental/ops/enumerate_ops.py @@ -0,0 +1,60 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Enumerate dataset transformations.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.util.tf_export import tf_export + + +@tf_export("data.experimental.enumerate_dataset") +def enumerate_dataset(start=0): + """A transformation that enumerate the elements of a dataset. + + It is Similar to python's `enumerate`. + For example: + + ```python + # NOTE: The following examples use `{ ... }` to represent the + # contents of a dataset. + a = { 1, 2, 3 } + b = { (7, 8), (9, 10) } + + # The nested structure of the `datasets` argument determines the + # structure of elements in the resulting dataset. + a.apply(tf.data.experimental.enumerate(start=5)) == { (5, 1), (6, 2), (7, 3) } + b.apply(tf.data.experimental.enumerate()) == { (0, (7, 8)), (1, (9, 10)) } + ``` + + Args: + start: A `tf.int64` scalar `tf.Tensor`, representing the start + value for enumeration. + + Returns: + A `Dataset` transformation function, which can be passed to + `tf.data.Dataset.apply`. + """ + + def _apply_fn(dataset): + max_value = np.iinfo(dtypes.int64.as_numpy_dtype).max + return dataset_ops.Dataset.zip((dataset_ops.Dataset.range(start, max_value), + dataset)) + + return _apply_fn diff --git a/tensorflow/python/data/experimental/ops/error_ops.py b/tensorflow/python/data/experimental/ops/error_ops.py new file mode 100644 index 0000000000..82e274b70c --- /dev/null +++ b/tensorflow/python/data/experimental/ops/error_ops.py @@ -0,0 +1,78 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Ignore_errors dataset transformations.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.ops import gen_experimental_dataset_ops +from tensorflow.python.util.tf_export import tf_export + + +@tf_export("data.experimental.ignore_errors") +def ignore_errors(): + """Creates a `Dataset` from another `Dataset` and silently ignores any errors. + + Use this transformation to produce a dataset that contains the same elements + as the input, but silently drops any elements that caused an error. For + example: + + ```python + dataset = tf.data.Dataset.from_tensor_slices([1., 2., 0., 4.]) + + # Computing `tf.check_numerics(1. / 0.)` will raise an InvalidArgumentError. + dataset = dataset.map(lambda x: tf.check_numerics(1. / x, "error")) + + # Using `ignore_errors()` will drop the element that causes an error. + dataset = + dataset.apply(tf.data.experimental.ignore_errors()) # ==> {1., 0.5, 0.2} + ``` + + Returns: + A `Dataset` transformation function, which can be passed to + `tf.data.Dataset.apply`. + """ + + def _apply_fn(dataset): + return _IgnoreErrorsDataset(dataset) + + return _apply_fn + + +class _IgnoreErrorsDataset(dataset_ops.UnaryDataset): + """A `Dataset` that silently ignores errors when computing its input.""" + + def __init__(self, input_dataset): + """See `Dataset.ignore_errors()` for details.""" + super(_IgnoreErrorsDataset, self).__init__(input_dataset) + self._input_dataset = input_dataset + + def _as_variant_tensor(self): + return gen_experimental_dataset_ops.experimental_ignore_errors_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + **dataset_ops.flat_structure(self)) + + @property + def output_classes(self): + return self._input_dataset.output_classes + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_types(self): + return self._input_dataset.output_types diff --git a/tensorflow/python/data/experimental/ops/get_single_element.py b/tensorflow/python/data/experimental/ops/get_single_element.py new file mode 100644 index 0000000000..132526166c --- /dev/null +++ b/tensorflow/python/data/experimental/ops/get_single_element.py @@ -0,0 +1,72 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python wrappers for Datasets and Iterators.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse +from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.util.tf_export import tf_export + + +@tf_export("data.experimental.get_single_element") +def get_single_element(dataset): + """Returns the single element in `dataset` as a nested structure of tensors. + + This function enables you to use a `tf.data.Dataset` in a stateless + "tensor-in tensor-out" expression, without creating a `tf.data.Iterator`. + This can be useful when your preprocessing transformations are expressed + as a `Dataset`, and you want to use the transformation at serving time. + For example: + + ```python + input_batch = tf.placeholder(tf.string, shape=[BATCH_SIZE]) + + def preprocessing_fn(input_str): + # ... + return image, label + + dataset = (tf.data.Dataset.from_tensor_slices(input_batch) + .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE) + .batch(BATCH_SIZE)) + + image_batch, label_batch = tf.data.experimental.get_single_element(dataset) + ``` + + Args: + dataset: A `tf.data.Dataset` object containing a single element. + + Returns: + A nested structure of `tf.Tensor` objects, corresponding to the single + element of `dataset`. + + Raises: + TypeError: if `dataset` is not a `tf.data.Dataset` object. + InvalidArgumentError (at runtime): if `dataset` does not contain exactly + one element. + """ + if not isinstance(dataset, dataset_ops.Dataset): + raise TypeError("`dataset` must be a `tf.data.Dataset` object.") + + nested_ret = nest.pack_sequence_as( + dataset.output_types, gen_dataset_ops.dataset_to_single_element( + dataset._as_variant_tensor(), # pylint: disable=protected-access + **dataset_ops.flat_structure(dataset))) + return sparse.deserialize_sparse_tensors( + nested_ret, dataset.output_types, dataset.output_shapes, + dataset.output_classes) diff --git a/tensorflow/python/data/experimental/ops/grouping.py b/tensorflow/python/data/experimental/ops/grouping.py new file mode 100644 index 0000000000..18ba583220 --- /dev/null +++ b/tensorflow/python/data/experimental/ops/grouping.py @@ -0,0 +1,551 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Grouping dataset transformations.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import nest +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.util.tf_export import tf_export + + +@tf_export("data.experimental.group_by_reducer") +def group_by_reducer(key_func, reducer): + """A transformation that groups elements and performs a reduction. + + This transformation maps element of a dataset to a key using `key_func` and + groups the elements by key. The `reducer` is used to process each group; its + `init_func` is used to initialize state for each group when it is created, the + `reduce_func` is used to update the state every time an element is mapped to + the matching group, and the `finalize_func` is used to map the final state to + an output value. + + Args: + key_func: A function mapping a nested structure of tensors + (having shapes and types defined by `self.output_shapes` and + `self.output_types`) to a scalar `tf.int64` tensor. + reducer: An instance of `Reducer`, which captures the reduction logic using + the `init_func`, `reduce_func`, and `finalize_func` functions. + + Returns: + A `Dataset` transformation function, which can be passed to + `tf.data.Dataset.apply`. + """ + + def _apply_fn(dataset): + """Function from `Dataset` to `Dataset` that applies the transformation.""" + return _GroupByReducerDataset(dataset, key_func, reducer) + + return _apply_fn + + +@tf_export("data.experimental.group_by_window") +def group_by_window(key_func, + reduce_func, + window_size=None, + window_size_func=None): + """A transformation that groups windows of elements by key and reduces them. + + This transformation maps each consecutive element in a dataset to a key + using `key_func` and groups the elements by key. It then applies + `reduce_func` to at most `window_size_func(key)` elements matching the same + key. All except the final window for each key will contain + `window_size_func(key)` elements; the final window may be smaller. + + You may provide either a constant `window_size` or a window size determined by + the key through `window_size_func`. + + Args: + key_func: A function mapping a nested structure of tensors + (having shapes and types defined by `self.output_shapes` and + `self.output_types`) to a scalar `tf.int64` tensor. + reduce_func: A function mapping a key and a dataset of up to `window_size` + consecutive elements matching that key to another dataset. + window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of + consecutive elements matching the same key to combine in a single + batch, which will be passed to `reduce_func`. Mutually exclusive with + `window_size_func`. + window_size_func: A function mapping a key to a `tf.int64` scalar + `tf.Tensor`, representing the number of consecutive elements matching + the same key to combine in a single batch, which will be passed to + `reduce_func`. Mutually exclusive with `window_size`. + + Returns: + A `Dataset` transformation function, which can be passed to + `tf.data.Dataset.apply`. + + Raises: + ValueError: if neither or both of {`window_size`, `window_size_func`} are + passed. + """ + if (window_size is not None and window_size_func or + not (window_size is not None or window_size_func)): + raise ValueError("Must pass either window_size or window_size_func.") + + if window_size is not None: + + def constant_window_func(unused_key): + return ops.convert_to_tensor(window_size, dtype=dtypes.int64) + + window_size_func = constant_window_func + + assert window_size_func is not None + + def _apply_fn(dataset): + """Function from `Dataset` to `Dataset` that applies the transformation.""" + return _GroupByWindowDataset(dataset, key_func, reduce_func, + window_size_func) + + return _apply_fn + + +@tf_export("data.experimental.bucket_by_sequence_length") +def bucket_by_sequence_length(element_length_func, + bucket_boundaries, + bucket_batch_sizes, + padded_shapes=None, + padding_values=None, + pad_to_bucket_boundary=False, + no_padding=False): + """A transformation that buckets elements in a `Dataset` by length. + + Elements of the `Dataset` are grouped together by length and then are padded + and batched. + + This is useful for sequence tasks in which the elements have variable length. + Grouping together elements that have similar lengths reduces the total + fraction of padding in a batch which increases training step efficiency. + + Args: + element_length_func: function from element in `Dataset` to `tf.int32`, + determines the length of the element, which will determine the bucket it + goes into. + bucket_boundaries: `list<int>`, upper length boundaries of the buckets. + bucket_batch_sizes: `list<int>`, batch size per bucket. Length should be + `len(bucket_boundaries) + 1`. + padded_shapes: Nested structure of `tf.TensorShape` to pass to + `tf.data.Dataset.padded_batch`. If not provided, will use + `dataset.output_shapes`, which will result in variable length dimensions + being padded out to the maximum length in each batch. + padding_values: Values to pad with, passed to + `tf.data.Dataset.padded_batch`. Defaults to padding with 0. + pad_to_bucket_boundary: bool, if `False`, will pad dimensions with unknown + size to maximum length in batch. If `True`, will pad dimensions with + unknown size to bucket boundary minus 1 (i.e., the maximum length in each + bucket), and caller must ensure that the source `Dataset` does not contain + any elements with length longer than `max(bucket_boundaries)`. + no_padding: `bool`, indicates whether to pad the batch features (features + need to be either of type `tf.SparseTensor` or of same shape). + + Returns: + A `Dataset` transformation function, which can be passed to + `tf.data.Dataset.apply`. + + Raises: + ValueError: if `len(bucket_batch_sizes) != len(bucket_boundaries) + 1`. + """ + with ops.name_scope("bucket_by_seq_length"): + if len(bucket_batch_sizes) != (len(bucket_boundaries) + 1): + raise ValueError( + "len(bucket_batch_sizes) must equal len(bucket_boundaries) + 1") + + batch_sizes = constant_op.constant(bucket_batch_sizes, dtype=dtypes.int64) + + def element_to_bucket_id(*args): + """Return int64 id of the length bucket for this element.""" + seq_length = element_length_func(*args) + + boundaries = list(bucket_boundaries) + buckets_min = [np.iinfo(np.int32).min] + boundaries + buckets_max = boundaries + [np.iinfo(np.int32).max] + conditions_c = math_ops.logical_and( + math_ops.less_equal(buckets_min, seq_length), + math_ops.less(seq_length, buckets_max)) + bucket_id = math_ops.reduce_min(array_ops.where(conditions_c)) + + return bucket_id + + def window_size_fn(bucket_id): + # The window size is set to the batch size for this bucket + window_size = batch_sizes[bucket_id] + return window_size + + def make_padded_shapes(shapes, none_filler=None): + padded = [] + for shape in nest.flatten(shapes): + shape = tensor_shape.TensorShape(shape) + shape = [ + none_filler if d.value is None else d + for d in shape + ] + padded.append(shape) + return nest.pack_sequence_as(shapes, padded) + + def batching_fn(bucket_id, grouped_dataset): + """Batch elements in dataset.""" + batch_size = window_size_fn(bucket_id) + if no_padding: + return grouped_dataset.batch(batch_size) + none_filler = None + if pad_to_bucket_boundary: + err_msg = ("When pad_to_bucket_boundary=True, elements must have " + "length < max(bucket_boundaries).") + check = check_ops.assert_less( + bucket_id, + constant_op.constant(len(bucket_batch_sizes) - 1, + dtype=dtypes.int64), + message=err_msg) + with ops.control_dependencies([check]): + boundaries = constant_op.constant(bucket_boundaries, + dtype=dtypes.int64) + bucket_boundary = boundaries[bucket_id] + none_filler = bucket_boundary - 1 + shapes = make_padded_shapes( + padded_shapes or grouped_dataset.output_shapes, + none_filler=none_filler) + return grouped_dataset.padded_batch(batch_size, shapes, padding_values) + + def _apply_fn(dataset): + return dataset.apply( + group_by_window(element_to_bucket_id, batching_fn, + window_size_func=window_size_fn)) + + return _apply_fn + + +def _map_x_dataset(map_func): + """A transformation that maps `map_func` across its input. + + This transformation is similar to `tf.data.Dataset.map`, but in addition to + supporting dense and sparse tensor inputs, it also supports dataset inputs. + + Args: + map_func: A function mapping a nested structure of tensors and/or datasets + (having shapes and types defined by `self.output_shapes` and + `self.output_types`) to another nested structure of tensors and/or + datasets. + + Returns: + Dataset: A `Dataset`. + """ + + def _apply_fn(dataset): + """Function from `Dataset` to `Dataset` that applies the transformation.""" + return _MapXDataset(dataset, map_func) + + return _apply_fn + + +class _GroupByReducerDataset(dataset_ops.UnaryDataset): + """A `Dataset` that groups its input and performs a reduction.""" + + def __init__(self, input_dataset, key_func, reducer): + """See `group_by_reducer()` for details.""" + super(_GroupByReducerDataset, self).__init__(input_dataset) + + self._input_dataset = input_dataset + + self._make_key_func(key_func, input_dataset) + self._make_init_func(reducer.init_func) + self._make_reduce_func(reducer.reduce_func, input_dataset) + self._make_finalize_func(reducer.finalize_func) + + def _make_key_func(self, key_func, input_dataset): + """Make wrapping Defun for key_func.""" + wrapped_func = dataset_ops.StructuredFunctionWrapper( + key_func, "tf.data.experimental.group_by_reducer()", input_dataset) + if not ( + wrapped_func.output_types == dtypes.int64 and + wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())): + raise ValueError( + "`key_func` must return a single tf.int64 tensor. " + "Got type=%s and shape=%s" + % (wrapped_func.output_types, wrapped_func.output_shapes)) + self._key_func = wrapped_func.function + + def _make_init_func(self, init_func): + """Make wrapping Defun for init_func.""" + wrapped_func = dataset_ops.StructuredFunctionWrapper( + init_func, + "tf.data.experimental.group_by_reducer()", + input_classes=ops.Tensor, + input_shapes=tensor_shape.scalar(), + input_types=dtypes.int64) + self._init_func = wrapped_func.function + self._state_classes = wrapped_func.output_classes + self._state_shapes = wrapped_func.output_shapes + self._state_types = wrapped_func.output_types + + def _make_reduce_func(self, reduce_func, input_dataset): + """Make wrapping Defun for reduce_func.""" + + # Iteratively rerun the reduce function until reaching a fixed point on + # `self._state_shapes`. + need_to_rerun = True + while need_to_rerun: + + wrapped_func = dataset_ops.StructuredFunctionWrapper( + reduce_func, + "tf.data.experimental.group_by_reducer()", + input_classes=(self._state_classes, input_dataset.output_classes), + input_shapes=(self._state_shapes, input_dataset.output_shapes), + input_types=(self._state_types, input_dataset.output_types), + add_to_graph=False) + + # Extract and validate class information from the returned values. + for new_state_class, state_class in zip( + nest.flatten(wrapped_func.output_classes), + nest.flatten(self._state_classes)): + if not issubclass(new_state_class, state_class): + raise TypeError( + "The element classes for the new state must match the initial " + "state. Expected %s; got %s." % + (self._state_classes, wrapped_func.output_classes)) + + # Extract and validate type information from the returned values. + for new_state_type, state_type in zip( + nest.flatten(wrapped_func.output_types), + nest.flatten(self._state_types)): + if new_state_type != state_type: + raise TypeError( + "The element types for the new state must match the initial " + "state. Expected %s; got %s." % + (self._state_types, wrapped_func.output_types)) + + # Extract shape information from the returned values. + flat_state_shapes = nest.flatten(self._state_shapes) + flat_new_state_shapes = nest.flatten(wrapped_func.output_shapes) + weakened_state_shapes = [ + original.most_specific_compatible_shape(new) + for original, new in zip(flat_state_shapes, flat_new_state_shapes) + ] + + need_to_rerun = False + for original_shape, weakened_shape in zip(flat_state_shapes, + weakened_state_shapes): + if original_shape.ndims is not None and ( + weakened_shape.ndims is None or + original_shape.as_list() != weakened_shape.as_list()): + need_to_rerun = True + break + + if need_to_rerun: + self._state_shapes = nest.pack_sequence_as(self._state_shapes, + weakened_state_shapes) + + self._reduce_func = wrapped_func.function + self._reduce_func.add_to_graph(ops.get_default_graph()) + + def _make_finalize_func(self, finalize_func): + """Make wrapping Defun for finalize_func.""" + wrapped_func = dataset_ops.StructuredFunctionWrapper( + finalize_func, + "tf.data.experimental.group_by_reducer()", + input_classes=self._state_classes, + input_shapes=self._state_shapes, + input_types=self._state_types) + self._finalize_func = wrapped_func.function + self._output_classes = wrapped_func.output_classes + self._output_shapes = wrapped_func.output_shapes + self._output_types = wrapped_func.output_types + + @property + def output_classes(self): + return self._output_classes + + @property + def output_shapes(self): + return self._output_shapes + + @property + def output_types(self): + return self._output_types + + def _as_variant_tensor(self): + return gen_dataset_ops.group_by_reducer_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + self._key_func.captured_inputs, + self._init_func.captured_inputs, + self._reduce_func.captured_inputs, + self._finalize_func.captured_inputs, + key_func=self._key_func, + init_func=self._init_func, + reduce_func=self._reduce_func, + finalize_func=self._finalize_func, + **dataset_ops.flat_structure(self)) + + +class _GroupByWindowDataset(dataset_ops.UnaryDataset): + """A `Dataset` that groups its input and performs a windowed reduction.""" + + def __init__(self, input_dataset, key_func, reduce_func, window_size_func): + """See `group_by_window()` for details.""" + super(_GroupByWindowDataset, self).__init__(input_dataset) + + self._input_dataset = input_dataset + + self._make_key_func(key_func, input_dataset) + self._make_reduce_func(reduce_func, input_dataset) + self._make_window_size_func(window_size_func) + + def _make_window_size_func(self, window_size_func): + """Make wrapping Defun for window_size_func.""" + def window_size_func_wrapper(key): + return ops.convert_to_tensor(window_size_func(key), dtype=dtypes.int64) + wrapped_func = dataset_ops.StructuredFunctionWrapper( + window_size_func_wrapper, + "tf.data.experimental.group_by_window()", + input_classes=ops.Tensor, + input_shapes=tensor_shape.scalar(), + input_types=dtypes.int64) + if not ( + wrapped_func.output_types == dtypes.int64 and + wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())): + raise ValueError( + "`window_size_func` must return a single tf.int64 scalar tensor.") + self._window_size_func = wrapped_func.function + + def _make_key_func(self, key_func, input_dataset): + """Make wrapping Defun for key_func.""" + def key_func_wrapper(*args): + return ops.convert_to_tensor(key_func(*args), dtype=dtypes.int64) + wrapped_func = dataset_ops.StructuredFunctionWrapper( + key_func_wrapper, "tf.data.experimental.group_by_window()", + input_dataset) + if not ( + wrapped_func.output_types == dtypes.int64 and + wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())): + raise ValueError( + "`key_func` must return a single tf.int64 scalar tensor.") + self._key_func = wrapped_func.function + + def _make_reduce_func(self, reduce_func, input_dataset): + """Make wrapping Defun for reduce_func.""" + nested_dataset = dataset_ops._NestedDatasetComponent(input_dataset) # pylint: disable=protected-access + wrapped_func = dataset_ops.StructuredFunctionWrapper( + reduce_func, + "tf.data.experimental.reduce_by_window()", + input_classes=(ops.Tensor, nested_dataset), + input_shapes=(tensor_shape.scalar(), nested_dataset), + input_types=(dtypes.int64, nested_dataset), + experimental_nested_dataset_support=True) + if not isinstance( + wrapped_func.output_classes, dataset_ops._NestedDatasetComponent): # pylint: disable=protected-access + raise TypeError("`reduce_func` must return a `Dataset` object.") + self._output_classes = wrapped_func.output_classes.output_classes + self._output_types = wrapped_func.output_types.output_types + self._output_shapes = wrapped_func.output_shapes.output_shapes + self._reduce_func = wrapped_func.function + + @property + def output_classes(self): + return self._output_classes + + @property + def output_shapes(self): + return self._output_shapes + + @property + def output_types(self): + return self._output_types + + def _as_variant_tensor(self): + return gen_dataset_ops.group_by_window_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + self._key_func.captured_inputs, + self._reduce_func.captured_inputs, + self._window_size_func.captured_inputs, + key_func=self._key_func, + reduce_func=self._reduce_func, + window_size_func=self._window_size_func, + **dataset_ops.flat_structure(self)) + + +@tf_export("data.experimental.Reducer") +class Reducer(object): + """A reducer is used for reducing a set of elements. + + A reducer is represented as a tuple of the three functions: + 1) initialization function: key => initial state + 2) reduce function: (old state, input) => new state + 3) finalization function: state => result + """ + + def __init__(self, init_func, reduce_func, finalize_func): + self._init_func = init_func + self._reduce_func = reduce_func + self._finalize_func = finalize_func + + @property + def init_func(self): + return self._init_func + + @property + def reduce_func(self): + return self._reduce_func + + @property + def finalize_func(self): + return self._finalize_func + + +class _MapXDataset(dataset_ops.UnaryDataset): + """A `Dataset` that maps a function over elements in its input.""" + + def __init__(self, input_dataset, map_func): + """See `map_x_dataset()` for details.""" + super(_MapXDataset, self).__init__(input_dataset) + self._input_dataset = input_dataset + + wrapped_func = dataset_ops.StructuredFunctionWrapper( + map_func, + "tf.data.experimental.map_x_dataset()", + input_dataset, + experimental_nested_dataset_support=True) + self._output_classes = wrapped_func.output_classes + self._output_shapes = wrapped_func.output_shapes + self._output_types = wrapped_func.output_types + self._map_func = wrapped_func.function + + def _as_variant_tensor(self): + input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access + return gen_dataset_ops.map_dataset( + input_t, + self._map_func.captured_inputs, + f=self._map_func, + **dataset_ops.flat_structure(self)) + + @property + def output_classes(self): + return self._output_classes + + @property + def output_shapes(self): + return self._output_shapes + + @property + def output_types(self): + return self._output_types diff --git a/tensorflow/python/data/experimental/ops/indexed_dataset_ops.py b/tensorflow/python/data/experimental/ops/indexed_dataset_ops.py new file mode 100644 index 0000000000..9c06474a2f --- /dev/null +++ b/tensorflow/python/data/experimental/ops/indexed_dataset_ops.py @@ -0,0 +1,177 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python wrappers for indexed datasets.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops + + +class MaterializedIndexedDataset(object): + """MaterializedIndexedDataset is highly experimental! + """ + + def __init__(self, materialized_resource, materializer, output_classes, + output_types, output_shapes): + self._materialized_resource = materialized_resource + self._materializer = materializer + self._output_classes = output_classes + self._output_types = output_types + self._output_shapes = output_shapes + + @property + def initializer(self): + if self._materializer is not None: + return self._materializer + raise ValueError("MaterializedDataset does not have a materializer") + + def get(self, index): + """Get retrieves a value (or set of values) from the IndexedDataset. + + Args: + index: A uint64 scalar or vector tensor with the indices to retrieve. + + Returns: + A tensor containing the values corresponding to `index`. + """ + # TODO(saeta): nest.pack_sequence_as(...) + return ged_ops.experimental_indexed_dataset_get( + self._materialized_resource, + index, + output_types=nest.flatten( + sparse.as_dense_types(self._output_types, self._output_classes)), + output_shapes=nest.flatten( + sparse.as_dense_types(self._output_shapes, self._output_classes))) + + +class IndexedDataset(dataset_ops.Dataset): + """IndexedDataset is highly experimental! + """ + + def __init__(self): + pass + + def materialize(self, shared_name=None, container=None): + """Materialize creates a MaterializedIndexedDataset. + + IndexedDatasets can be combined through operations such as TBD. Therefore, + they are only materialized when absolutely required. + + Args: + shared_name: a string for the shared name to use for the resource. + container: a string for the container to store the resource. + + Returns: + A MaterializedIndexedDataset. + """ + if container is None: + container = "" + if shared_name is None: + shared_name = "" + materialized_resource = ( + ged_ops.experimental_materialized_index_dataset_handle( + container=container, + shared_name=shared_name, + output_types=nest.flatten( + sparse.as_dense_types(self.output_types, self.output_classes)), + output_shapes=nest.flatten( + sparse.as_dense_types(self.output_shapes, + self.output_classes)))) + + with ops.colocate_with(materialized_resource): + materializer = ged_ops.experimental_indexed_dataset_materialize( + self._as_variant_tensor(), materialized_resource) + return MaterializedIndexedDataset(materialized_resource, materializer, + self.output_classes, self.output_types, + self.output_shapes) + + @abc.abstractproperty + def output_types(self): + """Returns the type of each component of an element of this IndexedDataset. + + Returns: + A nested structure of `tf.DType` objects corresponding to each component + of an element of this IndexedDataset. + """ + raise NotImplementedError("IndexedDataset.output_types") + + @abc.abstractproperty + def output_classes(self): + """Returns the class of each component of an element of this IndexedDataset. + + The expected values are `tf.Tensor` and `tf.SparseTensor`. + + Returns: + A nested structure of Python `type` objects corresponding to each + component of an element of this IndexedDataset. + """ + raise NotImplementedError("IndexedDataset.output_classes") + + @abc.abstractproperty + def output_shapes(self): + """Returns the shape of each component of an element of this IndexedDataset. + + Returns: + A nested structure of `tf.TensorShape` objects corresponding to each + component of an element of this IndexedDataset. + """ + raise NotImplementedError("IndexedDataset.output_shapes") + + @abc.abstractmethod + def _as_variant_tensor(self): + """Creates a `tf.variant` `tf.Tensor` representing this IndexedDataset. + + Returns: + A scalar `tf.Tensor` of `tf.variant` type, which represents this + IndexedDataset. + """ + raise NotImplementedError("IndexedDataset._as_variant_tensor") + + +class IdentityIndexedDataset(IndexedDataset): + """IdentityIndexedDataset is a trivial indexed dataset used for testing. + """ + + def __init__(self, size): + super(IdentityIndexedDataset, self).__init__() + # TODO(saeta): Verify _size is a scalar! + self._size = ops.convert_to_tensor(size, dtype=dtypes.uint64, name="size") + + @property + def output_types(self): + return dtypes.uint64 + + @property + def output_classes(self): + return ops.Tensor + + @property + def output_shapes(self): + return tensor_shape.scalar() + + def _as_variant_tensor(self): + return ged_ops.experimental_identity_indexed_dataset(self._size) + + def _inputs(self): + return [] diff --git a/tensorflow/python/data/experimental/ops/interleave_ops.py b/tensorflow/python/data/experimental/ops/interleave_ops.py new file mode 100644 index 0000000000..a3c094859e --- /dev/null +++ b/tensorflow/python/data/experimental/ops/interleave_ops.py @@ -0,0 +1,262 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Non-deterministic dataset transformations.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.experimental.ops import random_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import readers +from tensorflow.python.data.util import nest +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_experimental_dataset_ops +from tensorflow.python.ops import gen_stateless_random_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.util.tf_export import tf_export + + +@tf_export("data.experimental.parallel_interleave") +def parallel_interleave(map_func, + cycle_length, + block_length=1, + sloppy=False, + buffer_output_elements=None, + prefetch_input_elements=None): + """A parallel version of the `Dataset.interleave()` transformation. + + `parallel_interleave()` maps `map_func` across its input to produce nested + datasets, and outputs their elements interleaved. Unlike + `tf.data.Dataset.interleave`, it gets elements from `cycle_length` nested + datasets in parallel, which increases the throughput, especially in the + presence of stragglers. Furthermore, the `sloppy` argument can be used to + improve performance, by relaxing the requirement that the outputs are produced + in a deterministic order, and allowing the implementation to skip over nested + datasets whose elements are not readily available when requested. + + Example usage: + + ```python + # Preprocess 4 files concurrently. + filenames = tf.data.Dataset.list_files("/path/to/data/train*.tfrecords") + dataset = filenames.apply( + tf.data.experimental.parallel_interleave( + lambda filename: tf.data.TFRecordDataset(filename), + cycle_length=4)) + ``` + + WARNING: If `sloppy` is `True`, the order of produced elements is not + deterministic. + + Args: + map_func: A function mapping a nested structure of tensors to a `Dataset`. + cycle_length: The number of input `Dataset`s to interleave from in parallel. + block_length: The number of consecutive elements to pull from an input + `Dataset` before advancing to the next input `Dataset`. + sloppy: If false, elements are produced in deterministic order. Otherwise, + the implementation is allowed, for the sake of expediency, to produce + elements in a non-deterministic order. + buffer_output_elements: The number of elements each iterator being + interleaved should buffer (similar to the `.prefetch()` transformation for + each interleaved iterator). + prefetch_input_elements: The number of input elements to transform to + iterators before they are needed for interleaving. + + Returns: + A `Dataset` transformation function, which can be passed to + `tf.data.Dataset.apply`. + """ + def _apply_fn(dataset): + return readers.ParallelInterleaveDataset( + dataset, map_func, cycle_length, block_length, sloppy, + buffer_output_elements, prefetch_input_elements) + + return _apply_fn + + +class _DirectedInterleaveDataset(dataset_ops.Dataset): + """A substitute for `Dataset.interleave()` on a fixed list of datasets.""" + + def __init__(self, selector_input, data_inputs): + self._selector_input = selector_input + self._data_inputs = list(data_inputs) + + for data_input in data_inputs[1:]: + if (data_input.output_types != data_inputs[0].output_types or + data_input.output_classes != data_inputs[0].output_classes): + raise TypeError("All datasets must have the same type and class.") + + def _as_variant_tensor(self): + # pylint: disable=protected-access + return ( + gen_experimental_dataset_ops.experimental_directed_interleave_dataset( + self._selector_input._as_variant_tensor(), [ + data_input._as_variant_tensor() + for data_input in self._data_inputs + ], **dataset_ops.flat_structure(self))) + # pylint: enable=protected-access + + def _inputs(self): + return [self._selector_input] + self._data_inputs + + @property + def output_classes(self): + return self._data_inputs[0].output_classes + + @property + def output_shapes(self): + ret = self._data_inputs[0].output_shapes + for data_input in self._data_inputs[1:]: + ret = nest.pack_sequence_as(ret, [ + ts1.most_specific_compatible_shape(ts2) for (ts1, ts2) in zip( + nest.flatten(ret), nest.flatten(data_input.output_shapes)) + ]) + return ret + + @property + def output_types(self): + return self._data_inputs[0].output_types + + +@tf_export("data.experimental.sample_from_datasets") +def sample_from_datasets(datasets, weights=None, seed=None): + """Samples elements at random from the datasets in `datasets`. + + Args: + datasets: A list of `tf.data.Dataset` objects with compatible structure. + weights: (Optional.) A list of `len(datasets)` floating-point values where + `weights[i]` represents the probability with which an element should be + sampled from `datasets[i]`, or a `tf.data.Dataset` object where each + element is such a list. Defaults to a uniform distribution across + `datasets`. + seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the + random seed that will be used to create the distribution. See + `tf.set_random_seed` for behavior. + + Returns: + A dataset that interleaves elements from `datasets` at random, according to + `weights` if provided, otherwise with uniform probability. + + Raises: + TypeError: If the `datasets` or `weights` arguments have the wrong type. + ValueError: If the `weights` argument is specified and does not match the + length of the `datasets` element. + """ + num_datasets = len(datasets) + if not isinstance(weights, dataset_ops.Dataset): + if weights is None: + # Select inputs with uniform probability. + logits = [[1.0] * num_datasets] + + else: + # Use the given `weights` as the probability of choosing the respective + # input. + weights = ops.convert_to_tensor(weights, name="weights") + if weights.dtype not in (dtypes.float32, dtypes.float64): + raise TypeError("`weights` must be convertible to a tensor of " + "`tf.float32` or `tf.float64` elements.") + if not weights.shape.is_compatible_with([num_datasets]): + raise ValueError( + "`weights` must be a vector of length `len(datasets)`.") + + # The `stateless_multinomial()` op expects log-probabilities, as opposed + # to weights. + logits = array_ops.expand_dims(math_ops.log(weights, name="logits"), 0) + + # NOTE(mrry): We only specialize when `weights` is not a `Dataset`. When it + # is a `Dataset`, it is possible that evaluating it has a side effect the + # user depends on. + if len(datasets) == 1: + return datasets[0] + + def select_dataset_constant_logits(seed): + return array_ops.squeeze( + gen_stateless_random_ops.stateless_multinomial(logits, 1, seed=seed), + axis=[0, 1]) + + selector_input = dataset_ops.MapDataset( + random_ops.RandomDataset(seed).batch(2), + select_dataset_constant_logits, + use_inter_op_parallelism=False) + + else: + # Use each element of the given `weights` dataset as the probability of + # choosing the respective input. + + # The `stateless_multinomial()` op expects log-probabilities, as opposed to + # weights. + logits_ds = weights.map(lambda *p: math_ops.log(p, name="logits")) + + def select_dataset_varying_logits(logits, seed): + return array_ops.squeeze( + gen_stateless_random_ops.stateless_multinomial(logits, 1, seed=seed), + axis=[0, 1]) + + logits_and_seeds = dataset_ops.Dataset.zip( + (logits_ds, random_ops.RandomDataset(seed).batch(2))) + selector_input = dataset_ops.MapDataset( + logits_and_seeds, + select_dataset_varying_logits, + use_inter_op_parallelism=False) + + return _DirectedInterleaveDataset(selector_input, datasets) + + +@tf_export("data.experimental.choose_from_datasets") +def choose_from_datasets(datasets, choice_dataset): + """Creates a dataset that deterministically chooses elements from `datasets`. + + For example, given the following datasets: + + ```python + datasets = [tf.data.Dataset.from_tensors("foo").repeat(), + tf.data.Dataset.from_tensors("bar").repeat(), + tf.data.Dataset.from_tensors("baz").repeat()] + + # Define a dataset containing `[0, 1, 2, 0, 1, 2, 0, 1, 2]`. + choice_dataset = tf.data.Dataset.range(3).repeat(3) + + result = tf.data.experimental.choose_from_datasets(datasets, choice_dataset) + ``` + + The elements of `result` will be: + + ``` + "foo", "bar", "baz", "foo", "bar", "baz", "foo", "bar", "baz" + ``` + + Args: + datasets: A list of `tf.data.Dataset` objects with compatible structure. + choice_dataset: A `tf.data.Dataset` of scalar `tf.int64` tensors between + `0` and `len(datasets) - 1`. + + Returns: + A dataset that interleaves elements from `datasets` according to the values + of `choice_dataset`. + + Raises: + TypeError: If the `datasets` or `choice_dataset` arguments have the wrong + type. + """ + if not (choice_dataset.output_types == dtypes.int64 + and choice_dataset.output_shapes.is_compatible_with( + tensor_shape.scalar()) + and choice_dataset.output_classes == ops.Tensor): + raise TypeError("`choice_dataset` must be a dataset of scalar " + "`tf.int64` tensors.") + return _DirectedInterleaveDataset(choice_dataset, datasets) diff --git a/tensorflow/python/data/experimental/ops/iterator_ops.py b/tensorflow/python/data/experimental/ops/iterator_ops.py new file mode 100644 index 0000000000..72d7d58f06 --- /dev/null +++ b/tensorflow/python/data/experimental/ops/iterator_ops.py @@ -0,0 +1,268 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Iterator ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from tensorflow.python.data.ops import iterator_ops +from tensorflow.python.data.ops import optional_ops +from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.training import basic_session_run_hooks +from tensorflow.python.training import checkpoint_management +from tensorflow.python.training import saver as saver_lib +from tensorflow.python.training import session_run_hook +from tensorflow.python.util.tf_export import tf_export + + +@tf_export("data.experimental.make_saveable_from_iterator") +def make_saveable_from_iterator(iterator): + """Returns a SaveableObject for saving/restore iterator state using Saver. + + Args: + iterator: Iterator. + + For example: + + ```python + with tf.Graph().as_default(): + ds = tf.data.Dataset.range(10) + iterator = ds.make_initializable_iterator() + # Build the iterator SaveableObject. + saveable_obj = tf.data.experimental.make_saveable_from_iterator(iterator) + # Add the SaveableObject to the SAVEABLE_OBJECTS collection so + # it can be automatically saved using Saver. + tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, saveable_obj) + saver = tf.train.Saver() + + while continue_training: + ... Perform training ... + if should_save_checkpoint: + saver.save() + ``` + + Note: When restoring the iterator, the existing iterator state is completely + discarded. This means that any changes you may have made to the Dataset + graph will be discarded as well! This includes the new Dataset graph + that you may have built during validation. So, while running validation, + make sure to run the initializer for the validation input pipeline after + restoring the checkpoint. + + Note: Not all iterators support checkpointing yet. Attempting to save the + state of an unsupported iterator will throw an error. + """ + return _Saveable(iterator._iterator_resource) # pylint: disable=protected-access + + +class _Saveable(saver_lib.BaseSaverBuilder.SaveableObject): + """SaveableObject for saving/restoring iterator state.""" + + def __init__(self, iterator_resource): + serialized_iterator = gen_dataset_ops.serialize_iterator(iterator_resource) + specs = [ + saver_lib.BaseSaverBuilder.SaveSpec(serialized_iterator, "", + iterator_resource.name + "-state") + ] + super(_Saveable, self).__init__(iterator_resource, specs, + iterator_resource.name) + + def restore(self, restored_tensors, unused_restored_shapes): + with ops.colocate_with(self.op): + return gen_dataset_ops.deserialize_iterator(self.op, restored_tensors[0]) + + +@tf_export("data.experimental.CheckpointInputPipelineHook") +class CheckpointInputPipelineHook(session_run_hook.SessionRunHook): + """Checkpoints input pipeline state every N steps or seconds. + + This hook saves the state of the iterators in the `Graph` so that when + training is resumed the input pipeline continues from where it left off. + This could potentially avoid overfitting in certain pipelines where the + number of training steps per eval are small compared to the dataset + size or if the training pipeline is pre-empted. + + Differences from `CheckpointSaverHook`: + 1. Saves only the input pipelines in the "iterators" collection and not the + global variables or other saveable objects. + 2. Does not write the `GraphDef` and `MetaGraphDef` to the summary. + + Example of checkpointing the training pipeline: + + ```python + est = tf.estimator.Estimator(model_fn) + while True: + est.train( + train_input_fn, + hooks=[tf.data.experimental.CheckpointInputPipelineHook(est)], + steps=train_steps_per_eval) + # Note: We do not pass the hook here. + metrics = est.evaluate(eval_input_fn) + if should_stop_the_training(metrics): + break + ``` + + This hook should be used if the input pipeline state needs to be saved + separate from the model checkpoint. Doing so may be useful for a few reasons: + 1. The input pipeline checkpoint may be large, if there are large shuffle + or prefetch buffers for instance, and may bloat the checkpoint size. + 2. If the input pipeline is shared between training and validation, restoring + the checkpoint during validation may override the validation input + pipeline. + + For saving the input pipeline checkpoint alongside the model weights use + `tf.data.experimental.make_saveable_from_iterator` directly to create a + `SaveableObject` and add to the `SAVEABLE_OBJECTS` collection. Note, however, + that you will need to be careful not to restore the training iterator during + eval. You can do that by not adding the iterator to the SAVEABLE_OBJECTS + collector when building the eval graph. + """ + + def __init__(self, estimator): + """Initializes a `CheckpointInputPipelineHook`. + + Args: + estimator: Estimator. + + Raises: + ValueError: One of `save_steps` or `save_secs` should be set. + ValueError: At most one of saver or scaffold should be set. + """ + # `checkpoint_basename` is "input.ckpt" for non-distributed pipelines or + # of the form "input_<task_type>_<task_id>.ckpt" for distributed pipelines. + # Note: The default `checkpoint_basename` used by `CheckpointSaverHook` is + # "model.ckpt". We intentionally choose the input pipeline checkpoint prefix + # to be different to avoid conflicts with the model checkpoint. + + # pylint: disable=protected-access + checkpoint_prefix = "input" + if estimator._config.num_worker_replicas > 1: + # Distributed setting. + suffix = "_{}_{}".format(estimator._config.task_type, + estimator._config.task_id) + checkpoint_prefix += suffix + # pylint: enable=protected-access + + # We use a composition paradigm instead of inheriting from + # `CheckpointSaverHook` because `Estimator` does an `isinstance` check + # to check whether a `CheckpointSaverHook` is already present in the list + # of hooks and if not, adds one. Inheriting from `CheckpointSaverHook` + # would thwart this behavior. This hook checkpoints *only the iterators* + # and not the graph variables. + self._checkpoint_saver_hook = basic_session_run_hooks.CheckpointSaverHook( + estimator.model_dir, + save_secs=estimator._config.save_checkpoints_secs, # pylint: disable=protected-access + save_steps=estimator._config.save_checkpoints_steps, # pylint: disable=protected-access + checkpoint_basename=checkpoint_prefix + ".ckpt") + + # Name for the protocol buffer file that will contain the list of most + # recent checkpoints stored as a `CheckpointState` protocol buffer. + # This file, kept in the same directory as the checkpoint files, is + # automatically managed by the `Saver` to keep track of recent checkpoints. + # The default name used by the `Saver` for this file is "checkpoint". Here + # we use the name "checkpoint_<checkpoint_prefix>" so that in case the + # `checkpoint_dir` is the same as the model checkpoint directory, there are + # no conflicts during restore. + self._latest_filename = "checkpoint_" + checkpoint_prefix + self._first_run = True + + def begin(self): + # Build a Saver that saves all iterators in the `GLOBAL_ITERATORS` + # collection if no `Saver` or `Scaffold` is provided. + # pylint: disable=protected-access + if (self._checkpoint_saver_hook._saver is None and + self._checkpoint_saver_hook._scaffold is None): + iterators = ops.get_collection(iterator_ops.GLOBAL_ITERATORS) + saveables = [_Saveable(i) for i in iterators] + self._checkpoint_saver_hook._saver = _CustomSaver(saveables, + self._latest_filename) + # pylint: enable=protected-access + self._checkpoint_saver_hook.begin() + + def _restore_or_save_initial_ckpt(self, session): + # Ideally this should be run in after_create_session but is not for the + # following reason: + # Currently there is no way of enforcing an order of running the + # `SessionRunHooks`. Hence it is possible that the `_DatasetInitializerHook` + # is run *after* this hook. That is troublesome because + # 1. If a checkpoint exists and this hook restores it, the initializer hook + # will override it. + # 2. If no checkpoint exists, this hook will try to save an initialized + # iterator which will result in an exception. + # + # As a temporary fix we enter the following implicit contract between this + # hook and the _DatasetInitializerHook. + # 1. The _DatasetInitializerHook initializes the iterator in the call to + # after_create_session. + # 2. This hook saves the iterator on the first call to `before_run()`, which + # is guaranteed to happen after `after_create_session()` of all hooks + # have been run. + + # Check if there is an existing checkpoint. If so, restore from it. + # pylint: disable=protected-access + latest_checkpoint_path = checkpoint_management.latest_checkpoint( + self._checkpoint_saver_hook._checkpoint_dir, + latest_filename=self._latest_filename) + if latest_checkpoint_path: + self._checkpoint_saver_hook._get_saver().restore(session, + latest_checkpoint_path) + else: + # The checkpoint saved here is the state at step "global_step". + # Note: We do not save the GraphDef or MetaGraphDef here. + global_step = session.run(self._checkpoint_saver_hook._global_step_tensor) + self._checkpoint_saver_hook._save(session, global_step) + self._checkpoint_saver_hook._timer.update_last_triggered_step(global_step) + # pylint: enable=protected-access + + def before_run(self, run_context): + if self._first_run: + self._restore_or_save_initial_ckpt(run_context.session) + self._first_run = False + return self._checkpoint_saver_hook.before_run(run_context) + + def after_run(self, run_context, run_values): + self._checkpoint_saver_hook.after_run(run_context, run_values) + + def end(self, session): + self._checkpoint_saver_hook.end(session) + + +class _CustomSaver(saver_lib.Saver): + """`Saver` with a different default `latest_filename`. + + This is used in the `CheckpointInputPipelineHook` to avoid conflicts with + the model ckpt saved by the `CheckpointSaverHook`. + """ + + def __init__(self, var_list, latest_filename): + super(_CustomSaver, self).__init__(var_list) + self._latest_filename = latest_filename + + def save(self, + sess, + save_path, + global_step=None, + latest_filename=None, + meta_graph_suffix="meta", + write_meta_graph=True, + write_state=True, + strip_default_attrs=False): + return super(_CustomSaver, self).save( + sess, save_path, global_step, latest_filename or self._latest_filename, + meta_graph_suffix, write_meta_graph, write_state, strip_default_attrs) + + +tf_export("data.experimental.Optional")(optional_ops.Optional) +tf_export("data.experimental.get_next_as_optional")( + iterator_ops.get_next_as_optional) diff --git a/tensorflow/python/data/experimental/ops/map_defun.py b/tensorflow/python/data/experimental/ops/map_defun.py new file mode 100644 index 0000000000..3d0d0993c9 --- /dev/null +++ b/tensorflow/python/data/experimental/ops/map_defun.py @@ -0,0 +1,56 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Experimental API for optimizing `tf.data` pipelines.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import gen_dataset_ops + + +def map_defun(fn, elems, output_dtypes, output_shapes): + """Map a function on the list of tensors unpacked from `elems` on dimension 0. + + Args: + fn: A function (`function.Defun`) that takes a list of tensors and returns + another list of tensors. The output list has the same types as + output_dtypes. The elements of the output list have the same dimension 0 + as `elems`, and the remaining dimensions correspond to those of + `fn_output_shapes`. + elems: A list of tensors. + output_dtypes: A list of dtypes corresponding to the output types of the + function. + output_shapes: A list of `TensorShape`s corresponding to the output + shapes from each invocation of the function on slices of inputs. + + Raises: + ValueError: if any of the inputs are malformed. + + Returns: + A list of `Tensor` objects with the same types as `output_dtypes`. + """ + if not isinstance(elems, list): + raise ValueError("`elems` must be a list of tensors.") + if not isinstance(output_dtypes, list): + raise ValueError("`output_dtypes` must be a list of tensors.") + if not isinstance(output_shapes, list): + raise ValueError("`output_shapes` must be a list of tensors.") + + elems = [ops.convert_to_tensor(e) for e in elems] + output_shapes = [tensor_shape.TensorShape(s) for s in output_shapes] + return gen_dataset_ops.map_defun(elems, output_dtypes, output_shapes, fn) diff --git a/tensorflow/python/data/experimental/ops/optimization.py b/tensorflow/python/data/experimental/ops/optimization.py new file mode 100644 index 0000000000..30348ede36 --- /dev/null +++ b/tensorflow/python/data/experimental/ops/optimization.py @@ -0,0 +1,171 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Experimental API for optimizing `tf.data` pipelines.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.ops import gen_experimental_dataset_ops + +# A constant that can be used to enable auto-tuning. +AUTOTUNE = -1 + + +# TODO(jsimsa): Support RE matching for both individual transformation (e.g. to +# account for indexing) and transformation sequence. +def assert_next(transformations): + """A transformation that asserts which transformations happen next. + + Args: + transformations: A `tf.string` vector `tf.Tensor` identifying the + transformations that are expected to happen next. + + Returns: + A `Dataset` transformation function, which can be passed to + `tf.data.Dataset.apply`. + """ + + def _apply_fn(dataset): + """Function from `Dataset` to `Dataset` that applies the transformation.""" + return _AssertNextDataset(dataset, transformations) + + return _apply_fn + + +def model(): + """A transformation that models performance. + + Returns: + A `Dataset` transformation function, which can be passed to + `tf.data.Dataset.apply`. + """ + + def _apply_fn(dataset): + """Function from `Dataset` to `Dataset` that applies the transformation.""" + return _ModelDataset(dataset) + + return _apply_fn + + +def optimize(optimizations=None): + """A transformation that applies optimizations. + + Args: + optimizations: (Optional.) A `tf.string` vector `tf.Tensor` identifying + optimizations to use. If not specified, the default set of optimizations + is applied. + + Returns: + A `Dataset` transformation function, which can be passed to + `tf.data.Dataset.apply`. + """ + + def _apply_fn(dataset): + """Function from `Dataset` to `Dataset` that applies the transformation.""" + return _OptimizeDataset(dataset, optimizations) + + return _apply_fn + + +class _AssertNextDataset(dataset_ops.UnaryDataset): + """A `Dataset` that asserts which transformations happen next.""" + + def __init__(self, input_dataset, transformations): + """See `assert_next()` for details.""" + super(_AssertNextDataset, self).__init__(input_dataset) + self._input_dataset = input_dataset + if transformations is None: + raise ValueError("At least one transformation should be specified") + self._transformations = ops.convert_to_tensor( + transformations, dtype=dtypes.string, name="transformations") + + def _as_variant_tensor(self): + return gen_experimental_dataset_ops.experimental_assert_next_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + self._transformations, + **dataset_ops.flat_structure(self)) + + @property + def output_classes(self): + return self._input_dataset.output_classes + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_types(self): + return self._input_dataset.output_types + + +class _ModelDataset(dataset_ops.UnaryDataset): + """A `Dataset` that acts as an identity, and models performance.""" + + def __init__(self, input_dataset): + """See `optimize()` for details.""" + super(_ModelDataset, self).__init__(input_dataset) + self._input_dataset = input_dataset + + def _as_variant_tensor(self): + return gen_dataset_ops.model_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + **dataset_ops.flat_structure(self)) + + @property + def output_classes(self): + return self._input_dataset.output_classes + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_types(self): + return self._input_dataset.output_types + + +class _OptimizeDataset(dataset_ops.UnaryDataset): + """A `Dataset` that acts as an identity, and applies optimizations.""" + + def __init__(self, input_dataset, optimizations): + """See `optimize()` for details.""" + super(_OptimizeDataset, self).__init__(input_dataset) + self._input_dataset = input_dataset + if optimizations is None: + optimizations = [] + self._optimizations = ops.convert_to_tensor( + optimizations, dtype=dtypes.string, name="optimizations") + + def _as_variant_tensor(self): + return gen_dataset_ops.optimize_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + self._optimizations, + **dataset_ops.flat_structure(self)) + + @property + def output_classes(self): + return self._input_dataset.output_classes + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_types(self): + return self._input_dataset.output_types diff --git a/tensorflow/python/data/experimental/ops/parsing_ops.py b/tensorflow/python/data/experimental/ops/parsing_ops.py new file mode 100644 index 0000000000..6615b9022a --- /dev/null +++ b/tensorflow/python/data/experimental/ops/parsing_ops.py @@ -0,0 +1,152 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Experimental `dataset` API for parsing example.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import nest +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.ops import parsing_ops +from tensorflow.python.util.tf_export import tf_export + + +class _ParseExampleDataset(dataset_ops.UnaryDataset): + """A `Dataset` that parses `example` dataset into a `dict` dataset.""" + + def __init__(self, input_dataset, features, num_parallel_calls): + super(_ParseExampleDataset, self).__init__(input_dataset) + self._input_dataset = input_dataset + if not all(types == dtypes.string + for types in nest.flatten(input_dataset.output_types)): + raise TypeError("Input dataset should be a dataset of vectors of strings") + self._num_parallel_calls = num_parallel_calls + # pylint: disable=protected-access + self._features = parsing_ops._prepend_none_dimension(features) + # sparse_keys and dense_keys come back sorted here. + (sparse_keys, sparse_types, dense_keys, dense_types, dense_defaults, + dense_shapes) = parsing_ops._features_to_raw_params( + self._features, [ + parsing_ops.VarLenFeature, parsing_ops.SparseFeature, + parsing_ops.FixedLenFeature, parsing_ops.FixedLenSequenceFeature + ]) + # TODO(b/112859642): Pass sparse_index and sparse_values for SparseFeature. + (_, dense_defaults_vec, sparse_keys, sparse_types, dense_keys, dense_shapes, + dense_shape_as_shape) = parsing_ops._process_raw_parameters( + None, dense_defaults, sparse_keys, sparse_types, dense_keys, + dense_types, dense_shapes) + # pylint: enable=protected-access + self._sparse_keys = sparse_keys + self._sparse_types = sparse_types + self._dense_keys = dense_keys + self._dense_defaults = dense_defaults_vec + self._dense_shapes = dense_shapes + self._dense_types = dense_types + dense_output_shapes = [ + self._input_dataset.output_shapes.concatenate(shape) + for shape in dense_shape_as_shape + ] + sparse_output_shapes = [ + self._input_dataset.output_shapes.concatenate([None]) + for _ in range(len(sparse_keys)) + ] + + self._output_shapes = dict( + zip(self._dense_keys + self._sparse_keys, + dense_output_shapes + sparse_output_shapes)) + self._output_types = dict( + zip(self._dense_keys + self._sparse_keys, + self._dense_types + self._sparse_types)) + self._output_classes = dict( + zip(self._dense_keys + self._sparse_keys, + [ops.Tensor for _ in range(len(self._dense_defaults))] + + [sparse_tensor.SparseTensor for _ in range(len(self._sparse_keys)) + ])) + + def _as_variant_tensor(self): + return gen_dataset_ops.parse_example_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + self._num_parallel_calls, + self._dense_defaults, + self._sparse_keys, + self._dense_keys, + self._sparse_types, + self._dense_shapes, + **dataset_ops.flat_structure(self)) + + @property + def output_shapes(self): + return self._output_shapes + + @property + def output_types(self): + return self._output_types + + @property + def output_classes(self): + return self._output_classes + + +# TODO(b/111553342): add arguments names and example names as well. +@tf_export("data.experimental.parse_example_dataset") +def parse_example_dataset(features, num_parallel_calls=1): + """A transformation that parses `Example` protos into a `dict` of tensors. + + Parses a number of serialized `Example` protos given in `serialized`. We refer + to `serialized` as a batch with `batch_size` many entries of individual + `Example` protos. + + This op parses serialized examples into a dictionary mapping keys to `Tensor` + and `SparseTensor` objects. `features` is a dict from keys to `VarLenFeature`, + `SparseFeature`, and `FixedLenFeature` objects. Each `VarLenFeature` + and `SparseFeature` is mapped to a `SparseTensor`, and each + `FixedLenFeature` is mapped to a `Tensor`. See `tf.parse_example` for more + details about feature dictionaries. + + Args: + features: A `dict` mapping feature keys to `FixedLenFeature`, + `VarLenFeature`, and `SparseFeature` values. + num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`, + representing the number of parsing processes to call in parallel. + + Returns: + A dataset transformation function, which can be passed to + `tf.data.Dataset.apply`. + + Raises: + ValueError: if features argument is None. + """ + if features is None: + raise ValueError("Missing: features was %s." % features) + + def _apply_fn(dataset): + """Function from `Dataset` to `Dataset` that applies the transformation.""" + out_dataset = _ParseExampleDataset(dataset, features, num_parallel_calls) + if any([ + isinstance(feature, parsing_ops.SparseFeature) + for _, feature in features.items() + ]): + # pylint: disable=protected-access + # pylint: disable=g-long-lambda + out_dataset = out_dataset.map( + lambda x: parsing_ops._construct_sparse_tensors_for_sparse_features( + features, x), num_parallel_calls=num_parallel_calls) + return out_dataset + + return _apply_fn diff --git a/tensorflow/python/data/experimental/ops/prefetching_ops.py b/tensorflow/python/data/experimental/ops/prefetching_ops.py new file mode 100644 index 0000000000..48d7136f95 --- /dev/null +++ b/tensorflow/python/data/experimental/ops/prefetching_ops.py @@ -0,0 +1,531 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python wrapper for prefetching_ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import warnings + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import iterator_ops +from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse +from tensorflow.python.eager import context +from tensorflow.python.framework import device as framework_device +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import functional_ops +from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.util.tf_export import tf_export + + +def function_buffering_resource(string_arg, + target_device, + f, + buffer_size, + output_types, + container="", + shared_name=None, + name=None): + """Creates a FunctionBufferingResource. + + A FunctionBufferingResource fills up a buffer by calling a function `f` on + `target_device`. `f` should take in only a single string argument as input. + + Args: + string_arg: The single string argument to the function. + target_device: The device to run `f` on. + f: The function to be executed. + buffer_size: Size of the buffer to be populated. + output_types: The output types generated by the function. + container: (Optional) string. Defaults to "". + shared_name: (Optional) string. + name: (Optional) string to name the op. + + Returns: + Handle to a FunctionBufferingResource. + """ + if shared_name is None: + shared_name = "" + return ged_ops.experimental_function_buffering_resource( + string_arg=string_arg, + target_device=target_device, + shared_name=shared_name, + f=f, + buffer_size=buffer_size, + container=container, + name=name, + output_types=output_types) + + +def function_buffering_resource_get_next(function_buffer_resource, + output_types, + name=None): + return ged_ops.experimental_function_buffering_resource_get_next( + function_buffer_resource=function_buffer_resource, + output_types=output_types, + name=name) + + +def function_buffering_resource_reset(function_buffer_resource, name=None): + return ged_ops.experimental_function_buffering_resource_reset( + function_buffer_resource=function_buffer_resource, name=name) + + +# pylint: disable=protected-access +class _PrefetchToDeviceIterator(object): + """A replacement for `tf.data.Iterator` that prefetches to another device. + + Args: + input_dataset: The input dataset + one_shot: If true, we make a one shot iterator that's already initialized. + device: A fully specified device string where we want to prefetch to + buffer_size: Size of the prefetching buffer. + shared_name: (Optional.) If non-empty, the returned iterator will be + shared under the given name across multiple sessions that share the + same devices (e.g. when using a remote server). + + Returns: + An Iterator type object. + """ + + def __init__(self, + input_dataset, + one_shot, + device, + buffer_size, + shared_name=None): + self._input_dataset = input_dataset + self._get_next_call_count = 0 + self._one_shot = one_shot + if shared_name is None: + shared_name = "" + + if self._one_shot: + self._input_iterator = input_dataset.make_one_shot_iterator() + else: + self._input_iterator = iterator_ops.Iterator.from_structure( + self._input_dataset.output_types, self._input_dataset.output_shapes, + shared_name, self._input_dataset.output_classes) + input_iterator_handle = self._input_iterator.string_handle() + + @function.Defun(dtypes.string) + def _prefetch_fn(handle): + """Prefetches one element from `input_iterator`.""" + remote_iterator = iterator_ops.Iterator.from_string_handle( + handle, self._input_iterator.output_types, + self._input_iterator.output_shapes, + self._input_iterator.output_classes) + ret = remote_iterator.get_next() + return nest.flatten(sparse.serialize_sparse_tensors(ret)) + + iterator_device = ged_ops.experimental_iterator_get_device( + self._input_iterator._iterator_resource) + + with ops.device(device): + self._buffering_resource = function_buffering_resource( + f=_prefetch_fn, + target_device=iterator_device, + string_arg=input_iterator_handle, + buffer_size=buffer_size, + shared_name=shared_name, + output_types=nest.flatten( + sparse.as_dense_types(self._input_dataset.output_types, + self._input_dataset.output_classes))) + + if not self._one_shot: + reset_op = function_buffering_resource_reset(self._buffering_resource) + with ops.control_dependencies([reset_op]): + self._initializer = self._input_iterator.make_initializer( + self._input_dataset) + + def get_next(self, name=None): + """See `tf.data.Iterator.get_next`.""" + self._get_next_call_count += 1 + if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD: + warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE) + + flat_ret = ged_ops.experimental_function_buffering_resource_get_next( + self._buffering_resource, + output_types=nest.flatten( + sparse.as_dense_types(self.output_types, self.output_classes)), + name=name) + + ret = sparse.deserialize_sparse_tensors( + nest.pack_sequence_as(self.output_types, flat_ret), + self.output_types, self.output_shapes, self.output_classes) + + for tensor, shape in zip( + nest.flatten(ret), nest.flatten(self.output_shapes)): + if isinstance(tensor, ops.Tensor): + tensor.set_shape(shape) + + return ret + + @property + def initializer(self): + if self._one_shot: + raise NotImplementedError("Can't initialize a one_shot_iterator") + return self._initializer + + @property + def output_classes(self): + return self._input_dataset.output_classes + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_types(self): + return self._input_dataset.output_types + + +class _PrefetchToDeviceEagerIterator(iterator_ops.EagerIterator): + """A replacement for `tf.data.Iterator` that prefetches to another device. + + Args: + input_dataset: The input dataset + one_shot: If true, we make a one shot iterator that's already initialized. + device: A fully specified device string where we want to prefetch to + buffer_size: Size of the prefetching buffer. + shared_name: (Optional.) If non-empty, the returned iterator will be + shared under the given name across multiple sessions that share the + same devices (e.g. when using a remote server). + + Returns: + An Iterator type object. + """ + + def __init__(self, + input_dataset, + device, + buffer_size): + with ops.device("/device:CPU:0"): + super(_PrefetchToDeviceEagerIterator, self).__init__(input_dataset) + input_iterator_handle = gen_dataset_ops.iterator_to_string_handle( + self._resource) + + self._device = device + + @function.Defun(dtypes.string) + def _prefetch_fn(handle): + """Prefetches one element from `input_iterator`.""" + remote_iterator = iterator_ops.Iterator.from_string_handle( + handle, self.output_types, self.output_shapes, self.output_classes) + ret = remote_iterator.get_next() + return nest.flatten(sparse.serialize_sparse_tensors(ret)) + + _prefetch_fn.add_to_graph(None) + + with ops.device(device): + self._buffering_resource = function_buffering_resource( + f=_prefetch_fn, + output_types=self._flat_output_types, + target_device=ged_ops.experimental_iterator_get_device( + self._resource), + string_arg=input_iterator_handle, + buffer_size=buffer_size, + shared_name=iterator_ops._generate_shared_name( + "function_buffer_resource")) + + def _next_internal(self): + """Returns a nested structure of `tf.Tensor`s containing the next element. + """ + # This runs in sync mode as iterators use an error status to communicate + # that there is no more data to iterate over. + # TODO(b/77291417): Fix + with context.execution_mode(context.SYNC): + with ops.device(self._device): + ret = ged_ops.experimental_function_buffering_resource_get_next( + function_buffer_resource=self._buffering_resource, + output_types=self._flat_output_types) + return sparse.deserialize_sparse_tensors( + nest.pack_sequence_as(self._output_types, ret), self._output_types, + self._output_shapes, self._output_classes) +# pylint: enable=protected-access + + +class _PrefetchToDeviceDataset(dataset_ops.UnaryDataset): + """A `Dataset` whose iterator prefetches elements to another device.""" + + def __init__(self, input_dataset, device, buffer_size): + super(_PrefetchToDeviceDataset, self).__init__(input_dataset) + self._input_dataset = input_dataset + self._device = device + self._buffer_size = buffer_size if buffer_size is not None else 1 + + # The static analysis cannot tell that the eager iterator's superclass has + # a `next()` method. + # pylint: disable=non-iterator-returned + def __iter__(self): + """Creates an `Iterator` for enumerating the elements of this dataset. + + The returned iterator implements the Python iterator protocol and therefore + can only be used in eager mode. + + Returns: + An `Iterator` over the elements of this dataset. + + Raises: + RuntimeError: If eager execution is enabled. + """ + if context.executing_eagerly(): + return _PrefetchToDeviceEagerIterator(self._input_dataset, self._device, + self._buffer_size) + else: + raise RuntimeError("dataset.__iter__() is only supported when eager " + "execution is enabled.") + # pylint: enable=non-iterator-returned + + def make_one_shot_iterator(self): + if context.executing_eagerly(): + return _PrefetchToDeviceEagerIterator(self._input_dataset, self._device, + self._buffer_size) + else: + return _PrefetchToDeviceIterator(self._input_dataset, one_shot=True, + device=self._device, + buffer_size=self._buffer_size) + + def make_initializable_iterator(self, shared_name=None): + return _PrefetchToDeviceIterator( + self._input_dataset, + one_shot=False, + device=self._device, + buffer_size=self._buffer_size, + shared_name=shared_name) + + def _as_variant_tensor(self): + # TODO(mrry): Raise this error earlier (e.g. when one of the Dataset + # transformation methods is called. + # TODO(mrry): Investigate support for chaining further transformations after + # the prefetch, including GPU support. + raise NotImplementedError("`prefetch_to_device()` must be the last " + "transformation in a dataset pipeline.") + + @property + def output_types(self): + return self._input_dataset.output_types + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_classes(self): + return self._input_dataset.output_classes + + +@tf_export("data.experimental.prefetch_to_device") +def prefetch_to_device(device, buffer_size=None): + """A transformation that prefetches dataset values to the given `device`. + + NOTE: Although the transformation creates a `tf.data.Dataset`, the + transformation must be the final `Dataset` in the input pipeline. + + Args: + device: A string. The name of a device to which elements will be prefetched. + buffer_size: (Optional.) The number of elements to buffer on `device`. + Defaults to an automatically chosen value. + + Returns: + A `Dataset` transformation function, which can be passed to + `tf.data.Dataset.apply`. + """ + def _apply_fn(dataset): + return _PrefetchToDeviceDataset(dataset, device, buffer_size) + + return _apply_fn + + +@tf_export("data.experimental.copy_to_device") +def copy_to_device(target_device, source_device="/cpu:0"): + """A transformation that copies dataset elements to the given `target_device`. + + Args: + target_device: The name of a device to which elements will be copied. + source_device: The original device on which `input_dataset` will be placed. + + Returns: + A `Dataset` transformation function, which can be passed to + `tf.data.Dataset.apply`. + """ + + def _apply_fn(dataset): + return _CopyToDeviceDataset( + dataset, target_device=target_device, source_device=source_device) + + return _apply_fn + + +# TODO(rohanj): Use the _input_hostmem attr on the RemoteCall ops to indicate +# all inputs to the Op are in host memory, thereby avoiding some unnecessary +# Sends and Recvs. +class _CopyToDeviceDataset(dataset_ops.UnaryDataset): + """A `Dataset` that copies elements to another device.""" + + def __init__(self, input_dataset, target_device, source_device="/cpu:0"): + """Constructs a _CopyToDeviceDataset. + + Args: + input_dataset: `Dataset` to be copied + target_device: The name of the device to which elements would be copied. + source_device: Device where input_dataset would be placed. + """ + super(_CopyToDeviceDataset, self).__init__(input_dataset) + self._input_dataset = input_dataset + self._target_device = target_device + spec = framework_device.DeviceSpec().from_string(self._target_device) + self._is_gpu_target = (spec.device_type == "GPU") + self._source_device_string = source_device + self._source_device = ops.convert_to_tensor(source_device) + + self._flat_output_shapes = nest.flatten( + sparse.as_dense_shapes(self._input_dataset.output_shapes, + self._input_dataset.output_classes)) + self._flat_output_types = nest.flatten( + sparse.as_dense_types(self._input_dataset.output_types, + self._input_dataset.output_classes)) + + @function.Defun() + def _init_func(): + """Creates an iterator for the input dataset. + + Returns: + A `string` tensor that encapsulates the iterator created. + """ + # pylint: disable=protected-access + ds_variant = self._input_dataset._as_variant_tensor() + resource = gen_dataset_ops.anonymous_iterator( + output_types=self._flat_output_types, + output_shapes=self._flat_output_shapes) + with ops.control_dependencies( + [gen_dataset_ops.make_iterator(ds_variant, resource)]): + return gen_dataset_ops.iterator_to_string_handle(resource) + + @function.Defun() + def _remote_init_func(): + return functional_ops.remote_call( + target=self._source_device, + args=_init_func.captured_inputs, + Tout=[dtypes.string], + f=_init_func) + + self._init_func = _remote_init_func + self._init_captured_args = _remote_init_func.captured_inputs + + @function.Defun(dtypes.string) + def _next_func(string_handle): + """Calls get_next for created iterator. + + Args: + string_handle: An iterator string handle created by _init_func + Returns: + The elements generated from `input_dataset` + """ + with ops.device(self._source_device_string): + iterator = iterator_ops.Iterator.from_string_handle( + string_handle, self.output_types, self.output_shapes, + self.output_classes) + ret = iterator.get_next() + return nest.flatten(sparse.serialize_sparse_tensors(ret)) + + @function.Defun(dtypes.string) + def _remote_next_func(string_handle): + return functional_ops.remote_call( + target=self._source_device, + args=[string_handle] + _next_func.captured_inputs, + Tout=self._flat_output_types, + f=_next_func) + + self._next_func = _remote_next_func + self._next_captured_args = _remote_next_func.captured_inputs + + @function.Defun(dtypes.string) + def _finalize_func(string_handle): + """Destroys the iterator resource created. + + Args: + string_handle: An iterator string handle created by _init_func + Returns: + Tensor constant 0 + """ + iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2( + string_handle, + output_types=self._flat_output_types, + output_shapes=self._flat_output_shapes) + with ops.control_dependencies([ + resource_variable_ops.destroy_resource_op( + iterator_resource, ignore_lookup_error=True)]): + return array_ops.constant(0, dtypes.int64) + + @function.Defun(dtypes.string) + def _remote_finalize_func(string_handle): + return functional_ops.remote_call( + target=self._source_device, + args=[string_handle] + _finalize_func.captured_inputs, + Tout=[dtypes.int64], + f=_finalize_func) + + self._finalize_func = _remote_finalize_func + self._finalize_captured_args = _remote_finalize_func.captured_inputs + + g = ops.get_default_graph() + _remote_init_func.add_to_graph(g) + _remote_next_func.add_to_graph(g) + _remote_finalize_func.add_to_graph(g) + # pylint: enable=protected-scope + + # The one_shot_iterator implementation needs a 0 arg _make_dataset function + # that thereby captures all the inputs required to create the dataset. Since + # there are strings that are inputs to the GeneratorDataset which can't be + # placed on a GPU, this fails for the GPU case. Therefore, disabling it for + # GPU + def make_one_shot_iterator(self): + if self._is_gpu_target: + raise ValueError("Cannot create a one shot iterator when using " + "`tf.data.experimental.copy_to_device()` on GPU. Please " + "use `Dataset.make_initializable_iterator()` instead.") + else: + return super(_CopyToDeviceDataset, self).make_one_shot_iterator() + + def _as_variant_tensor(self): + with ops.device(self._target_device): + return gen_dataset_ops.generator_dataset( + self._init_captured_args, + self._next_captured_args, + self._finalize_captured_args, + init_func=self._init_func, + next_func=self._next_func, + finalize_func=self._finalize_func, + output_types=self._flat_output_types, + output_shapes=self._flat_output_shapes) + + @property + def output_types(self): + return self._input_dataset.output_types + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_classes(self): + return self._input_dataset.output_classes diff --git a/tensorflow/python/data/experimental/ops/random_ops.py b/tensorflow/python/data/experimental/ops/random_ops.py new file mode 100644 index 0000000000..e3a2aeab31 --- /dev/null +++ b/tensorflow/python/data/experimental/ops/random_ops.py @@ -0,0 +1,54 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Datasets for random number generators.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import random_seed +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.util.tf_export import tf_export + + +@tf_export("data.experimental.RandomDataset") +class RandomDataset(dataset_ops.DatasetSource): + """A `Dataset` of pseudorandom values.""" + + def __init__(self, seed=None): + """A `Dataset` of pseudorandom values.""" + super(RandomDataset, self).__init__() + self._seed, self._seed2 = random_seed.get_seed(seed) + + def _as_variant_tensor(self): + return gen_dataset_ops.random_dataset( + seed=self._seed, + seed2=self._seed2, + **dataset_ops.flat_structure(self)) + + @property + def output_classes(self): + return ops.Tensor + + @property + def output_shapes(self): + return tensor_shape.scalar() + + @property + def output_types(self): + return dtypes.int64 diff --git a/tensorflow/python/data/experimental/ops/readers.py b/tensorflow/python/data/experimental/ops/readers.py new file mode 100644 index 0000000000..3b2d094514 --- /dev/null +++ b/tensorflow/python/data/experimental/ops/readers.py @@ -0,0 +1,904 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python wrappers for reader Datasets.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import csv + +import numpy as np + +from tensorflow.python.data.experimental.ops import batching +from tensorflow.python.data.experimental.ops import interleave_ops +from tensorflow.python.data.experimental.ops import optimization +from tensorflow.python.data.experimental.ops import parsing_ops +from tensorflow.python.data.experimental.ops import shuffle_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import readers as core_readers +from tensorflow.python.data.util import convert +from tensorflow.python.data.util import nest +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.lib.io import file_io +from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.ops import gen_experimental_dataset_ops +from tensorflow.python.platform import gfile +from tensorflow.python.util.tf_export import tf_export + +_ACCEPTABLE_CSV_TYPES = (dtypes.float32, dtypes.float64, dtypes.int32, + dtypes.int64, dtypes.string) + + +def _is_valid_int32(str_val): + try: + # Checks equality to prevent int32 overflow + return dtypes.int32.as_numpy_dtype(str_val) == dtypes.int64.as_numpy_dtype( + str_val) + except (ValueError, OverflowError): + return False + + +def _is_valid_int64(str_val): + try: + dtypes.int64.as_numpy_dtype(str_val) + return True + except (ValueError, OverflowError): + return False + + +def _is_valid_float(str_val, float_dtype): + try: + return float_dtype.as_numpy_dtype(str_val) < np.inf + except ValueError: + return False + + +def _infer_type(str_val, na_value, prev_type): + """Given a string, infers its tensor type. + + Infers the type of a value by picking the least 'permissive' type possible, + while still allowing the previous type inference for this column to be valid. + + Args: + str_val: String value to infer the type of. + na_value: Additional string to recognize as a NA/NaN CSV value. + prev_type: Type previously inferred based on values of this column that + we've seen up till now. + Returns: + Inferred dtype. + """ + if str_val in ("", na_value): + # If the field is null, it gives no extra information about its type + return prev_type + + type_list = [ + dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64, dtypes.string + ] # list of types to try, ordered from least permissive to most + + type_functions = [ + _is_valid_int32, + _is_valid_int64, + lambda str_val: _is_valid_float(str_val, dtypes.float32), + lambda str_val: _is_valid_float(str_val, dtypes.float64), + lambda str_val: True, + ] # Corresponding list of validation functions + + for i in range(len(type_list)): + validation_fn = type_functions[i] + if validation_fn(str_val) and (prev_type is None or + prev_type in type_list[:i + 1]): + return type_list[i] + + +def _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header): + """Generator that yields rows of CSV file(s) in order.""" + for fn in filenames: + with file_io.FileIO(fn, "r") as f: + rdr = csv.reader( + f, + delimiter=field_delim, + quoting=csv.QUOTE_MINIMAL if use_quote_delim else csv.QUOTE_NONE) + if header: + next(rdr) # Skip header lines + + for csv_row in rdr: + if len(csv_row) != num_cols: + raise ValueError( + "Problem inferring types: CSV row has different number of fields " + "than expected.") + yield csv_row + + +def _infer_column_defaults(filenames, num_cols, field_delim, use_quote_delim, + na_value, header, num_rows_for_inference, + select_columns): + """Infers column types from the first N valid CSV records of files.""" + if select_columns is None: + select_columns = range(num_cols) + inferred_types = [None] * len(select_columns) + + for i, csv_row in enumerate( + _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header)): + if num_rows_for_inference is not None and i >= num_rows_for_inference: + break + + for j, col_index in enumerate(select_columns): + inferred_types[j] = _infer_type(csv_row[col_index], na_value, + inferred_types[j]) + + # Replace None's with a default type + inferred_types = [t or dtypes.string for t in inferred_types] + # Default to 0 or '' for null values + return [ + constant_op.constant([0 if t is not dtypes.string else ""], dtype=t) + for t in inferred_types + ] + + +def _infer_column_names(filenames, field_delim, use_quote_delim): + """Infers column names from first rows of files.""" + csv_kwargs = { + "delimiter": field_delim, + "quoting": csv.QUOTE_MINIMAL if use_quote_delim else csv.QUOTE_NONE + } + with file_io.FileIO(filenames[0], "r") as f: + try: + column_names = next(csv.reader(f, **csv_kwargs)) + except StopIteration: + raise ValueError(("Received StopIteration when reading the header line " + "of %s. Empty file?") % filenames[0]) + + for name in filenames[1:]: + with file_io.FileIO(name, "r") as f: + try: + if next(csv.reader(f, **csv_kwargs)) != column_names: + raise ValueError( + "Files have different column names in the header row.") + except StopIteration: + raise ValueError(("Received StopIteration when reading the header line " + "of %s. Empty file?") % filenames[0]) + return column_names + + +def _get_sorted_col_indices(select_columns, column_names): + """Transforms select_columns argument into sorted column indices.""" + names_to_indices = {n: i for i, n in enumerate(column_names)} + num_cols = len(column_names) + for i, v in enumerate(select_columns): + if isinstance(v, int): + if v < 0 or v >= num_cols: + raise ValueError( + "Column index %d specified in select_columns out of valid range." % + v) + continue + if v not in names_to_indices: + raise ValueError( + "Value '%s' specified in select_columns not a valid column index or " + "name." % v) + select_columns[i] = names_to_indices[v] + + # Sort and ensure there are no duplicates + result = sorted(set(select_columns)) + if len(result) != len(select_columns): + raise ValueError("select_columns contains duplicate columns") + return result + + +def _maybe_shuffle_and_repeat( + dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed): + """Optionally shuffle and repeat dataset, as requested.""" + if num_epochs != 1 and shuffle: + # Use shuffle_and_repeat for perf + return dataset.apply( + shuffle_ops.shuffle_and_repeat(shuffle_buffer_size, num_epochs, + shuffle_seed)) + elif shuffle: + return dataset.shuffle(shuffle_buffer_size, shuffle_seed) + elif num_epochs != 1: + return dataset.repeat(num_epochs) + return dataset + + +def make_tf_record_dataset(file_pattern, + batch_size, + parser_fn=None, + num_epochs=None, + shuffle=True, + shuffle_buffer_size=None, + shuffle_seed=None, + prefetch_buffer_size=optimization.AUTOTUNE, + num_parallel_reads=None, + num_parallel_parser_calls=None, + drop_final_batch=False): + """Reads and optionally parses TFRecord files into a dataset. + + Provides common functionality such as batching, optional parsing, shuffling, + and performant defaults. + + Args: + file_pattern: List of files or patterns of TFRecord file paths. + See `tf.gfile.Glob` for pattern rules. + batch_size: An int representing the number of records to combine + in a single batch. + parser_fn: (Optional.) A function accepting string input to parse + and process the record contents. This function must map records + to components of a fixed shape, so they may be batched. By + default, uses the record contents unmodified. + num_epochs: (Optional.) An int specifying the number of times this + dataset is repeated. If None (the default), cycles through the + dataset forever. + shuffle: (Optional.) A bool that indicates whether the input + should be shuffled. Defaults to `True`. + shuffle_buffer_size: (Optional.) Buffer size to use for + shuffling. A large buffer size ensures better shuffling, but + increases memory usage and startup time. + shuffle_seed: (Optional.) Randomization seed to use for shuffling. + prefetch_buffer_size: (Optional.) An int specifying the number of + feature batches to prefetch for performance improvement. + Defaults to auto-tune. Set to 0 to disable prefetching. + num_parallel_reads: (Optional.) Number of threads used to read + records from files. By default or if set to a value >1, the + results will be interleaved. + num_parallel_parser_calls: (Optional.) Number of parallel + records to parse in parallel. Defaults to an automatic selection. + drop_final_batch: (Optional.) Whether the last batch should be + dropped in case its size is smaller than `batch_size`; the + default behavior is not to drop the smaller batch. + + Returns: + A dataset, where each element matches the output of `parser_fn` + except it will have an additional leading `batch-size` dimension, + or a `batch_size`-length 1-D tensor of strings if `parser_fn` is + unspecified. + """ + files = dataset_ops.Dataset.list_files( + file_pattern, shuffle=shuffle, seed=shuffle_seed) + + if num_parallel_reads is None: + # Note: We considered auto-tuning this value, but there is a concern + # that this affects the mixing of records from different files, which + # could affect training convergence/accuracy, so we are defaulting to + # a constant for now. + num_parallel_reads = 24 + dataset = core_readers.TFRecordDataset( + files, num_parallel_reads=num_parallel_reads) + + if shuffle_buffer_size is None: + # TODO(josh11b): Auto-tune this value when not specified + shuffle_buffer_size = 10000 + dataset = _maybe_shuffle_and_repeat( + dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed) + + # NOTE(mrry): We set `drop_final_batch=True` when `num_epochs is None` to + # improve the shape inference, because it makes the batch dimension static. + # It is safe to do this because in that case we are repeating the input + # indefinitely, and all batches will be full-sized. + drop_final_batch = drop_final_batch or num_epochs is None + + if parser_fn is None: + dataset = dataset.batch(batch_size, drop_remainder=drop_final_batch) + else: + # TODO(josh11b): if num_parallel_parser_calls is None, use some function + # of num cores instead of map_and_batch's default behavior of one batch. + dataset = dataset.apply(batching.map_and_batch( + parser_fn, batch_size, num_parallel_calls=num_parallel_parser_calls, + drop_remainder=drop_final_batch)) + + if prefetch_buffer_size == 0: + return dataset + else: + return dataset.prefetch(buffer_size=prefetch_buffer_size) + + +@tf_export("data.experimental.make_csv_dataset") +def make_csv_dataset( + file_pattern, + batch_size, + column_names=None, + column_defaults=None, + label_name=None, + select_columns=None, + field_delim=",", + use_quote_delim=True, + na_value="", + header=True, + num_epochs=None, + shuffle=True, + shuffle_buffer_size=10000, + shuffle_seed=None, + prefetch_buffer_size=optimization.AUTOTUNE, + num_parallel_reads=1, + sloppy=False, + num_rows_for_inference=100, + compression_type=None, +): + """Reads CSV files into a dataset. + + Reads CSV files into a dataset, where each element is a (features, labels) + tuple that corresponds to a batch of CSV rows. The features dictionary + maps feature column names to `Tensor`s containing the corresponding + feature data, and labels is a `Tensor` containing the batch's label data. + + Args: + file_pattern: List of files or patterns of file paths containing CSV + records. See `tf.gfile.Glob` for pattern rules. + batch_size: An int representing the number of records to combine + in a single batch. + column_names: An optional list of strings that corresponds to the CSV + columns, in order. One per column of the input record. If this is not + provided, infers the column names from the first row of the records. + These names will be the keys of the features dict of each dataset element. + column_defaults: A optional list of default values for the CSV fields. One + item per selected column of the input record. Each item in the list is + either a valid CSV dtype (float32, float64, int32, int64, or string), or a + `Tensor` with one of the aforementioned types. The tensor can either be + a scalar default value (if the column is optional), or an empty tensor (if + the column is required). If a dtype is provided instead of a tensor, the + column is also treated as required. If this list is not provided, tries + to infer types based on reading the first num_rows_for_inference rows of + files specified, and assumes all columns are optional, defaulting to `0` + for numeric values and `""` for string values. If both this and + `select_columns` are specified, these must have the same lengths, and + `column_defaults` is assumed to be sorted in order of increasing column + index. + label_name: A optional string corresponding to the label column. If + provided, the data for this column is returned as a separate `Tensor` from + the features dictionary, so that the dataset complies with the format + expected by a `tf.Estimator.train` or `tf.Estimator.evaluate` input + function. + select_columns: An optional list of integer indices or string column + names, that specifies a subset of columns of CSV data to select. If + column names are provided, these must correspond to names provided in + `column_names` or inferred from the file header lines. When this argument + is specified, only a subset of CSV columns will be parsed and returned, + corresponding to the columns specified. Using this results in faster + parsing and lower memory usage. If both this and `column_defaults` are + specified, these must have the same lengths, and `column_defaults` is + assumed to be sorted in order of increasing column index. + field_delim: An optional `string`. Defaults to `","`. Char delimiter to + separate fields in a record. + use_quote_delim: An optional bool. Defaults to `True`. If false, treats + double quotation marks as regular characters inside of the string fields. + na_value: Additional string to recognize as NA/NaN. + header: A bool that indicates whether the first rows of provided CSV files + correspond to header lines with column names, and should not be included + in the data. + num_epochs: An int specifying the number of times this dataset is repeated. + If None, cycles through the dataset forever. + shuffle: A bool that indicates whether the input should be shuffled. + shuffle_buffer_size: Buffer size to use for shuffling. A large buffer size + ensures better shuffling, but increases memory usage and startup time. + shuffle_seed: Randomization seed to use for shuffling. + prefetch_buffer_size: An int specifying the number of feature + batches to prefetch for performance improvement. Recommended value is the + number of batches consumed per training step. Defaults to auto-tune. + + num_parallel_reads: Number of threads used to read CSV records from files. + If >1, the results will be interleaved. + sloppy: If `True`, reading performance will be improved at + the cost of non-deterministic ordering. If `False`, the order of elements + produced is deterministic prior to shuffling (elements are still + randomized if `shuffle=True`. Note that if the seed is set, then order + of elements after shuffling is deterministic). Defaults to `False`. + num_rows_for_inference: Number of rows of a file to use for type inference + if record_defaults is not provided. If None, reads all the rows of all + the files. Defaults to 100. + compression_type: (Optional.) A `tf.string` scalar evaluating to one of + `""` (no compression), `"ZLIB"`, or `"GZIP"`. Defaults to no compression. + + Returns: + A dataset, where each element is a (features, labels) tuple that corresponds + to a batch of `batch_size` CSV rows. The features dictionary maps feature + column names to `Tensor`s containing the corresponding column data, and + labels is a `Tensor` containing the column data for the label column + specified by `label_name`. + + Raises: + ValueError: If any of the arguments is malformed. + """ + # Create dataset of all matching filenames + filenames = _get_file_names(file_pattern, False) + dataset = dataset_ops.Dataset.from_tensor_slices(filenames) + if shuffle: + dataset = dataset.shuffle(len(filenames), shuffle_seed) + + # Clean arguments; figure out column names and defaults + + if column_names is None: + if not header: + raise ValueError("Cannot infer column names without a header line.") + # If column names are not provided, infer from the header lines + column_names = _infer_column_names(filenames, field_delim, use_quote_delim) + if len(column_names) != len(set(column_names)): + raise ValueError("Cannot have duplicate column names.") + + if select_columns is not None: + select_columns = _get_sorted_col_indices(select_columns, column_names) + + if column_defaults is not None: + column_defaults = [ + constant_op.constant([], dtype=x) if x in _ACCEPTABLE_CSV_TYPES else x + for x in column_defaults + ] + else: + # If column defaults are not provided, infer from records at graph + # construction time + column_defaults = _infer_column_defaults( + filenames, len(column_names), field_delim, use_quote_delim, na_value, + header, num_rows_for_inference, select_columns) + + if select_columns is not None and len(column_defaults) != len(select_columns): + raise ValueError( + "If specified, column_defaults and select_columns must have same " + "length." + ) + if select_columns is not None and len(column_names) > len(select_columns): + # Pick the relevant subset of column names + column_names = [column_names[i] for i in select_columns] + + if label_name is not None and label_name not in column_names: + raise ValueError("`label_name` provided must be one of the columns.") + + def filename_to_dataset(filename): + return CsvDataset( + filename, + record_defaults=column_defaults, + field_delim=field_delim, + use_quote_delim=use_quote_delim, + na_value=na_value, + select_cols=select_columns, + header=header, + compression_type=compression_type, + ) + + def map_fn(*columns): + """Organizes columns into a features dictionary. + + Args: + *columns: list of `Tensor`s corresponding to one csv record. + Returns: + An OrderedDict of feature names to values for that particular record. If + label_name is provided, extracts the label feature to be returned as the + second element of the tuple. + """ + features = collections.OrderedDict(zip(column_names, columns)) + if label_name is not None: + label = features.pop(label_name) + return features, label + return features + + # Read files sequentially (if num_parallel_reads=1) or in parallel + dataset = dataset.apply( + interleave_ops.parallel_interleave( + filename_to_dataset, cycle_length=num_parallel_reads, sloppy=sloppy)) + + dataset = _maybe_shuffle_and_repeat( + dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed) + + # Apply batch before map for perf, because map has high overhead relative + # to the size of the computation in each map. + # NOTE(mrry): We set `drop_remainder=True` when `num_epochs is None` to + # improve the shape inference, because it makes the batch dimension static. + # It is safe to do this because in that case we are repeating the input + # indefinitely, and all batches will be full-sized. + dataset = dataset.batch(batch_size=batch_size, + drop_remainder=num_epochs is None) + dataset = dataset_ops.MapDataset( + dataset, map_fn, use_inter_op_parallelism=False) + dataset = dataset.prefetch(prefetch_buffer_size) + + return dataset + + +_DEFAULT_READER_BUFFER_SIZE_BYTES = 4 * 1024 * 1024 # 4 MB + + +@tf_export("data.experimental.CsvDataset") +class CsvDataset(dataset_ops.DatasetSource): + """A Dataset comprising lines from one or more CSV files.""" + + def __init__(self, + filenames, + record_defaults, + compression_type=None, + buffer_size=None, + header=False, + field_delim=",", + use_quote_delim=True, + na_value="", + select_cols=None): + """Creates a `CsvDataset` by reading and decoding CSV files. + + The elements of this dataset correspond to records from the file(s). + RFC 4180 format is expected for CSV files + (https://tools.ietf.org/html/rfc4180) + Note that we allow leading and trailing spaces with int or float field. + + + For example, suppose we have a file 'my_file0.csv' with four CSV columns of + different data types: + ``` + abcdefg,4.28E10,5.55E6,12 + hijklmn,-5.3E14,,2 + ``` + + We can construct a CsvDataset from it as follows: + ```python + dataset = tf.data.experimental.CsvDataset( + "my_file*.csv", + [tf.float32, # Required field, use dtype or empty tensor + tf.constant([0.0], dtype=tf.float32), # Optional field, default to 0.0 + tf.int32, # Required field, use dtype or empty tensor + ], + select_cols=[1,2,3] # Only parse last three columns + ) + ``` + + The expected output of its iterations is: + ```python + next_element = dataset.make_one_shot_iterator().get_next() + with tf.Session() as sess: + while True: + try: + print(sess.run(next_element)) + except tf.errors.OutOfRangeError: + break + + >> (4.28e10, 5.55e6, 12) + >> (-5.3e14, 0.0, 2) + ``` + + Args: + filenames: A `tf.string` tensor containing one or more filenames. + record_defaults: A list of default values for the CSV fields. Each item in + the list is either a valid CSV `DType` (float32, float64, int32, int64, + string), or a `Tensor` object with one of the above types. One per + column of CSV data, with either a scalar `Tensor` default value for the + column if it is optional, or `DType` or empty `Tensor` if required. If + both this and `select_columns` are specified, these must have the same + lengths, and `column_defaults` is assumed to be sorted in order of + increasing column index. + compression_type: (Optional.) A `tf.string` scalar evaluating to one of + `""` (no compression), `"ZLIB"`, or `"GZIP"`. Defaults to no + compression. + buffer_size: (Optional.) A `tf.int64` scalar denoting the number of bytes + to buffer while reading files. Defaults to 4MB. + header: (Optional.) A `tf.bool` scalar indicating whether the CSV file(s) + have header line(s) that should be skipped when parsing. Defaults to + `False`. + field_delim: (Optional.) A `tf.string` scalar containing the delimiter + character that separates fields in a record. Defaults to `","`. + use_quote_delim: (Optional.) A `tf.bool` scalar. If `False`, treats + double quotation marks as regular characters inside of string fields + (ignoring RFC 4180, Section 2, Bullet 5). Defaults to `True`. + na_value: (Optional.) A `tf.string` scalar indicating a value that will + be treated as NA/NaN. + select_cols: (Optional.) A sorted list of column indices to select from + the input data. If specified, only this subset of columns will be + parsed. Defaults to parsing all columns. + """ + super(CsvDataset, self).__init__() + self._filenames = ops.convert_to_tensor( + filenames, dtype=dtypes.string, name="filenames") + self._compression_type = convert.optional_param_to_tensor( + "compression_type", + compression_type, + argument_default="", + argument_dtype=dtypes.string) + record_defaults = [ + constant_op.constant([], dtype=x) if x in _ACCEPTABLE_CSV_TYPES else x + for x in record_defaults + ] + self._record_defaults = ops.convert_n_to_tensor( + record_defaults, name="record_defaults") + self._buffer_size = convert.optional_param_to_tensor( + "buffer_size", buffer_size, _DEFAULT_READER_BUFFER_SIZE_BYTES) + self._header = ops.convert_to_tensor( + header, dtype=dtypes.bool, name="header") + self._field_delim = ops.convert_to_tensor( + field_delim, dtype=dtypes.string, name="field_delim") + self._use_quote_delim = ops.convert_to_tensor( + use_quote_delim, dtype=dtypes.bool, name="use_quote_delim") + self._na_value = ops.convert_to_tensor( + na_value, dtype=dtypes.string, name="na_value") + self._select_cols = convert.optional_param_to_tensor( + "select_cols", + select_cols, + argument_default=[], + argument_dtype=dtypes.int64, + ) + self._output_shapes = tuple( + tensor_shape.scalar() for _ in range(len(record_defaults))) + self._output_types = tuple(d.dtype for d in self._record_defaults) + self._output_classes = tuple( + ops.Tensor for _ in range(len(record_defaults))) + + def _as_variant_tensor(self): + # Constructs graph node for the dataset op. + return gen_experimental_dataset_ops.experimental_csv_dataset( + filenames=self._filenames, + record_defaults=self._record_defaults, + buffer_size=self._buffer_size, + header=self._header, + output_shapes=self._output_shapes, + field_delim=self._field_delim, + use_quote_delim=self._use_quote_delim, + na_value=self._na_value, + select_cols=self._select_cols, + compression_type=self._compression_type, + ) + + @property + def output_types(self): + return self._output_types + + @property + def output_shapes(self): + return self._output_shapes + + @property + def output_classes(self): + return self._output_classes + + +@tf_export("data.experimental.make_batched_features_dataset") +def make_batched_features_dataset(file_pattern, + batch_size, + features, + reader=core_readers.TFRecordDataset, + label_key=None, + reader_args=None, + num_epochs=None, + shuffle=True, + shuffle_buffer_size=10000, + shuffle_seed=None, + prefetch_buffer_size=optimization.AUTOTUNE, + reader_num_threads=1, + parser_num_threads=2, + sloppy_ordering=False, + drop_final_batch=False): + """Returns a `Dataset` of feature dictionaries from `Example` protos. + + If label_key argument is provided, returns a `Dataset` of tuple + comprising of feature dictionaries and label. + + Example: + + ``` + serialized_examples = [ + features { + feature { key: "age" value { int64_list { value: [ 0 ] } } } + feature { key: "gender" value { bytes_list { value: [ "f" ] } } } + feature { key: "kws" value { bytes_list { value: [ "code", "art" ] } } } + }, + features { + feature { key: "age" value { int64_list { value: [] } } } + feature { key: "gender" value { bytes_list { value: [ "f" ] } } } + feature { key: "kws" value { bytes_list { value: [ "sports" ] } } } + } + ] + ``` + + We can use arguments: + + ``` + features: { + "age": FixedLenFeature([], dtype=tf.int64, default_value=-1), + "gender": FixedLenFeature([], dtype=tf.string), + "kws": VarLenFeature(dtype=tf.string), + } + ``` + + And the expected output is: + + ```python + { + "age": [[0], [-1]], + "gender": [["f"], ["f"]], + "kws": SparseTensor( + indices=[[0, 0], [0, 1], [1, 0]], + values=["code", "art", "sports"] + dense_shape=[2, 2]), + } + ``` + + Args: + file_pattern: List of files or patterns of file paths containing + `Example` records. See `tf.gfile.Glob` for pattern rules. + batch_size: An int representing the number of records to combine + in a single batch. + features: A `dict` mapping feature keys to `FixedLenFeature` or + `VarLenFeature` values. See `tf.parse_example`. + reader: A function or class that can be + called with a `filenames` tensor and (optional) `reader_args` and returns + a `Dataset` of `Example` tensors. Defaults to `tf.data.TFRecordDataset`. + label_key: (Optional) A string corresponding to the key labels are stored in + `tf.Examples`. If provided, it must be one of the `features` key, + otherwise results in `ValueError`. + reader_args: Additional arguments to pass to the reader class. + num_epochs: Integer specifying the number of times to read through the + dataset. If None, cycles through the dataset forever. Defaults to `None`. + shuffle: A boolean, indicates whether the input should be shuffled. Defaults + to `True`. + shuffle_buffer_size: Buffer size of the ShuffleDataset. A large capacity + ensures better shuffling but would increase memory usage and startup time. + shuffle_seed: Randomization seed to use for shuffling. + prefetch_buffer_size: Number of feature batches to prefetch in order to + improve performance. Recommended value is the number of batches consumed + per training step. Defaults to auto-tune. + reader_num_threads: Number of threads used to read `Example` records. If >1, + the results will be interleaved. + parser_num_threads: Number of threads to use for parsing `Example` tensors + into a dictionary of `Feature` tensors. + sloppy_ordering: If `True`, reading performance will be improved at + the cost of non-deterministic ordering. If `False`, the order of elements + produced is deterministic prior to shuffling (elements are still + randomized if `shuffle=True`. Note that if the seed is set, then order + of elements after shuffling is deterministic). Defaults to `False`. + drop_final_batch: If `True`, and the batch size does not evenly divide the + input dataset size, the final smaller batch will be dropped. Defaults to + `False`. + + Returns: + A dataset of `dict` elements, (or a tuple of `dict` elements and label). + Each `dict` maps feature keys to `Tensor` or `SparseTensor` objects. + + Raises: + ValueError: If `label_key` is not one of the `features` keys. + """ + # Create dataset of all matching filenames + filenames = _get_file_names(file_pattern, False) + dataset = dataset_ops.Dataset.from_tensor_slices(filenames) + if shuffle: + dataset = dataset.shuffle(len(filenames), shuffle_seed) + + # Read `Example` records from files as tensor objects. + if reader_args is None: + reader_args = [] + + # Read files sequentially (if reader_num_threads=1) or in parallel + dataset = dataset.apply( + interleave_ops.parallel_interleave( + lambda filename: reader(filename, *reader_args), + cycle_length=reader_num_threads, + sloppy=sloppy_ordering)) + + # Extract values if the `Example` tensors are stored as key-value tuples. + if dataset.output_types == (dtypes.string, dtypes.string): + dataset = dataset_ops.MapDataset( + dataset, lambda _, v: v, use_inter_op_parallelism=False) + + # Apply dataset repeat and shuffle transformations. + dataset = _maybe_shuffle_and_repeat( + dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed) + + # NOTE(mrry): We set `drop_remainder=True` when `num_epochs is None` to + # improve the shape inference, because it makes the batch dimension static. + # It is safe to do this because in that case we are repeating the input + # indefinitely, and all batches will be full-sized. + dataset = dataset.batch( + batch_size, drop_remainder=drop_final_batch or num_epochs is None) + + # Parse `Example` tensors to a dictionary of `Feature` tensors. + dataset = dataset.apply( + parsing_ops.parse_example_dataset( + features, num_parallel_calls=parser_num_threads)) + + if label_key: + if label_key not in features: + raise ValueError( + "The `label_key` provided (%r) must be one of the `features` keys." % + label_key) + dataset = dataset.map(lambda x: (x, x.pop(label_key))) + + dataset = dataset.prefetch(prefetch_buffer_size) + return dataset + + +def _get_file_names(file_pattern, shuffle): + """Parse list of file names from pattern, optionally shuffled. + + Args: + file_pattern: File glob pattern, or list of glob patterns. + shuffle: Whether to shuffle the order of file names. + + Returns: + List of file names matching `file_pattern`. + + Raises: + ValueError: If `file_pattern` is empty, or pattern matches no files. + """ + if isinstance(file_pattern, list): + if not file_pattern: + raise ValueError("File pattern is empty.") + file_names = [] + for entry in file_pattern: + file_names.extend(gfile.Glob(entry)) + else: + file_names = list(gfile.Glob(file_pattern)) + + if not file_names: + raise ValueError("No files match %s." % file_pattern) + + # Sort files so it will be deterministic for unit tests. + if not shuffle: + file_names = sorted(file_names) + return file_names + + +@tf_export("data.experimental.SqlDataset") +class SqlDataset(dataset_ops.DatasetSource): + """A `Dataset` consisting of the results from a SQL query.""" + + def __init__(self, driver_name, data_source_name, query, output_types): + """Creates a `SqlDataset`. + + `SqlDataset` allows a user to read data from the result set of a SQL query. + For example: + + ```python + dataset = tf.data.experimental.SqlDataset("sqlite", "/foo/bar.sqlite3", + "SELECT name, age FROM people", + (tf.string, tf.int32)) + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + # Prints the rows of the result set of the above query. + while True: + try: + print(sess.run(next_element)) + except tf.errors.OutOfRangeError: + break + ``` + + Args: + driver_name: A 0-D `tf.string` tensor containing the database type. + Currently, the only supported value is 'sqlite'. + data_source_name: A 0-D `tf.string` tensor containing a connection string + to connect to the database. + query: A 0-D `tf.string` tensor containing the SQL query to execute. + output_types: A tuple of `tf.DType` objects representing the types of the + columns returned by `query`. + """ + super(SqlDataset, self).__init__() + self._driver_name = ops.convert_to_tensor( + driver_name, dtype=dtypes.string, name="driver_name") + self._data_source_name = ops.convert_to_tensor( + data_source_name, dtype=dtypes.string, name="data_source_name") + self._query = ops.convert_to_tensor( + query, dtype=dtypes.string, name="query") + self._output_types = output_types + + def _as_variant_tensor(self): + return gen_dataset_ops.sql_dataset(self._driver_name, + self._data_source_name, self._query, + nest.flatten(self.output_types), + nest.flatten(self.output_shapes)) + + @property + def output_classes(self): + return nest.map_structure(lambda _: ops.Tensor, self._output_types) + + @property + def output_shapes(self): + return nest.map_structure(lambda _: tensor_shape.TensorShape([]), + self._output_types) + + @property + def output_types(self): + return self._output_types diff --git a/tensorflow/python/data/experimental/ops/resampling.py b/tensorflow/python/data/experimental/ops/resampling.py new file mode 100644 index 0000000000..3a3040ae9a --- /dev/null +++ b/tensorflow/python/data/experimental/ops/resampling.py @@ -0,0 +1,296 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Resampling dataset transformations.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.data.experimental.ops import batching +from tensorflow.python.data.experimental.ops import interleave_ops +from tensorflow.python.data.experimental.ops import scan_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import logging_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.util.tf_export import tf_export + + +@tf_export("data.experimental.rejection_resample") +def rejection_resample(class_func, target_dist, initial_dist=None, seed=None): + """A transformation that resamples a dataset to achieve a target distribution. + + **NOTE** Resampling is performed via rejection sampling; some fraction + of the input values will be dropped. + + Args: + class_func: A function mapping an element of the input dataset to a scalar + `tf.int32` tensor. Values should be in `[0, num_classes)`. + target_dist: A floating point type tensor, shaped `[num_classes]`. + initial_dist: (Optional.) A floating point type tensor, shaped + `[num_classes]`. If not provided, the true class distribution is + estimated live in a streaming fashion. + seed: (Optional.) Python integer seed for the resampler. + + Returns: + A `Dataset` transformation function, which can be passed to + `tf.data.Dataset.apply`. + """ + def _apply_fn(dataset): + """Function from `Dataset` to `Dataset` that applies the transformation.""" + target_dist_t = ops.convert_to_tensor(target_dist, name="target_dist") + class_values_ds = dataset.map(class_func) + + # Get initial distribution. + if initial_dist is not None: + initial_dist_t = ops.convert_to_tensor(initial_dist, name="initial_dist") + acceptance_dist, prob_of_original = ( + _calculate_acceptance_probs_with_mixing(initial_dist_t, + target_dist_t)) + initial_dist_ds = dataset_ops.Dataset.from_tensors( + initial_dist_t).repeat() + acceptance_dist_ds = dataset_ops.Dataset.from_tensors( + acceptance_dist).repeat() + prob_of_original_ds = dataset_ops.Dataset.from_tensors( + prob_of_original).repeat() + else: + initial_dist_ds = _estimate_initial_dist_ds( + target_dist_t, class_values_ds) + acceptance_and_original_prob_ds = initial_dist_ds.map( + lambda initial: _calculate_acceptance_probs_with_mixing( # pylint: disable=g-long-lambda + initial, target_dist_t)) + acceptance_dist_ds = acceptance_and_original_prob_ds.map( + lambda accept_prob, _: accept_prob) + prob_of_original_ds = acceptance_and_original_prob_ds.map( + lambda _, prob_original: prob_original) + filtered_ds = _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds, + class_values_ds, seed) + # Prefetch filtered dataset for speed. + filtered_ds = filtered_ds.prefetch(3) + + prob_original_static = _get_prob_original_static( + initial_dist_t, target_dist_t) if initial_dist is not None else None + if prob_original_static == 1: + return dataset_ops.Dataset.zip((class_values_ds, dataset)) + elif prob_original_static == 0: + return filtered_ds + else: + return interleave_ops.sample_from_datasets( + [dataset_ops.Dataset.zip((class_values_ds, dataset)), filtered_ds], + weights=prob_of_original_ds.map(lambda prob: [(prob, 1.0 - prob)]), + seed=seed) + + return _apply_fn + + +def _get_prob_original_static(initial_dist_t, target_dist_t): + """Returns the static probability of sampling from the original. + + `tensor_util.constant_value(prob_of_original)` returns `None` if it encounters + an Op that it isn't defined for. We have some custom logic to avoid this. + + Args: + initial_dist_t: A tensor of the initial distribution. + target_dist_t: A tensor of the target distribution. + + Returns: + The probability of sampling from the original distribution as a constant, + if it is a constant, or `None`. + """ + init_static = tensor_util.constant_value(initial_dist_t) + target_static = tensor_util.constant_value(target_dist_t) + + if init_static is None or target_static is None: + return None + else: + return np.min(target_static / init_static) + + +def _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds, class_values_ds, + seed): + """Filters a dataset based on per-class acceptance probabilities. + + Args: + dataset: The dataset to be filtered. + acceptance_dist_ds: A dataset of acceptance probabilities. + initial_dist_ds: A dataset of the initial probability distribution, given or + estimated. + class_values_ds: A dataset of the corresponding classes. + seed: (Optional.) Python integer seed for the resampler. + + Returns: + A dataset of (class value, data) after filtering. + """ + def maybe_warn_on_large_rejection(accept_dist, initial_dist): + proportion_rejected = math_ops.reduce_sum((1 - accept_dist) * initial_dist) + return control_flow_ops.cond( + math_ops.less(proportion_rejected, .5), + lambda: accept_dist, + lambda: logging_ops.Print( # pylint: disable=g-long-lambda + accept_dist, [proportion_rejected, initial_dist, accept_dist], + message="Proportion of examples rejected by sampler is high: ", + summarize=100, + first_n=10)) + + acceptance_dist_ds = (dataset_ops.Dataset.zip((acceptance_dist_ds, + initial_dist_ds)) + .map(maybe_warn_on_large_rejection)) + + def _gather_and_copy(class_val, acceptance_prob, data): + return class_val, array_ops.gather(acceptance_prob, class_val), data + + current_probabilities_and_class_and_data_ds = dataset_ops.Dataset.zip( + (class_values_ds, acceptance_dist_ds, dataset)).map(_gather_and_copy) + filtered_ds = ( + current_probabilities_and_class_and_data_ds + .filter(lambda _1, p, _2: random_ops.random_uniform([], seed=seed) < p)) + return filtered_ds.map(lambda class_value, _, data: (class_value, data)) + + +def _estimate_initial_dist_ds( + target_dist_t, class_values_ds, dist_estimation_batch_size=32, + smoothing_constant=10): + num_classes = (target_dist_t.shape[0].value or + array_ops.shape(target_dist_t)[0]) + initial_examples_per_class_seen = array_ops.fill( + [num_classes], np.int64(smoothing_constant)) + + def update_estimate_and_tile(num_examples_per_class_seen, c): + updated_examples_per_class_seen, dist = _estimate_data_distribution( + c, num_examples_per_class_seen) + tiled_dist = array_ops.tile( + array_ops.expand_dims(dist, 0), [dist_estimation_batch_size, 1]) + return updated_examples_per_class_seen, tiled_dist + + initial_dist_ds = (class_values_ds.batch(dist_estimation_batch_size) + .apply(scan_ops.scan(initial_examples_per_class_seen, + update_estimate_and_tile)) + .apply(batching.unbatch())) + + return initial_dist_ds + + +def _get_target_to_initial_ratio(initial_probs, target_probs): + # Add tiny to initial_probs to avoid divide by zero. + denom = (initial_probs + np.finfo(initial_probs.dtype.as_numpy_dtype).tiny) + return target_probs / denom + + +def _estimate_data_distribution(c, num_examples_per_class_seen): + """Estimate data distribution as labels are seen. + + Args: + c: The class labels. Type `int32`, shape `[batch_size]`. + num_examples_per_class_seen: Type `int64`, shape `[num_classes]`, + containing counts. + + Returns: + num_examples_per_lass_seen: Updated counts. Type `int64`, shape + `[num_classes]`. + dist: The updated distribution. Type `float32`, shape `[num_classes]`. + """ + num_classes = num_examples_per_class_seen.get_shape()[0].value + # Update the class-count based on what labels are seen in batch. + num_examples_per_class_seen = math_ops.add( + num_examples_per_class_seen, math_ops.reduce_sum( + array_ops.one_hot(c, num_classes, dtype=dtypes.int64), 0)) + init_prob_estimate = math_ops.truediv( + num_examples_per_class_seen, + math_ops.reduce_sum(num_examples_per_class_seen)) + dist = math_ops.cast(init_prob_estimate, dtypes.float32) + return num_examples_per_class_seen, dist + + +def _calculate_acceptance_probs_with_mixing(initial_probs, target_probs): + """Calculates the acceptance probabilities and mixing ratio. + + In this case, we assume that we can *either* sample from the original data + distribution with probability `m`, or sample from a reshaped distribution + that comes from rejection sampling on the original distribution. This + rejection sampling is done on a per-class basis, with `a_i` representing the + probability of accepting data from class `i`. + + This method is based on solving the following analysis for the reshaped + distribution: + + Let F be the probability of a rejection (on any example). + Let p_i be the proportion of examples in the data in class i (init_probs) + Let a_i is the rate the rejection sampler should *accept* class i + Let t_i is the target proportion in the minibatches for class i (target_probs) + + ``` + F = sum_i(p_i * (1-a_i)) + = 1 - sum_i(p_i * a_i) using sum_i(p_i) = 1 + ``` + + An example with class `i` will be accepted if `k` rejections occur, then an + example with class `i` is seen by the rejector, and it is accepted. This can + be written as follows: + + ``` + t_i = sum_k=0^inf(F^k * p_i * a_i) + = p_i * a_j / (1 - F) using geometric series identity, since 0 <= F < 1 + = p_i * a_i / sum_j(p_j * a_j) using F from above + ``` + + Note that the following constraints hold: + ``` + 0 <= p_i <= 1, sum_i(p_i) = 1 + 0 <= a_i <= 1 + 0 <= t_i <= 1, sum_i(t_i) = 1 + ``` + + A solution for a_i in terms of the other variables is the following: + ```a_i = (t_i / p_i) / max_i[t_i / p_i]``` + + If we try to minimize the amount of data rejected, we get the following: + + M_max = max_i [ t_i / p_i ] + M_min = min_i [ t_i / p_i ] + + The desired probability of accepting data if it comes from class `i`: + + a_i = (t_i/p_i - m) / (M_max - m) + + The desired probability of pulling a data element from the original dataset, + rather than the filtered one: + + m = M_min + + Args: + initial_probs: A Tensor of the initial probability distribution, given or + estimated. + target_probs: A Tensor of the corresponding classes. + + Returns: + (A 1D Tensor with the per-class acceptance probabilities, the desired + probability of pull from the original distribution.) + """ + ratio_l = _get_target_to_initial_ratio(initial_probs, target_probs) + max_ratio = math_ops.reduce_max(ratio_l) + min_ratio = math_ops.reduce_min(ratio_l) + + # Target prob to sample from original distribution. + m = min_ratio + + # TODO(joelshor): Simplify fraction, if possible. + a_i = (ratio_l - m) / (max_ratio - m) + return a_i, m diff --git a/tensorflow/python/data/experimental/ops/scan_ops.py b/tensorflow/python/data/experimental/ops/scan_ops.py new file mode 100644 index 0000000000..e05e7c5a18 --- /dev/null +++ b/tensorflow/python/data/experimental/ops/scan_ops.py @@ -0,0 +1,177 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Scan dataset transformation.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.util.tf_export import tf_export + + +class _ScanDataset(dataset_ops.UnaryDataset): + """A dataset that scans a function across its input.""" + + def __init__(self, input_dataset, initial_state, scan_func): + """See `scan()` for details.""" + super(_ScanDataset, self).__init__(input_dataset) + self._input_dataset = input_dataset + + with ops.name_scope("initial_state"): + # Convert any `SparseTensorValue`s to `SparseTensor`s and all other + # values to tensors. + self._initial_state = nest.pack_sequence_as(initial_state, [ + sparse_tensor.SparseTensor.from_value(t) + if sparse_tensor.is_sparse(t) else ops.convert_to_tensor( + t, name="component_%d" % i) + for i, t in enumerate(nest.flatten(initial_state)) + ]) + + # Compute initial values for the state classes, shapes and types based on + # the initial state. The shapes may be refined by running `tf_scan_func` one + # or more times below. + self._state_classes = sparse.get_classes(self._initial_state) + self._state_shapes = nest.pack_sequence_as( + self._initial_state, + [t.get_shape() for t in nest.flatten(self._initial_state)]) + self._state_types = nest.pack_sequence_as( + self._initial_state, + [t.dtype for t in nest.flatten(self._initial_state)]) + + # Will be populated by calling `tf_scan_func`. + self._output_classes = None + self._output_shapes = None + self._output_types = None + + # Iteratively rerun the scan function until reaching a fixed point on + # `self._state_shapes`. + need_to_rerun = True + while need_to_rerun: + + wrapped_func = dataset_ops.StructuredFunctionWrapper( + scan_func, + "tf.data.experimental.scan()", + input_classes=(self._state_classes, input_dataset.output_classes), + input_shapes=(self._state_shapes, input_dataset.output_shapes), + input_types=(self._state_types, input_dataset.output_types), + add_to_graph=False) + if not ( + isinstance(wrapped_func.output_types, collections.Sequence) and + len(wrapped_func.output_types) == 2): + raise TypeError("The scan function must return a pair comprising the " + "new state and the output value.") + + new_state_classes, self._output_classes = wrapped_func.output_classes + + # Extract and validate class information from the returned values. + for new_state_class, state_class in zip( + nest.flatten(new_state_classes), + nest.flatten(self._state_classes)): + if not issubclass(new_state_class, state_class): + raise TypeError( + "The element classes for the new state must match the initial " + "state. Expected %s; got %s." % + (self._state_classes, new_state_classes)) + + # Extract and validate type information from the returned values. + new_state_types, self._output_types = wrapped_func.output_types + for new_state_type, state_type in zip( + nest.flatten(new_state_types), nest.flatten(self._state_types)): + if new_state_type != state_type: + raise TypeError( + "The element types for the new state must match the initial " + "state. Expected %s; got %s." % + (self._state_types, new_state_types)) + + # Extract shape information from the returned values. + new_state_shapes, self._output_shapes = wrapped_func.output_shapes + + flat_state_shapes = nest.flatten(self._state_shapes) + flat_new_state_shapes = nest.flatten(new_state_shapes) + weakened_state_shapes = [ + original.most_specific_compatible_shape(new) + for original, new in zip(flat_state_shapes, flat_new_state_shapes) + ] + + need_to_rerun = False + for original_shape, weakened_shape in zip(flat_state_shapes, + weakened_state_shapes): + if original_shape.ndims is not None and ( + weakened_shape.ndims is None or + original_shape.as_list() != weakened_shape.as_list()): + need_to_rerun = True + break + + if need_to_rerun: + self._state_shapes = nest.pack_sequence_as(self._state_shapes, + weakened_state_shapes) + + self._scan_func = wrapped_func.function + self._scan_func.add_to_graph(ops.get_default_graph()) + + def _as_variant_tensor(self): + input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access + return gen_dataset_ops.scan_dataset( + input_t, + nest.flatten(sparse.serialize_sparse_tensors(self._initial_state)), + self._scan_func.captured_inputs, + f=self._scan_func, + **dataset_ops.flat_structure(self)) + + @property + def output_classes(self): + return self._output_classes + + @property + def output_shapes(self): + return self._output_shapes + + @property + def output_types(self): + return self._output_types + + +@tf_export("data.experimental.scan") +def scan(initial_state, scan_func): + """A transformation that scans a function across an input dataset. + + This transformation is a stateful relative of `tf.data.Dataset.map`. + In addition to mapping `scan_func` across the elements of the input dataset, + `scan()` accumulates one or more state tensors, whose initial values are + `initial_state`. + + Args: + initial_state: A nested structure of tensors, representing the initial state + of the accumulator. + scan_func: A function that maps `(old_state, input_element)` to + `(new_state, output_element). It must take two arguments and return a + pair of nested structures of tensors. The `new_state` must match the + structure of `initial_state`. + + Returns: + A `Dataset` transformation function, which can be passed to + `tf.data.Dataset.apply`. + """ + def _apply_fn(dataset): + return _ScanDataset(dataset, initial_state, scan_func) + + return _apply_fn diff --git a/tensorflow/python/data/experimental/ops/shuffle_ops.py b/tensorflow/python/data/experimental/ops/shuffle_ops.py new file mode 100644 index 0000000000..a4307212da --- /dev/null +++ b/tensorflow/python/data/experimental/ops/shuffle_ops.py @@ -0,0 +1,102 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Experimental shuffle ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import random_seed +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.util.tf_export import tf_export + + +class _ShuffleAndRepeatDataset(dataset_ops.UnaryDataset): + """A `Dataset` that fuses `shuffle` and `repeat`.""" + + def __init__(self, input_dataset, buffer_size, count=None, seed=None): + super(_ShuffleAndRepeatDataset, self).__init__(input_dataset) + self._input_dataset = input_dataset + self._buffer_size = ops.convert_to_tensor( + buffer_size, dtype=dtypes.int64, name="buffer_size") + if count is None: + self._count = constant_op.constant(-1, dtype=dtypes.int64, name="count") + else: + self._count = ops.convert_to_tensor( + count, dtype=dtypes.int64, name="count") + self._seed, self._seed2 = random_seed.get_seed(seed) + + def _as_variant_tensor(self): + # pylint: disable=protected-access + input_resource = self._input_dataset._as_variant_tensor() + return gen_dataset_ops.shuffle_and_repeat_dataset( + input_resource, + buffer_size=self._buffer_size, + count=self._count, + seed=self._seed, + seed2=self._seed2, + **dataset_ops.flat_structure(self)) + # pylint: enable=protected-access + + @property + def output_classes(self): + return self._input_dataset.output_classes + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_types(self): + return self._input_dataset.output_types + + +@tf_export("data.experimental.shuffle_and_repeat") +def shuffle_and_repeat(buffer_size, count=None, seed=None): + """Shuffles and repeats a Dataset returning a new permutation for each epoch. + + `dataset.apply(tf.data.experimental.shuffle_and_repeat(buffer_size, count))` + + is equivalent to + + `dataset.shuffle(buffer_size, reshuffle_each_iteration=True).repeat(count)` + + The difference is that the latter dataset is not serializable. So, + if you need to checkpoint an input pipeline with reshuffling you must use + this implementation. + + Args: + buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the + maximum number elements that will be buffered when prefetching. + count: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the + number of times the dataset should be repeated. The default behavior + (if `count` is `None` or `-1`) is for the dataset be repeated + indefinitely. + seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the + random seed that will be used to create the distribution. See + `tf.set_random_seed` for behavior. + + Returns: + A `Dataset` transformation function, which can be passed to + `tf.data.Dataset.apply`. + """ + + def _apply_fn(dataset): # pylint: disable=missing-docstring + return _ShuffleAndRepeatDataset(dataset, buffer_size, count, seed) + + return _apply_fn diff --git a/tensorflow/python/data/experimental/ops/stats_ops.py b/tensorflow/python/data/experimental/ops/stats_ops.py new file mode 100644 index 0000000000..c918d223e8 --- /dev/null +++ b/tensorflow/python/data/experimental/ops/stats_ops.py @@ -0,0 +1,205 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Experimental API for gathering statistics from `tf.data` pipelines.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.util.tf_export import tf_export + + +@tf_export("data.experimental.StatsAggregator") +class StatsAggregator(object): + """A stateful resource that aggregates statistics from one or more iterators. + + To record statistics, use one of the custom transformation functions defined + in this module when defining your `tf.data.Dataset`. All statistics will be + aggregated by the `StatsAggregator` that is associated with a particular + iterator (see below). For example, to record the latency of producing each + element by iterating over a dataset: + + ```python + dataset = ... + dataset = dataset.apply(tf.data.experimental.latency_stats("total_bytes")) + ``` + + To associate a `StatsAggregator` with a `tf.data.Dataset` object, use + the following pattern: + + ```python + stats_aggregator = stats_ops.StatsAggregator() + dataset = ... + + # Apply `set_stats_aggregator` to associate `dataset` with `stats_aggregator`. + dataset = dataset.apply( + tf.data.experimental.set_stats_aggregator(stats_aggregator)) + iterator = dataset.make_one_shot_iterator() + ``` + + To get a protocol buffer summary of the currently aggregated statistics, + use the `StatsAggregator.get_summary()` tensor. The easiest way to do this + is to add the returned tensor to the `tf.GraphKeys.SUMMARIES` collection, + so that the summaries will be included with any existing summaries. + + ```python + stats_aggregator = stats_ops.StatsAggregator() + # ... + stats_summary = stats_aggregator.get_summary() + tf.add_to_collection(tf.GraphKeys.SUMMARIES, stats_summary) + ``` + + Note: This interface is experimental and expected to change. In particular, + we expect to add other implementations of `StatsAggregator` that provide + different ways of exporting statistics, and add more types of statistics. + """ + + def __init__(self): + """Creates a `StatsAggregator`.""" + self._resource = gen_dataset_ops.stats_aggregator_handle() + + # TODO(b/116314787): Update this/add support for V2 summary API. + def get_summary(self): + """Returns a string `tf.Tensor` that summarizes the aggregated statistics. + + The returned tensor will contain a serialized `tf.summary.Summary` protocol + buffer, which can be used with the standard TensorBoard logging facilities. + + Returns: + A scalar string `tf.Tensor` that summarizes the aggregated statistics. + """ + return gen_dataset_ops.stats_aggregator_summary(self._resource) + + +class _SetStatsAggregatorDataset(dataset_ops.UnaryDataset): + """A `Dataset` that acts as an identity, and sets given stats_aggregator.""" + + def __init__(self, input_dataset, stats_aggregator): + super(_SetStatsAggregatorDataset, self).__init__(input_dataset) + self._input_dataset = input_dataset + self._stats_aggregator = stats_aggregator + + def _as_variant_tensor(self): + return gen_dataset_ops.set_stats_aggregator_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + self._stats_aggregator._resource, # pylint: disable=protected-access + **dataset_ops.flat_structure(self)) + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_types(self): + return self._input_dataset.output_types + + @property + def output_classes(self): + return self._input_dataset.output_classes + + +@tf_export("data.experimental.set_stats_aggregator") +def set_stats_aggregator(stats_aggregator): + """Set the given `stats_aggregator` for aggregating the input dataset stats. + + Args: + stats_aggregator: A `tf.data.experimental.StatsAggregator` object. + + Returns: + A `Dataset` transformation function, which can be passed to + `tf.data.Dataset.apply`. + """ + + def _apply_fn(dataset): + return _SetStatsAggregatorDataset(dataset, stats_aggregator) + + return _apply_fn + + +# TODO(b/38416882): Properly export in the `tf.data.experimental` API when +# stable or make private / remove. +def bytes_produced_stats(tag): + """Records the number of bytes produced by each element of the input dataset. + + To consume the statistics, associate a `StatsAggregator` with the output + dataset. + + Args: + tag: String. All statistics recorded by the returned transformation will + be associated with the given `tag`. + + Returns: + A `Dataset` transformation function, which can be passed to + `tf.data.Dataset.apply`. + """ + + def _apply_fn(dataset): + return _StatsDataset(dataset, gen_dataset_ops.bytes_produced_stats_dataset, + tag) + + return _apply_fn + + +@tf_export("data.experimental.latency_stats") +def latency_stats(tag): + """Records the latency of producing each element of the input dataset. + + To consume the statistics, associate a `StatsAggregator` with the output + dataset. + + Args: + tag: String. All statistics recorded by the returned transformation will + be associated with the given `tag`. + + Returns: + A `Dataset` transformation function, which can be passed to + `tf.data.Dataset.apply`. + """ + + def _apply_fn(dataset): + return _StatsDataset(dataset, gen_dataset_ops.latency_stats_dataset, tag) + + return _apply_fn + + +class _StatsDataset(dataset_ops.UnaryDataset): + """A `Dataset` that acts as an identity, and also records statistics.""" + + def __init__(self, input_dataset, op_function, tag): + super(_StatsDataset, self).__init__(input_dataset) + self._input_dataset = input_dataset + self._op_function = op_function + self._tag = ops.convert_to_tensor(tag, dtype=dtypes.string) + + def _as_variant_tensor(self): + return self._op_function( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + self._tag, + **dataset_ops.flat_structure(self)) + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_types(self): + return self._input_dataset.output_types + + @property + def output_classes(self): + return self._input_dataset.output_classes diff --git a/tensorflow/python/data/experimental/ops/threadpool.py b/tensorflow/python/data/experimental/ops/threadpool.py new file mode 100644 index 0000000000..3ea017c6e8 --- /dev/null +++ b/tensorflow/python/data/experimental/ops/threadpool.py @@ -0,0 +1,104 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Experimental API for controlling threading in `tf.data` pipelines.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import threading + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import context +from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops +from tensorflow.python.ops import resource_variable_ops + +_uid_counter = 0 +_uid_lock = threading.Lock() + + +def _generate_shared_name(prefix): + with _uid_lock: + global _uid_counter + uid = _uid_counter + _uid_counter += 1 + return "{}{}".format(prefix, uid) + + +# TODO(b/73383364): Properly export in the `tf.data.experimental` API when +# stable or make private / remove. +class PrivateThreadPool(object): + """A stateful resource that represents a private thread pool.""" + + def __init__(self, num_threads, display_name=None, + max_intra_op_parallelism=1): + """Creates a `PrivateThreadPool` with the given number of threads.""" + if context.executing_eagerly(): + shared_name = _generate_shared_name("privatethreadpool") + self._resource = ged_ops.experimental_thread_pool_handle( + num_threads=num_threads, + max_intra_op_parallelism=max_intra_op_parallelism, + display_name=display_name, + shared_name=shared_name) + self._resource_deleter = resource_variable_ops.EagerResourceDeleter( + handle=self._resource, handle_device=context.context().device_name) + else: + self._resource = ged_ops.experimental_thread_pool_handle( + num_threads=num_threads, + max_intra_op_parallelism=max_intra_op_parallelism, + display_name=display_name) + + +class _ThreadPoolDataset(dataset_ops.UnaryDataset): + """A `Dataset` that acts as an identity, and sets a custom threadpool.""" + + def __init__(self, input_dataset, thread_pool): + super(_ThreadPoolDataset, self).__init__(input_dataset) + self._input_dataset = input_dataset + self._thread_pool = thread_pool + + def _as_variant_tensor(self): + return ged_ops.experimental_thread_pool_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + self._thread_pool._resource, # pylint: disable=protected-access + **dataset_ops.flat_structure(self)) + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_types(self): + return self._input_dataset.output_types + + @property + def output_classes(self): + return self._input_dataset.output_classes + + +# TODO(b/73383364): Properly export in the `tf.data.experimental` API when +# stable or make private / remove. +def override_threadpool(dataset, thread_pool): + """Returns a new dataset that uses the given thread pool for its operations. + + Args: + dataset: A `tf.data.Dataset` object. + thread_pool: A `PrivateThreadPool` object. + + Returns: + A dataset containing the same values as `dataset`, but which uses + `thread_pool` to compute any of its parallel operations (such as + `tf.data.Dataset.map`). + """ + return _ThreadPoolDataset(dataset, thread_pool) diff --git a/tensorflow/python/data/experimental/ops/unique.py b/tensorflow/python/data/experimental/ops/unique.py new file mode 100644 index 0000000000..2a7775c456 --- /dev/null +++ b/tensorflow/python/data/experimental/ops/unique.py @@ -0,0 +1,79 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Unique element dataset transformations.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import gen_experimental_dataset_ops +from tensorflow.python.util.tf_export import tf_export + + +@tf_export("data.experimental.unique") +def unique(): + """Creates a `Dataset` from another `Dataset`, discarding duplicates. + + Use this transformation to produce a dataset that contains one instance of + each unique element in the input. For example: + + ```python + dataset = tf.data.Dataset.from_tensor_slices([1, 37, 2, 37, 2, 1]) + + # Using `unique()` will drop the duplicate elements. + dataset = dataset.apply(tf.data.experimental.unique()) # ==> { 1, 37, 2 } + ``` + + Returns: + A `Dataset` transformation function, which can be passed to + `tf.data.Dataset.apply`. + """ + + def _apply_fn(dataset): + return _UniqueDataset(dataset) + + return _apply_fn + + +class _UniqueDataset(dataset_ops.UnaryDataset): + """A `Dataset` contains the unique elements from its input.""" + + def __init__(self, input_dataset): + """See `unique()` for details.""" + super(_UniqueDataset, self).__init__(input_dataset) + self._input_dataset = input_dataset + if input_dataset.output_types not in (dtypes.int32, dtypes.int64, + dtypes.string): + raise TypeError( + "`tf.data.experimental.unique()` only supports inputs with a single " + "`tf.int32`, `tf.int64`, or `tf.string` component.") + + def _as_variant_tensor(self): + return gen_experimental_dataset_ops.experimental_unique_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + **dataset_ops.flat_structure(self)) + + @property + def output_classes(self): + return self._input_dataset.output_classes + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_types(self): + return self._input_dataset.output_types diff --git a/tensorflow/python/data/experimental/ops/writers.py b/tensorflow/python/data/experimental/ops/writers.py new file mode 100644 index 0000000000..994447cb4d --- /dev/null +++ b/tensorflow/python/data/experimental/ops/writers.py @@ -0,0 +1,60 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python wrappers for tf.data writers.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import convert +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.util.tf_export import tf_export + + +@tf_export("data.experimental.TFRecordWriter") +class TFRecordWriter(object): + """Writes data to a TFRecord file.""" + + def __init__(self, filename, compression_type=None): + self._filename = ops.convert_to_tensor( + filename, dtypes.string, name="filename") + self._compression_type = convert.optional_param_to_tensor( + "compression_type", + compression_type, + argument_default="", + argument_dtype=dtypes.string) + + def write(self, dataset): + """Returns a `tf.Operation` to write a dataset to a file. + + Args: + dataset: a `tf.data.Dataset` whose elements are to be written to a file + + Returns: + A `tf.Operation` that, when run, writes contents of `dataset` to a file. + """ + if not isinstance(dataset, dataset_ops.Dataset): + raise TypeError("`dataset` must be a `tf.data.Dataset` object.") + if (dataset.output_types != dtypes.string or + dataset.output_shapes != tensor_shape.scalar()): + raise TypeError( + "`dataset` must produce scalar `DT_STRING` tensors whereas it " + "produces shape {0} and types {1}".format(dataset.output_shapes, + dataset.output_types)) + return gen_dataset_ops.dataset_to_tf_record( + dataset._as_variant_tensor(), self._filename, self._compression_type) # pylint: disable=protected-access diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 6bba72a8e9..3b9d3a639d 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -889,8 +889,8 @@ class Dataset(object): will be padded out to the maximum length of all elements in that dimension. - See also `tf.contrib.data.dense_to_sparse_batch`, which combines elements - that may have different shapes into a `tf.SparseTensor`. + See also `tf.data.experimental.dense_to_sparse_batch`, which combines + elements that may have different shapes into a `tf.SparseTensor`. Args: batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of diff --git a/tensorflow/python/data/ops/optional_ops.py b/tensorflow/python/data/ops/optional_ops.py index 3bbebd7878..aca989e03a 100644 --- a/tensorflow/python/data/ops/optional_ops.py +++ b/tensorflow/python/data/ops/optional_ops.py @@ -31,7 +31,7 @@ class Optional(object): An `Optional` can represent the result of an operation that may fail as a value, rather than raising an exception and halting execution. For example, - `tf.contrib.data.get_next_as_optional` returns an `Optional` that either + `tf.data.experimental.get_next_as_optional` returns an `Optional` that either contains the next value from a `tf.data.Iterator` if one exists, or a "none" value that indicates the end of the sequence has been reached. """ @@ -111,7 +111,7 @@ class Optional(object): class _OptionalImpl(Optional): - """Concrete implementation of `tf.contrib.data.Optional`. + """Concrete implementation of `tf.data.experimental.Optional`. NOTE(mrry): This implementation is kept private, to avoid defining `Optional.__init__()` in the public API. diff --git a/tensorflow/python/data/ops/readers.py b/tensorflow/python/data/ops/readers.py index b0f26631f9..d08da6704c 100644 --- a/tensorflow/python/data/ops/readers.py +++ b/tensorflow/python/data/ops/readers.py @@ -129,7 +129,7 @@ class ParallelInterleaveDataset(dataset_ops.InterleaveDataset): def __init__(self, input_dataset, map_func, cycle_length, block_length, sloppy, buffer_output_elements, prefetch_input_elements): - """See `tf.contrib.data.parallel_interleave()` for details.""" + """See `tf.data.experimental.parallel_interleave()` for details.""" super(ParallelInterleaveDataset, self).__init__(input_dataset, map_func, cycle_length, block_length) self._sloppy = ops.convert_to_tensor( @@ -158,7 +158,7 @@ class ParallelInterleaveDataset(dataset_ops.InterleaveDataset): # pylint: enable=protected-access def _transformation_name(self): - return "tf.contrib.data.parallel_interleave()" + return "tf.data.experimental.parallel_interleave()" @tf_export("data.TFRecordDataset") |