diff options
author | Shivani Agrawal <shivaniagrawal@google.com> | 2018-09-13 16:42:57 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-13 16:47:05 -0700 |
commit | 3b438e4a24dd0f113f1d36d97196a027bd473fc4 (patch) | |
tree | d6191ef394a98a2348ee109faef7604867759b4b /tensorflow/contrib/data | |
parent | 5dd20118a25e8d29b7684cf5fb17951657a4a687 (diff) |
[tf.data] Changes `make_batched_features_dataset` and `make_tf_record_dataset` default `prefetch` buffer size to auto-tune (from 1).
PiperOrigin-RevId: 212900920
Diffstat (limited to 'tensorflow/contrib/data')
-rw-r--r-- | tensorflow/contrib/data/__init__.py | 9 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/ops/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/ops/optimization.py | 3 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/ops/readers.py | 39 |
4 files changed, 29 insertions, 23 deletions
diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py index baec238c62..c378b1ce8d 100644 --- a/tensorflow/contrib/data/__init__.py +++ b/tensorflow/contrib/data/__init__.py @@ -62,6 +62,8 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview. @@sloppy_interleave @@unbatch @@unique + +@@AUTOTUNE """ from __future__ import absolute_import @@ -91,6 +93,10 @@ from tensorflow.contrib.data.python.ops.interleave_ops import sample_from_datase from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave from tensorflow.contrib.data.python.ops.iterator_ops import CheckpointInputPipelineHook from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator + +# Optimization constant that can be used to enable auto-tuning. +from tensorflow.contrib.data.python.ops.optimization import AUTOTUNE + from tensorflow.contrib.data.python.ops.parsing_ops import parse_example_dataset from tensorflow.contrib.data.python.ops.prefetching_ops import copy_to_device from tensorflow.contrib.data.python.ops.prefetching_ops import prefetch_to_device @@ -113,6 +119,3 @@ from tensorflow.python.data.ops.optional_ops import Optional from tensorflow.python.util.all_util import remove_undocumented remove_undocumented(__name__) - -# A constant that can be used to enable auto-tuning. -AUTOTUNE = -1 diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index 4b45cc7e36..a14781cd93 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -80,6 +80,7 @@ py_library( ":batching", ":gen_dataset_ops", ":interleave_ops", + ":optimization", ":parsing_ops", ":shuffle_ops", "//tensorflow/python:constant_op", diff --git a/tensorflow/contrib/data/python/ops/optimization.py b/tensorflow/contrib/data/python/ops/optimization.py index 4114b62e29..73840452df 100644 --- a/tensorflow/contrib/data/python/ops/optimization.py +++ b/tensorflow/contrib/data/python/ops/optimization.py @@ -24,6 +24,9 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import gen_dataset_ops +# A constant that can be used to enable auto-tuning. +AUTOTUNE = -1 + # TODO(jsimsa): Support RE matching for both individual transformation (e.g. to # account for indexing) and transformation sequence. diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py index 4c466781f7..785b395707 100644 --- a/tensorflow/contrib/data/python/ops/readers.py +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -25,6 +25,7 @@ import numpy as np from tensorflow.contrib.data.python.ops import batching from tensorflow.contrib.data.python.ops import gen_dataset_ops as contrib_gen_dataset_ops from tensorflow.contrib.data.python.ops import interleave_ops +from tensorflow.contrib.data.python.ops import optimization from tensorflow.contrib.data.python.ops import parsing_ops from tensorflow.contrib.data.python.ops import shuffle_ops from tensorflow.python.data.ops import dataset_ops @@ -214,18 +215,17 @@ def _maybe_shuffle_and_repeat( return dataset -def make_tf_record_dataset( - file_pattern, - batch_size, - parser_fn=None, - num_epochs=None, - shuffle=True, - shuffle_buffer_size=None, - shuffle_seed=None, - prefetch_buffer_size=None, - num_parallel_reads=None, - num_parallel_parser_calls=None, - drop_final_batch=False): +def make_tf_record_dataset(file_pattern, + batch_size, + parser_fn=None, + num_epochs=None, + shuffle=True, + shuffle_buffer_size=None, + shuffle_seed=None, + prefetch_buffer_size=optimization.AUTOTUNE, + num_parallel_reads=None, + num_parallel_parser_calls=None, + drop_final_batch=False): """Reads and optionally parses TFRecord files into a dataset. Provides common functionality such as batching, optional parsing, shuffling, @@ -300,8 +300,6 @@ def make_tf_record_dataset( parser_fn, batch_size, num_parallel_calls=num_parallel_parser_calls, drop_remainder=drop_final_batch)) - if prefetch_buffer_size is None: - prefetch_buffer_size = -1 # tf.config.data.AUTOTUNE if prefetch_buffer_size == 0: return dataset else: @@ -323,7 +321,7 @@ def make_csv_dataset( shuffle=True, shuffle_buffer_size=10000, shuffle_seed=None, - prefetch_buffer_size=1, + prefetch_buffer_size=optimization.AUTOTUNE, num_parallel_reads=1, sloppy=False, num_rows_for_inference=100, @@ -386,9 +384,10 @@ def make_csv_dataset( shuffle_buffer_size: Buffer size to use for shuffling. A large buffer size ensures better shuffling, but increases memory usage and startup time. shuffle_seed: Randomization seed to use for shuffling. - prefetch_buffer_size: An int specifying the number of feature batches to - prefetch for performance improvement. Recommended value is the number of - batches consumed per training step. + prefetch_buffer_size: An int specifying the number of feature + batches to prefetch for performance improvement. Recommended value is the + number of batches consumed per training step. Defaults to auto-tune. + num_parallel_reads: Number of threads used to read CSV records from files. If >1, the results will be interleaved. sloppy: If `True`, reading performance will be improved at @@ -666,7 +665,7 @@ def make_batched_features_dataset(file_pattern, shuffle=True, shuffle_buffer_size=10000, shuffle_seed=None, - prefetch_buffer_size=1, + prefetch_buffer_size=optimization.AUTOTUNE, reader_num_threads=1, parser_num_threads=2, sloppy_ordering=False, @@ -739,7 +738,7 @@ def make_batched_features_dataset(file_pattern, shuffle_seed: Randomization seed to use for shuffling. prefetch_buffer_size: Number of feature batches to prefetch in order to improve performance. Recommended value is the number of batches consumed - per training step (default is 1). + per training step. Defaults to auto-tune. reader_num_threads: Number of threads used to read `Example` records. If >1, the results will be interleaved. parser_num_threads: Number of threads to use for parsing `Example` tensors |