aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/slim/python/slim/data/parallel_reader.py
blob: b3343aef47d9f352c3bcbef4afbe8f9bf2560e6d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implements a parallel data reader with queues and optional shuffling."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.python.framework import dtypes as tf_dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import io_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import gfile
from tensorflow.python.summary import summary
from tensorflow.python.training import input as tf_input
from tensorflow.python.training import queue_runner


class ParallelReader(io_ops.ReaderBase):
  """Reader class that uses multiple readers in parallel to improve speed.

  See ReaderBase for supported methods.
  """

  def __init__(self,
               reader_class,
               common_queue,
               num_readers=4,
               reader_kwargs=None):
    """ParallelReader creates num_readers instances of the reader_class.

    Each instance is created by calling the `reader_class` function passing
    the arguments specified in `reader_kwargs` as in:
      reader_class(**read_kwargs)

    When you read from a ParallelReader, with its `read()` method,
    you just dequeue examples from the `common_queue`.

    The readers will read different files in parallel, asynchronously enqueueing
    their output into `common_queue`. The `common_queue.dtypes` must be
    [tf.string, tf.string]

    Because each reader can read from a different file, the examples in the
    `common_queue` could be from different files. Due to the asynchronous
    reading there is no guarantee that all the readers will read the same
    number of examples.

    If the `common_queue` is a shuffling queue, then the examples are shuffled.

    Usage:
      common_queue = tf.RandomShuffleQueue(
          capacity=256,
          min_after_dequeue=128,
          dtypes=[tf.string, tf.string])
      p_reader = ParallelReader(tf.TFRecordReader, common_queue)

      common_queue = tf.FIFOQueue(
          capacity=256,
          dtypes=[tf.string, tf.string])
      p_reader = ParallelReader(readers, common_queue, num_readers=2)


    Args:
      reader_class: one of the io_ops.ReaderBase subclasses ex: TFRecordReader
      common_queue: a Queue to hold (key, value pairs) with `dtypes` equal to
        [tf.string, tf.string]. Must be one of the data_flow_ops.Queues
        instances, ex. `tf.FIFOQueue()`, `tf.RandomShuffleQueue()`, ...
      num_readers: a integer, number of instances of reader_class to create.
      reader_kwargs: an optional dict of kwargs to create the readers.

    Raises:
      TypeError: if `common_queue.dtypes` is not [tf.string, tf.string].
    """
    if len(common_queue.dtypes) != 2:
      raise TypeError('common_queue.dtypes must be [tf.string, tf.string]')
    for dtype in common_queue.dtypes:
      if not dtype.is_compatible_with(tf_dtypes.string):
        raise TypeError('common_queue.dtypes must be [tf.string, tf.string]')

    reader_kwargs = reader_kwargs or {}
    self._readers = [reader_class(**reader_kwargs) for _ in range(num_readers)]
    self._common_queue = common_queue

  @property
  def num_readers(self):
    return len(self._readers)

  @property
  def common_queue(self):
    return self._common_queue

  def read(self, queue, name=None):
    """Returns the next record (key, value pair) produced by the reader.

    The multiple reader instances are all configured to `read()` from the
    filenames listed in `queue` and enqueue their output into the `common_queue`
    passed to the constructor, and this method returns the next record dequeued
    from that `common_queue`.


    Readers dequeue a work unit from `queue` if necessary (e.g. when a
    reader needs to start reading from a new file since it has finished with
    the previous file).

    A queue runner for enqueing in the `common_queue` is automatically added to
    the TF QueueRunners collection.

    Args:
      queue: A Queue or a mutable string Tensor representing a handle
        to a Queue, with string work items.
      name: A name for the operation (optional).

    Returns:
      The next record (i.e. (key, value pair)) from the common_queue.
    """

    self._configure_readers_by(queue)
    return self._common_queue.dequeue(name=name)

  def read_up_to(self, queue, num_records, name=None):
    """Returns up to num_records (key, value pairs) produced by a reader.

    Will dequeue a work unit from queue if necessary (e.g., when the
    Reader needs to start reading from a new file since it has
    finished with the previous file).
    It may return less than num_records even before the last batch.

    **Note** This operation is not supported by all types of `common_queue`s.
    If a `common_queue` does not support `dequeue_up_to()`, then a
    `tf.errors.UnimplementedError` is raised.

    Args:
      queue: A Queue or a mutable string Tensor representing a handle
        to a Queue, with string work items.
      num_records: Number of records to read.
      name: A name for the operation (optional).

    Returns:
      A tuple of Tensors (keys, values) from common_queue.
      keys: A 1-D string Tensor.
      values: A 1-D string Tensor.
    """
    self._configure_readers_by(queue)
    return self._common_queue.dequeue_up_to(num_records, name)

  def _configure_readers_by(self, queue):
    enqueue_ops = []
    for reader in self._readers:
      enqueue_ops.append(self._common_queue.enqueue(reader.read(queue)))

    queue_runner.add_queue_runner(
        queue_runner.QueueRunner(self._common_queue, enqueue_ops))

  def num_records_produced(self, name=None):
    """Returns the number of records this reader has produced.

    Args:
      name: A name for the operation (optional).

    Returns:
      An int64 Tensor.

    """
    num_records = [r.num_records_produced() for r in self._readers]
    return math_ops.add_n(num_records, name=name)

  def num_work_units_completed(self, name=None):
    """Returns the number of work units this reader has finished processing.

    Args:
      name: A name for the operation (optional).

    Returns:
      An int64 Tensor.
    """
    num_work_units = [r.num_work_units_completed() for r in self._readers]
    return math_ops.add_n(num_work_units, name=name)


def parallel_read(data_sources,
                  reader_class,
                  num_epochs=None,
                  num_readers=4,
                  reader_kwargs=None,
                  shuffle=True,
                  dtypes=None,
                  capacity=256,
                  min_after_dequeue=128,
                  seed=None,
                  scope=None):
  """Reads multiple records in parallel from data_sources using n readers.

  It uses a ParallelReader to read from multiple files in parallel using
  multiple readers created using `reader_class` with `reader_kwargs'.

  If shuffle is True the common_queue would be a RandomShuffleQueue otherwise
  it would be a FIFOQueue.

  Usage:
      data_sources = ['path_to/train*']
      key, value = parallel_read(data_sources, tf.CSVReader, num_readers=4)

  Args:
    data_sources: a list/tuple of files or the location of the data, i.e.
      /path/to/train@128, /path/to/train* or /tmp/.../train*
    reader_class: one of the io_ops.ReaderBase subclasses ex: TFRecordReader
    num_epochs: The number of times each data source is read. If left as None,
        the data will be cycled through indefinitely.
    num_readers: a integer, number of Readers to create.
    reader_kwargs: an optional dict, of kwargs for the reader.
    shuffle: boolean, whether should shuffle the files and the records by using
      RandomShuffleQueue as common_queue.
    dtypes:  A list of types.  The length of dtypes must equal the number
        of elements in each record. If it is None it will default to
        [tf.string, tf.string] for (key, value).
    capacity: integer, capacity of the common_queue.
    min_after_dequeue: integer, minimum number of records in the common_queue
      after dequeue. Needed for a good shuffle.
    seed: A seed for RandomShuffleQueue.
    scope: Optional name scope for the ops.

  Returns:
    key, value: a tuple of keys and values from the data_source.
  """
  data_files = get_data_files(data_sources)
  with ops.name_scope(scope, 'parallel_read'):
    filename_queue = tf_input.string_input_producer(
        data_files, num_epochs=num_epochs, shuffle=shuffle, seed=seed,
        name='filenames')
    dtypes = dtypes or [tf_dtypes.string, tf_dtypes.string]
    if shuffle:
      common_queue = data_flow_ops.RandomShuffleQueue(
          capacity=capacity,
          min_after_dequeue=min_after_dequeue,
          dtypes=dtypes,
          seed=seed,
          name='common_queue')
    else:
      common_queue = data_flow_ops.FIFOQueue(
          capacity=capacity, dtypes=dtypes, name='common_queue')

    summary.scalar('fraction_of_%d_full' % capacity,
                   math_ops.to_float(common_queue.size()) * (1. / capacity))

    return ParallelReader(
        reader_class,
        common_queue,
        num_readers=num_readers,
        reader_kwargs=reader_kwargs).read(filename_queue)


def single_pass_read(data_sources, reader_class, reader_kwargs=None,
                     scope=None):
  """Reads sequentially the data_sources using the reader, doing a single pass.

  Args:
    data_sources: a list/tuple of files or the location of the data, i.e.
      /path/to/train@128, /path/to/train* or /tmp/.../train*
    reader_class: one of the io_ops.ReaderBase subclasses ex: TFRecordReader.
    reader_kwargs: an optional dict, of kwargs for the reader.
    scope: Optional name scope for the ops.

  Returns:
    key, value: a tuple of keys and values from the data_source.
  """
  data_files = get_data_files(data_sources)
  with ops.name_scope(scope, 'single_pass_read'):
    filename_queue = tf_input.string_input_producer(
        data_files, num_epochs=1, shuffle=False, capacity=1, name='filenames')
    reader_kwargs = reader_kwargs or {}
    return reader_class(**reader_kwargs).read(filename_queue)


def get_data_files(data_sources):
  """Get data_files from data_sources.

  Args:
    data_sources: a list/tuple of files or the location of the data, i.e.
      /path/to/train@128, /path/to/train* or /tmp/.../train*

  Returns:
    a list of data_files.

  Raises:
    ValueError: if not data files are not found

  """
  if isinstance(data_sources, (list, tuple)):
    data_files = []
    for source in data_sources:
      data_files += get_data_files(source)
  else:
    if '*' in data_sources or '?' in data_sources or '[' in data_sources:
      data_files = gfile.Glob(data_sources)
    else:
      data_files = [data_sources]
  if not data_files:
    raise ValueError('No data files found in %s' % (data_sources,))
  return data_files