aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data
diff options
context:
space:
mode:
authorGravatar Shivani Agrawal <shivaniagrawal@google.com>2018-09-13 16:42:57 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-13 16:47:05 -0700
commit3b438e4a24dd0f113f1d36d97196a027bd473fc4 (patch)
treed6191ef394a98a2348ee109faef7604867759b4b /tensorflow/contrib/data
parent5dd20118a25e8d29b7684cf5fb17951657a4a687 (diff)
[tf.data] Changes `make_batched_features_dataset` and `make_tf_record_dataset` default `prefetch` buffer size to auto-tune (from 1).
PiperOrigin-RevId: 212900920
Diffstat (limited to 'tensorflow/contrib/data')
-rw-r--r--tensorflow/contrib/data/__init__.py9
-rw-r--r--tensorflow/contrib/data/python/ops/BUILD1
-rw-r--r--tensorflow/contrib/data/python/ops/optimization.py3
-rw-r--r--tensorflow/contrib/data/python/ops/readers.py39
4 files changed, 29 insertions, 23 deletions
diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py
index baec238c62..c378b1ce8d 100644
--- a/tensorflow/contrib/data/__init__.py
+++ b/tensorflow/contrib/data/__init__.py
@@ -62,6 +62,8 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
@@sloppy_interleave
@@unbatch
@@unique
+
+@@AUTOTUNE
"""
from __future__ import absolute_import
@@ -91,6 +93,10 @@ from tensorflow.contrib.data.python.ops.interleave_ops import sample_from_datase
from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave
from tensorflow.contrib.data.python.ops.iterator_ops import CheckpointInputPipelineHook
from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator
+
+# Optimization constant that can be used to enable auto-tuning.
+from tensorflow.contrib.data.python.ops.optimization import AUTOTUNE
+
from tensorflow.contrib.data.python.ops.parsing_ops import parse_example_dataset
from tensorflow.contrib.data.python.ops.prefetching_ops import copy_to_device
from tensorflow.contrib.data.python.ops.prefetching_ops import prefetch_to_device
@@ -113,6 +119,3 @@ from tensorflow.python.data.ops.optional_ops import Optional
from tensorflow.python.util.all_util import remove_undocumented
remove_undocumented(__name__)
-
-# A constant that can be used to enable auto-tuning.
-AUTOTUNE = -1
diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD
index 4b45cc7e36..a14781cd93 100644
--- a/tensorflow/contrib/data/python/ops/BUILD
+++ b/tensorflow/contrib/data/python/ops/BUILD
@@ -80,6 +80,7 @@ py_library(
":batching",
":gen_dataset_ops",
":interleave_ops",
+ ":optimization",
":parsing_ops",
":shuffle_ops",
"//tensorflow/python:constant_op",
diff --git a/tensorflow/contrib/data/python/ops/optimization.py b/tensorflow/contrib/data/python/ops/optimization.py
index 4114b62e29..73840452df 100644
--- a/tensorflow/contrib/data/python/ops/optimization.py
+++ b/tensorflow/contrib/data/python/ops/optimization.py
@@ -24,6 +24,9 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_dataset_ops
+# A constant that can be used to enable auto-tuning.
+AUTOTUNE = -1
+
# TODO(jsimsa): Support RE matching for both individual transformation (e.g. to
# account for indexing) and transformation sequence.
diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py
index 4c466781f7..785b395707 100644
--- a/tensorflow/contrib/data/python/ops/readers.py
+++ b/tensorflow/contrib/data/python/ops/readers.py
@@ -25,6 +25,7 @@ import numpy as np
from tensorflow.contrib.data.python.ops import batching
from tensorflow.contrib.data.python.ops import gen_dataset_ops as contrib_gen_dataset_ops
from tensorflow.contrib.data.python.ops import interleave_ops
+from tensorflow.contrib.data.python.ops import optimization
from tensorflow.contrib.data.python.ops import parsing_ops
from tensorflow.contrib.data.python.ops import shuffle_ops
from tensorflow.python.data.ops import dataset_ops
@@ -214,18 +215,17 @@ def _maybe_shuffle_and_repeat(
return dataset
-def make_tf_record_dataset(
- file_pattern,
- batch_size,
- parser_fn=None,
- num_epochs=None,
- shuffle=True,
- shuffle_buffer_size=None,
- shuffle_seed=None,
- prefetch_buffer_size=None,
- num_parallel_reads=None,
- num_parallel_parser_calls=None,
- drop_final_batch=False):
+def make_tf_record_dataset(file_pattern,
+ batch_size,
+ parser_fn=None,
+ num_epochs=None,
+ shuffle=True,
+ shuffle_buffer_size=None,
+ shuffle_seed=None,
+ prefetch_buffer_size=optimization.AUTOTUNE,
+ num_parallel_reads=None,
+ num_parallel_parser_calls=None,
+ drop_final_batch=False):
"""Reads and optionally parses TFRecord files into a dataset.
Provides common functionality such as batching, optional parsing, shuffling,
@@ -300,8 +300,6 @@ def make_tf_record_dataset(
parser_fn, batch_size, num_parallel_calls=num_parallel_parser_calls,
drop_remainder=drop_final_batch))
- if prefetch_buffer_size is None:
- prefetch_buffer_size = -1 # tf.config.data.AUTOTUNE
if prefetch_buffer_size == 0:
return dataset
else:
@@ -323,7 +321,7 @@ def make_csv_dataset(
shuffle=True,
shuffle_buffer_size=10000,
shuffle_seed=None,
- prefetch_buffer_size=1,
+ prefetch_buffer_size=optimization.AUTOTUNE,
num_parallel_reads=1,
sloppy=False,
num_rows_for_inference=100,
@@ -386,9 +384,10 @@ def make_csv_dataset(
shuffle_buffer_size: Buffer size to use for shuffling. A large buffer size
ensures better shuffling, but increases memory usage and startup time.
shuffle_seed: Randomization seed to use for shuffling.
- prefetch_buffer_size: An int specifying the number of feature batches to
- prefetch for performance improvement. Recommended value is the number of
- batches consumed per training step.
+ prefetch_buffer_size: An int specifying the number of feature
+ batches to prefetch for performance improvement. Recommended value is the
+ number of batches consumed per training step. Defaults to auto-tune.
+
num_parallel_reads: Number of threads used to read CSV records from files.
If >1, the results will be interleaved.
sloppy: If `True`, reading performance will be improved at
@@ -666,7 +665,7 @@ def make_batched_features_dataset(file_pattern,
shuffle=True,
shuffle_buffer_size=10000,
shuffle_seed=None,
- prefetch_buffer_size=1,
+ prefetch_buffer_size=optimization.AUTOTUNE,
reader_num_threads=1,
parser_num_threads=2,
sloppy_ordering=False,
@@ -739,7 +738,7 @@ def make_batched_features_dataset(file_pattern,
shuffle_seed: Randomization seed to use for shuffling.
prefetch_buffer_size: Number of feature batches to prefetch in order to
improve performance. Recommended value is the number of batches consumed
- per training step (default is 1).
+ per training step. Defaults to auto-tune.
reader_num_threads: Number of threads used to read `Example` records. If >1,
the results will be interleaved.
parser_num_threads: Number of threads to use for parsing `Example` tensors