diff options
Diffstat (limited to 'tensorflow/python/estimator/inputs/queues/feeding_functions.py')
-rw-r--r-- | tensorflow/python/estimator/inputs/queues/feeding_functions.py | 64 |
1 files changed, 61 insertions, 3 deletions
diff --git a/tensorflow/python/estimator/inputs/queues/feeding_functions.py b/tensorflow/python/estimator/inputs/queues/feeding_functions.py index 9da2bce0f8..a6f5157680 100644 --- a/tensorflow/python/estimator/inputs/queues/feeding_functions.py +++ b/tensorflow/python/estimator/inputs/queues/feeding_functions.py @@ -20,7 +20,9 @@ from __future__ import print_function import collections import random +import types as tp import numpy as np +import six from tensorflow.python.estimator.inputs.queues import feeding_queue_runner as fqr from tensorflow.python.framework import dtypes @@ -218,6 +220,54 @@ class _PandasFeedFn(object): return feed_dict +class _GeneratorFeedFn(object): + """Creates feed dictionaries from `Generator` of `dicts` of numpy arrays.""" + + def __init__(self, + placeholders, + generator, + batch_size, + random_start=False, + seed=None, + num_epochs=None): + first_sample = next(generator()) + if len(placeholders) != len(first_sample): + raise ValueError("Expected {} placeholders; got {}.".format( + len(first_sample), len(placeholders))) + self._keys = sorted(list(first_sample.keys())) + self._col_placeholders = placeholders + self._generator_function = generator + self._iterator = generator() + self._batch_size = batch_size + self._num_epochs = num_epochs + self._epoch = 0 + random.seed(seed) + + def __call__(self): + if self._num_epochs and self._epoch >= self._num_epochs: + raise errors.OutOfRangeError(None, None, + "Already emitted %s epochs." % self._epoch) + list_dict = {} + list_dict_size = 0 + while list_dict_size < self._batch_size: + try: + data_row = next(self._iterator) + except StopIteration: + self._epoch += 1 + self._iterator = self._generator_function() + data_row = next(self._iterator) + for index, key in enumerate(self._keys): + if key not in data_row.keys(): + raise KeyError("key mismatch between dicts emitted by GenFun" + "Expected {} keys; got {}".format( + self._keys, data_row.keys())) + list_dict.setdefault(self._col_placeholders[index], + list()).append(data_row[key]) + list_dict_size += 1 + feed_dict = {key: np.asarray(item) for key, item in list(list_dict.items())} + return feed_dict + + def _enqueue_data(data, capacity, shuffle=False, @@ -235,8 +285,9 @@ def _enqueue_data(data, numpy arrays, the first enqueued `Tensor` contains the row number. Args: - data: a numpy `ndarray`, `OrderedDict` of numpy arrays, or pandas - `DataFrame` that will be read into the queue. + data: a numpy `ndarray`, `OrderedDict` of numpy arrays, or a generator + yielding `dict`s of numpy arrays or pandas `DataFrame` that will be read + into the queue. capacity: the capacity of the queue. shuffle: whether or not to shuffle the rows of the array. min_after_dequeue: minimum number of elements that can remain in the queue @@ -254,7 +305,7 @@ def _enqueue_data(data, Raises: TypeError: `data` is not a Pandas `DataFrame`, an `OrderedDict` of numpy - arrays or a numpy `ndarray`. + arrays, a numpy `ndarray`, or a generator producing these. """ with ops.name_scope(name): if isinstance(data, np.ndarray): @@ -267,6 +318,13 @@ def _enqueue_data(data, ] queue_shapes = [()] + [col.shape[1:] for col in data.values()] get_feed_fn = _OrderedDictNumpyFeedFn + elif isinstance(data, tp.FunctionType): + x_first_el = six.next(data()) + x_first_keys = sorted(x_first_el.keys()) + x_first_values = [x_first_el[key] for key in x_first_keys] + types = [dtypes.as_dtype(col.dtype) for col in x_first_values] + queue_shapes = [col.shape for col in x_first_values] + get_feed_fn = _GeneratorFeedFn elif HAS_PANDAS and isinstance(data, pd.DataFrame): types = [ dtypes.as_dtype(dt) for dt in [data.index.dtype] + list(data.dtypes) |