aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/learn/python/learn/io/data_feeder.py
blob: f23a4a3beae4fdad1f666b2d02586b12dd79e086 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Implementations of different data feeders to provide data for TF trainer."""

# TODO(ipolosukhin): Replace this module with feed-dict queue runners & queues.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import itertools
import math

import numpy as np
import six
from six.moves import xrange  # pylint: disable=redefined-builtin

from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops

# pylint: disable=g-multiple-import,g-bad-import-order
from .pandas_io import HAS_PANDAS, extract_pandas_data, extract_pandas_matrix, extract_pandas_labels
from .dask_io import HAS_DASK, extract_dask_data, extract_dask_labels
# pylint: enable=g-multiple-import,g-bad-import-order


def _get_in_out_shape(x_shape, y_shape, n_classes, batch_size=None):
  """Returns shape for input and output of the data feeder."""
  if batch_size is None:
    batch_size = x_shape[0]
  elif batch_size <= 0:
    raise ValueError('Invalid batch_size %d.' % batch_size)
  x_shape = list(x_shape[1:]) if len(x_shape) > 1 else [1]
  input_shape = [batch_size] + x_shape
  if y_shape is None:
    return input_shape, None, batch_size
  y_shape = list(y_shape[1:]) if len(y_shape) > 1 else []
  # Skip first dimension if it is 1.
  if y_shape and y_shape[0] == 1:
    y_shape = y_shape[1:]
  if n_classes is not None and n_classes > 1:
    output_shape = [batch_size] + y_shape + [n_classes]
  else:
    output_shape = [batch_size] + y_shape
  return input_shape, output_shape, batch_size


def _data_type_filter(x, y):
  """Filter data types into acceptable format."""
  if HAS_DASK:
    x = extract_dask_data(x)
    if y is not None:
      y = extract_dask_labels(y)
  if HAS_PANDAS:
    x = extract_pandas_data(x)
    if y is not None:
      y = extract_pandas_labels(y)
  return x, y


def _is_iterable(x):
  return hasattr(x, 'next') or hasattr(x, '__next__')


def setup_train_data_feeder(
    x, y, n_classes, batch_size=None, shuffle=True, epochs=None):
  """Create data feeder, to sample inputs from dataset.

  If `x` and `y` are iterators, use `StreamingDataFeeder`.

  Args:
    x: numpy, pandas or Dask matrix or iterable.
    y: numpy, pandas or Dask array or iterable.
    n_classes: number of classes.
    batch_size: size to split data into parts. Must be >= 1.
    shuffle: Whether to shuffle the inputs.
    epochs: Number of epochs to run.

  Returns:
    DataFeeder object that returns training data.

  Raises:
    ValueError: if one of `x` and `y` is iterable and the other is not.
  """
  x, y = _data_type_filter(x, y)
  if HAS_DASK:
    # pylint: disable=g-import-not-at-top
    import dask.dataframe as dd
    if (isinstance(x, (dd.Series, dd.DataFrame)) and
        (y is None or isinstance(y, (dd.Series, dd.DataFrame)))):
      data_feeder_cls = DaskDataFeeder
    else:
      data_feeder_cls = DataFeeder
  else:
    data_feeder_cls = DataFeeder

  if _is_iterable(x):
    if y is not None and not _is_iterable(y):
      raise ValueError('Both x and y should be iterators for '
                       'streaming learning to work.')
    return StreamingDataFeeder(x, y, n_classes, batch_size)
  return data_feeder_cls(
      x, y, n_classes, batch_size, shuffle=shuffle, epochs=epochs)


def _batch_data(x, batch_size=None):
  if (batch_size is not None) and (batch_size <= 0):
    raise ValueError('Invalid batch_size %d.' % batch_size)
  chunk = []
  for data in x:
    chunk.append(data)
    if (batch_size is not None) and (len(chunk) >= batch_size):
      yield np.matrix(chunk)
      chunk = []
  yield np.matrix(chunk)


def setup_predict_data_feeder(x, batch_size=None):
  """Returns an iterable for feeding into predict step.

  Args:
    x: numpy, pandas, Dask array or iterable.
    batch_size: Size of batches to split data into.
      If `None`, returns one batch of full size.

  Returns:
    List or iterator of parts of data to predict on.

  Raises:
    ValueError: if `batch_size` <= 0.
  """
  if HAS_DASK:
    x = extract_dask_data(x)
  if HAS_PANDAS:
    x = extract_pandas_data(x)
  if _is_iterable(x):
    return _batch_data(x, batch_size)
  if len(x.shape) == 1:
    x = np.reshape(x, (-1, 1))
  if batch_size is not None:
    if batch_size <= 0:
      raise ValueError('Invalid batch_size %d.' % batch_size)
    n_batches = int(math.ceil(float(len(x)) / batch_size))
    return [x[i * batch_size:(i + 1) * batch_size] for i in xrange(n_batches)]
  return [x]


def setup_processor_data_feeder(x):
  """Sets up processor iterable.

  Args:
    x: numpy, pandas or iterable.

  Returns:
    Iterable of data to process.
  """
  if HAS_PANDAS:
    x = extract_pandas_matrix(x)
  return x


def check_array(array, dtype):
  """Checks array on dtype and converts it if different.

  Args:
    array: Input array.
    dtype: Expected dtype.

  Returns:
    Original array or converted.
  """
  # skip check if array is instance of other classes, e.g. h5py.Dataset
  # to avoid copying array and loading whole data into memory
  if isinstance(array, (np.ndarray, list)):
    array = np.array(array, dtype=dtype, order=None, copy=False)
  return array


class DataFeeder(object):
  """Data feeder is an example class to sample data for TF trainer.

  Parameters:
    x: feature Nd numpy matrix of shape [n_samples, n_features, ...].
    y: target vector, either floats for regression or class id for
      classification. If matrix, will consider as a sequence
      of targets. Can be None for unsupervised setting.
    n_classes: number of classes, 0 and 1 are considered regression, None will
      pass through the input labels without one-hot conversion.
    batch_size: mini batch size to accumulate.
    random_state: numpy RandomState object to reproduce sampling.

  Attributes:
    x: input features.
    y: input target.
    n_classes: number of classes (if None, pass through indices without
      one-hot conversion).
    batch_size: mini batch size to accumulate.
    input_shape: shape of the input.
    output_shape: shape of the output.
    input_dtype: dtype of input.
    output_dtype: dtype of output.
  """

  def __init__(
      self, x, y, n_classes, batch_size=None, shuffle=True, random_state=None,
      epochs=None):
    x_dtype = np.int64 if x.dtype == np.int64 else np.float32
    y_dtype = (
        np.int64 if n_classes is not None and n_classes > 1 else np.float32)
    self.x = check_array(x, dtype=x_dtype)
    # self.n_classes is None means we're passing in raw target indices
    if n_classes is not None:
      self.y = (None if y is None else check_array(y, dtype=y_dtype))
    else:
      self.y = y
      if isinstance(self.y, list):
        self.y = np.array(y)
    self.n_classes = n_classes
    self.max_epochs = epochs
    self.input_shape, self.output_shape, self._batch_size = _get_in_out_shape(
        self.x.shape, None if self.y is None else self.y.shape, n_classes,
        batch_size)
    # Input dtype matches dtype of x.
    self.input_dtype = x_dtype
    # self.n_classes is None means we're passing in raw target indices
    if n_classes is not None or y is None:
      self.output_dtype = np.float32
    else:
      self.output_dtype = self.y.dtype
    self.shuffle = shuffle
    self.random_state = np.random.RandomState(
        42) if random_state is None else random_state
    if self.shuffle:
      self.indices = self.random_state.permutation(self.x.shape[0])
    else:
      self.indices = np.array(range(self.x.shape[0]))
    self.offset = 0
    self.epoch = 0
    self._epoch_placeholder = None

  @property
  def batch_size(self):
    return self._batch_size

  def make_epoch_variable(self):
    """Adds a placeholder variable for the epoch to the graph.

    Returns:
      The epoch placeholder.
    """
    self._epoch_placeholder = array_ops.placeholder(dtypes.int32, [1],
                                                    name='epoch')
    return self._epoch_placeholder

  def input_builder(self):
    """Builds inputs in the graph.

    Returns:
      Two placeholders for inputs and outputs.
    """
    input_shape = [None] + self.input_shape[1:]
    self._input_placeholder = array_ops.placeholder(
        dtypes.as_dtype(self.input_dtype),
        input_shape,
        name='input')
    if self.output_shape is None:
      self._output_placeholder = None
    else:
      output_shape = [None] + self.output_shape[1:]
      self._output_placeholder = array_ops.placeholder(
          dtypes.as_dtype(self.output_dtype),
          output_shape,
          name='output')
    return self._input_placeholder, self._output_placeholder

  def set_placeholders(self, input_placeholder, output_placeholder):
    """Sets placeholders for this data feeder.

    Args:
      input_placeholder: Placeholder for `x` variable. Should match shape
        of the examples in the x dataset.
      output_placeholder: Placeholder for `y` variable. Should match
        shape of the examples in the y dataset. Can be None.
    """
    self._input_placeholder = input_placeholder
    self._output_placeholder = output_placeholder

  def get_feed_params(self):
    """Function returns a dict with data feed params while training.

    Returns:
      A dict with data feed params while training.
    """
    return {
        'epoch': self.epoch,
        'offset': self.offset,
        'batch_size': self._batch_size
    }

  def get_feed_dict_fn(self):
    """Returns a function that samples data into given placeholders.

    Returns:
      A function that when called samples a random subset of batch size
      from x and y.
    """
    def _feed_dict_fn():
      """Function that samples data into given placeholders."""
      if self.max_epochs is not None and self.epoch + 1 > self.max_epochs:
        raise StopIteration
      assert self._input_placeholder is not None
      feed_dict = {}
      if self._epoch_placeholder is not None:
        feed_dict[self._epoch_placeholder.name] = [self.epoch]

      # Take next batch of indices.
      end = min(self.x.shape[0], self.offset + self._batch_size)
      batch_indices = self.indices[self.offset:end]

      # Assign input features from random indices.
      inp = (
          np.array(self.x[batch_indices]).reshape((batch_indices.shape[0], 1))
          if len(self.x.shape) == 1 else self.x[batch_indices])
      feed_dict[self._input_placeholder.name] = inp

      # move offset and reset it if necessary
      self.offset += self._batch_size
      if self.offset >= self.x.shape[0]:
        self.indices = self.random_state.permutation(self.x.shape[0])
        self.offset = 0
        self.epoch += 1

      # return early if there are no labels
      if self._output_placeholder is None:
        return feed_dict

      # assign labels from random indices
      self.output_shape[0] = batch_indices.shape[0]
      out = np.zeros(self.output_shape, dtype=self.output_dtype)
      for i in xrange(out.shape[0]):
        sample = batch_indices[i]
        # self.n_classes is None means we're passing in raw target indices
        if self.n_classes is None:
          out[i] = self.y[sample]
        else:
          if self.n_classes > 1:
            if len(self.output_shape) == 2:
              out.itemset((i, int(self.y[sample])), 1.0)
            else:
              for idx, value in enumerate(self.y[sample]):
                out.itemset(tuple([i, idx, value]), 1.0)
          else:
            out[i] = self.y[sample]
      feed_dict[self._output_placeholder.name] = out

      return feed_dict

    return _feed_dict_fn


class StreamingDataFeeder(DataFeeder):
  """Data feeder for TF trainer that reads data from iterator.

  Streaming data feeder allows to read data as it comes it from disk or
  somewhere else. It's custom to have this iterators rotate infinetly over
  the dataset, to allow control of how much to learn on the trainer side.

  Parameters:
    x: iterator that returns for each element, returns features.
    y: iterator that returns for each element, returns 1 or many classes /
       regression values.
    n_classes: indicator of how many classes the target has.
    batch_size: Mini batch size to accumulate.

  Attributes:
    x: input features.
    y: input target.
    n_classes: number of classes.
    batch_size: mini batch size to accumulate.
    input_shape: shape of the input.
    output_shape: shape of the output.
    input_dtype: dtype of input.
    output_dtype: dtype of output.
  """

  def __init__(self, x, y, n_classes, batch_size):
    # pylint: disable=invalid-name,super-init-not-called
    x_first_el = six.next(x)
    self.x = itertools.chain([x_first_el], x)
    if y is not None:
      y_first_el = six.next(y)
      self.y = itertools.chain([y_first_el], y)
    else:
      y_first_el = None
      self.y = None
    self.n_classes = n_classes
    self.input_shape, self.output_shape, self._batch_size = _get_in_out_shape(
        [1] + list(x_first_el.shape),
        [1] + list(y_first_el.shape) if y is not None else None,
        n_classes,
        batch_size)
    self.input_dtype = x_first_el.dtype
    # Convert float64 to float32, as all the parameters in the model are
    # floats32 and there is a lot of benefits in using it in NNs.
    if self.input_dtype == np.float64:
      self.input_dtype = np.float32
    # Output types are floats, due to both softmaxes and regression req.
    if n_classes is not None and n_classes > 0:
      self.output_dtype = np.float32
    elif y is not None:
      if isinstance(y_first_el, list) or isinstance(y_first_el, np.ndarray):
        self.output_dtype = np.dtype(type(y_first_el[0]))
      else:
        self.output_dtype = np.dtype(type(y_first_el))

  def get_feed_params(self):
    """Function returns a dict with data feed params while training.

    Returns:
      A dict with data feed params while training.
    """
    return {'batch_size': self._batch_size}

  def get_feed_dict_fn(self):
    """Returns a function, that will sample data and provide it to placeholders.

    Returns:
      A function that when called samples a random subset of batch size
      from x and y.
    """
    self.stopped = False

    def _feed_dict_fn():
      """Samples data and provides it to placeholders.

      Returns:
        Dict of input and output tensors.
      """
      if self.stopped:
        raise StopIteration
      inp = np.zeros(self.input_shape, dtype=self.input_dtype)
      if self.y is not None:
        out = np.zeros(self.output_shape, dtype=self.output_dtype)
      for i in xrange(self._batch_size):
        # Add handling when queue ends.
        try:
          inp[i, :] = six.next(self.x)
        except StopIteration:
          self.stopped = True
          inp = inp[:i, :]
          if self.y is not None:
            out = out[:i]
          break

        if self.y is not None:
          y = six.next(self.y)
          if self.n_classes is not None and self.n_classes > 1:
            if len(self.output_shape) == 2:
              out.itemset((i, y), 1.0)
            else:
              for idx, value in enumerate(y):
                out.itemset(tuple([i, idx, value]), 1.0)
          else:
            out[i] = y
      if self.y is None:
        return {self._input_placeholder.name: inp}
      return {self._input_placeholder.name: inp,
              self._output_placeholder.name: out}

    return _feed_dict_fn


class DaskDataFeeder(object):
  """Data feeder for that reads data from dask.Series and dask.DataFrame.

  Numpy arrays can be serialized to disk and it's possible to do random seeks
  into them. DaskDataFeeder will remove requirement to have full dataset in the
  memory and still do random seeks for sampling of batches.

  Parameters:
    x: iterator that returns for each element, returns features.
    y: iterator that returns for each element, returns 1 or many classes /
      regression values.
    n_classes: indicator of how many classes the target has.
    batch_size: Mini batch size to accumulate.
    random_state: random state for RNG. Note that it will mutate so use a
      int value for this if you want consistent sized batches.

  Attributes:
    x: input features.
    y: input target.
    n_classes: number of classes.
    batch_size: mini batch size to accumulate.
    input_shape: shape of the input.
    output_shape: shape of the output.
    input_dtype: dtype of input.
    output_dtype: dtype of output.
  """
  def __init__(self, x, y, n_classes, batch_size, shuffle=True,
               random_state=None, epochs=None):
    # pylint: disable=invalid-name,super-init-not-called
    import dask.dataframe as dd  # pylint: disable=g-import-not-at-top
    # TODO(terrytangyuan): check x and y dtypes in dask_io like pandas
    self.x = x
    self.y = y
    # save column names
    self.x_columns = list(x.columns)
    if isinstance(y.columns[0], str):
      self.y_columns = list(y.columns)
    else:
      # deal with cases where two DFs have overlapped default numeric colnames
      self.y_columns = len(self.x_columns) + 1
      self.y = self.y.rename(columns={y.columns[0]: self.y_columns})

    # TODO(terrytangyuan): deal with unsupervised cases
    # combine into a data frame
    self.df = dd.multi.concat([self.x, self.y], axis=1)
    self.n_classes = n_classes

    x_count = x.count().compute()[0]
    x_shape = (x_count, len(self.x.columns))
    y_shape = (x_count, len(self.y.columns))
    # TODO(terrytangyuan): Add support for shuffle and epochs.
    self.shuffle = shuffle
    self.epochs = epochs
    self.input_shape, self.output_shape, self._batch_size = _get_in_out_shape(
        x_shape, y_shape, n_classes, batch_size)
    self.sample_fraction = self._batch_size / float(x_count)
    # TODO(ptucker,ipolosukhin): Remove this?
    # TODO(ipolosukhin): remove or restore.
    # self.x.dtypes[0], self.y.dtypes[self.y_columns]
    self.input_dtype, self.output_dtype = np.float32, np.float32
    if random_state is None:
      self.random_state = 66
    else:
      self.random_state = random_state

  def get_feed_params(self):
    """Function returns a dict with data feed params while training.

    Returns:
      A dict with data feed params while training.
    """
    return {'batch_size': self._batch_size}

  def get_feed_dict_fn(self, input_placeholder, output_placeholder):
    """Returns a function, that will sample data and provide it to placeholders.

    Args:
      input_placeholder: tf.Placeholder for input features mini batch.
      output_placeholder: tf.Placeholder for output targets.

    Returns:
      A function that when called samples a random subset of batch size
      from x and y.
    """
    def _feed_dict_fn():
      """Samples data and provides it to placeholders."""
      # TODO(ipolosukhin): option for with/without replacement (dev version of
      # dask)
      sample = self.df.random_split(
          [self.sample_fraction, 1 - self.sample_fraction],
          random_state=self.random_state)
      inp = extract_pandas_matrix(sample[0][self.x_columns].compute()).tolist()
      out = extract_pandas_matrix(sample[0][self.y_columns].compute())
      # convert to correct dtype
      inp = np.array(inp, dtype=self.input_dtype)
      # one-hot encode out for each class for cross entropy loss
      if HAS_PANDAS:
        import pandas as pd  # pylint: disable=g-import-not-at-top
        if not isinstance(out, pd.Series):
          out = out.flatten()
      out_max = self.y.max().compute().values[0]
      encoded_out = np.zeros((out.size, out_max + 1), dtype=self.output_dtype)
      encoded_out[np.arange(out.size), out] = 1
      return {input_placeholder.name: inp,
              output_placeholder.name: encoded_out}
    return _feed_dict_fn