aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/inputs/queues/feeding_functions.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/estimator/inputs/queues/feeding_functions.py')
-rw-r--r--tensorflow/python/estimator/inputs/queues/feeding_functions.py64
1 files changed, 61 insertions, 3 deletions
diff --git a/tensorflow/python/estimator/inputs/queues/feeding_functions.py b/tensorflow/python/estimator/inputs/queues/feeding_functions.py
index 9da2bce0f8..a6f5157680 100644
--- a/tensorflow/python/estimator/inputs/queues/feeding_functions.py
+++ b/tensorflow/python/estimator/inputs/queues/feeding_functions.py
@@ -20,7 +20,9 @@ from __future__ import print_function
import collections
import random
+import types as tp
import numpy as np
+import six
from tensorflow.python.estimator.inputs.queues import feeding_queue_runner as fqr
from tensorflow.python.framework import dtypes
@@ -218,6 +220,54 @@ class _PandasFeedFn(object):
return feed_dict
+class _GeneratorFeedFn(object):
+ """Creates feed dictionaries from `Generator` of `dicts` of numpy arrays."""
+
+ def __init__(self,
+ placeholders,
+ generator,
+ batch_size,
+ random_start=False,
+ seed=None,
+ num_epochs=None):
+ first_sample = next(generator())
+ if len(placeholders) != len(first_sample):
+ raise ValueError("Expected {} placeholders; got {}.".format(
+ len(first_sample), len(placeholders)))
+ self._keys = sorted(list(first_sample.keys()))
+ self._col_placeholders = placeholders
+ self._generator_function = generator
+ self._iterator = generator()
+ self._batch_size = batch_size
+ self._num_epochs = num_epochs
+ self._epoch = 0
+ random.seed(seed)
+
+ def __call__(self):
+ if self._num_epochs and self._epoch >= self._num_epochs:
+ raise errors.OutOfRangeError(None, None,
+ "Already emitted %s epochs." % self._epoch)
+ list_dict = {}
+ list_dict_size = 0
+ while list_dict_size < self._batch_size:
+ try:
+ data_row = next(self._iterator)
+ except StopIteration:
+ self._epoch += 1
+ self._iterator = self._generator_function()
+ data_row = next(self._iterator)
+ for index, key in enumerate(self._keys):
+ if key not in data_row.keys():
+ raise KeyError("key mismatch between dicts emitted by GenFun"
+ "Expected {} keys; got {}".format(
+ self._keys, data_row.keys()))
+ list_dict.setdefault(self._col_placeholders[index],
+ list()).append(data_row[key])
+ list_dict_size += 1
+ feed_dict = {key: np.asarray(item) for key, item in list(list_dict.items())}
+ return feed_dict
+
+
def _enqueue_data(data,
capacity,
shuffle=False,
@@ -235,8 +285,9 @@ def _enqueue_data(data,
numpy arrays, the first enqueued `Tensor` contains the row number.
Args:
- data: a numpy `ndarray`, `OrderedDict` of numpy arrays, or pandas
- `DataFrame` that will be read into the queue.
+ data: a numpy `ndarray`, `OrderedDict` of numpy arrays, or a generator
+ yielding `dict`s of numpy arrays or pandas `DataFrame` that will be read
+ into the queue.
capacity: the capacity of the queue.
shuffle: whether or not to shuffle the rows of the array.
min_after_dequeue: minimum number of elements that can remain in the queue
@@ -254,7 +305,7 @@ def _enqueue_data(data,
Raises:
TypeError: `data` is not a Pandas `DataFrame`, an `OrderedDict` of numpy
- arrays or a numpy `ndarray`.
+ arrays, a numpy `ndarray`, or a generator producing these.
"""
with ops.name_scope(name):
if isinstance(data, np.ndarray):
@@ -267,6 +318,13 @@ def _enqueue_data(data,
]
queue_shapes = [()] + [col.shape[1:] for col in data.values()]
get_feed_fn = _OrderedDictNumpyFeedFn
+ elif isinstance(data, tp.FunctionType):
+ x_first_el = six.next(data())
+ x_first_keys = sorted(x_first_el.keys())
+ x_first_values = [x_first_el[key] for key in x_first_keys]
+ types = [dtypes.as_dtype(col.dtype) for col in x_first_values]
+ queue_shapes = [col.shape for col in x_first_values]
+ get_feed_fn = _GeneratorFeedFn
elif HAS_PANDAS and isinstance(data, pd.DataFrame):
types = [
dtypes.as_dtype(dt) for dt in [data.index.dtype] + list(data.dtypes)