aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/inputs/queues/feeding_functions.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/estimator/inputs/queues/feeding_functions.py')
-rw-r--r--tensorflow/python/estimator/inputs/queues/feeding_functions.py123
1 files changed, 112 insertions, 11 deletions
diff --git a/tensorflow/python/estimator/inputs/queues/feeding_functions.py b/tensorflow/python/estimator/inputs/queues/feeding_functions.py
index 847b27b904..149425436a 100644
--- a/tensorflow/python/estimator/inputs/queues/feeding_functions.py
+++ b/tensorflow/python/estimator/inputs/queues/feeding_functions.py
@@ -46,6 +46,64 @@ except ImportError:
HAS_PANDAS = False
+def _fill_array(arr, seq, fillvalue=0):
+ """Recursively fills padded arr with elements from seq.
+
+ If lenght of seq is less then arr padded length, fillvalue used.
+ Args:
+ arr: Padded tensor of shape [batch_size, ..., max_padded_dim_len].
+ seq: Non-padded list of data sampels of shape
+ [batch_size, ..., padded_dim(None)]
+ fillvalue: Default fillvalue to use.
+ """
+ if arr.ndim == 1:
+ try:
+ len_ = len(seq)
+ except TypeError:
+ len_ = 0
+ arr[:len_] = seq
+ arr[len_:] = fillvalue
+ else:
+ for subarr, subseq in six.moves.zip_longest(arr, seq, fillvalue=()):
+ _fill_array(subarr, subseq, fillvalue)
+
+
+def _pad_if_needed(batch_key_item, fillvalue=0):
+ """ Returns padded batch.
+
+ Args:
+ batch_key_item: List of data samples of any type with shape
+ [batch_size, ..., padded_dim(None)].
+ fillvalue: Default fillvalue to use.
+
+ Returns:
+ Padded with zeros tensor of same type and shape
+ [batch_size, ..., max_padded_dim_len].
+
+ Raises:
+ ValueError if data samples have different shapes (except last padded dim).
+ """
+ shapes = [
+ seq.shape[:-1] if len(seq.shape) > 0 else -1 for seq in batch_key_item
+ ]
+ if not all(shapes[0] == x for x in shapes):
+ raise ValueError("Array shapes must match.")
+
+ last_length = [
+ seq.shape[-1] if len(seq.shape) > 0 else 0 for seq in batch_key_item
+ ]
+ if all([x == last_length[0] for x in last_length]):
+ return batch_key_item
+
+ batch_size = len(batch_key_item)
+ max_sequence_length = max(last_length)
+ result_batch = np.zeros(
+ shape=[batch_size] + list(shapes[0]) + [max_sequence_length],
+ dtype=batch_key_item[0].dtype)
+ _fill_array(result_batch, batch_key_item, fillvalue)
+ return result_batch
+
+
def _get_integer_indices_for_next_batch(
batch_indices_start, batch_size, epoch_end, array_length,
current_epoch, total_epochs):
@@ -229,7 +287,8 @@ class _GeneratorFeedFn(object):
batch_size,
random_start=False,
seed=None,
- num_epochs=None):
+ num_epochs=None,
+ pad_value=None):
first_sample = next(generator())
if len(placeholders) != len(first_sample):
raise ValueError("Expected {} placeholders; got {}.".format(
@@ -241,6 +300,7 @@ class _GeneratorFeedFn(object):
self._batch_size = batch_size
self._num_epochs = num_epochs
self._epoch = 0
+ self._pad_value = pad_value
random.seed(seed)
def __call__(self):
@@ -264,7 +324,17 @@ class _GeneratorFeedFn(object):
list_dict.setdefault(self._col_placeholders[index],
list()).append(data_row[key])
list_dict_size += 1
- feed_dict = {key: np.asarray(item) for key, item in list(list_dict.items())}
+
+ if self._pad_value is not None:
+ feed_dict = {
+ key: np.asarray(_pad_if_needed(item, self._pad_value))
+ for key, item in list(list_dict.items())
+ }
+ else:
+ feed_dict = {
+ key: np.asarray(item)
+ for key, item in list(list_dict.items())
+ }
return feed_dict
@@ -276,7 +346,8 @@ def _enqueue_data(data,
seed=None,
name="enqueue_input",
enqueue_size=1,
- num_epochs=None):
+ num_epochs=None,
+ pad_value=None):
"""Creates a queue filled from a numpy array or pandas `DataFrame`.
Returns a queue filled with the rows of the given (`OrderedDict` of) array
@@ -298,6 +369,7 @@ def _enqueue_data(data,
name: a scope name identifying the data.
enqueue_size: the number of rows to enqueue per step.
num_epochs: limit enqueuing to a specified number of epochs, if provided.
+ pad_value: default value for dynamic padding of data samples, if provided.
Returns:
A queue filled with the rows of the given (`OrderedDict` of) array or
@@ -306,6 +378,8 @@ def _enqueue_data(data,
Raises:
TypeError: `data` is not a Pandas `DataFrame`, an `OrderedDict` of numpy
arrays, a numpy `ndarray`, or a generator producing these.
+ NotImplementedError: padding and shuffling data at the same time.
+ NotImplementedError: padding usage with non generator data type.
"""
with ops.name_scope(name):
if isinstance(data, np.ndarray):
@@ -336,6 +410,14 @@ def _enqueue_data(data,
"data must be either a numpy array or pandas DataFrame if pandas is "
"installed; got {}".format(type(data).__name__))
+ pad_data = pad_value is not None
+ if pad_data and get_feed_fn is not _GeneratorFeedFn:
+ raise NotImplementedError(
+ "padding is only available with generator usage")
+ if shuffle and pad_data:
+ raise NotImplementedError(
+ "padding and shuffling data at the same time is not implemented")
+
# TODO(jamieas): TensorBoard warnings for all warnings below once available.
if num_threads > 1 and num_epochs is not None:
@@ -368,6 +450,13 @@ def _enqueue_data(data,
dtypes=types,
shapes=queue_shapes,
seed=seed)
+ elif pad_data:
+ min_after_dequeue = 0 # just for the summary text
+ queue_shapes = list(
+ map(lambda x: tuple(list(x[:-1]) + [None]) if len(x) > 0 else x,
+ queue_shapes))
+ queue = data_flow_ops.PaddingFIFOQueue(
+ capacity, dtypes=types, shapes=queue_shapes)
else:
min_after_dequeue = 0 # just for the summary text
queue = data_flow_ops.FIFOQueue(
@@ -383,14 +472,26 @@ def _enqueue_data(data,
enqueue_ops.append(queue.enqueue_many(placeholders))
seed_i = None if seed is None else (i + 1) * seed
- feed_fns.append(
- get_feed_fn(
- placeholders,
- data,
- enqueue_size,
- random_start=shuffle,
- seed=seed_i,
- num_epochs=num_epochs))
+
+ if not pad_data:
+ feed_fns.append(
+ get_feed_fn(
+ placeholders,
+ data,
+ enqueue_size,
+ random_start=shuffle,
+ seed=seed_i,
+ num_epochs=num_epochs))
+ else:
+ feed_fns.append(
+ get_feed_fn(
+ placeholders,
+ data,
+ enqueue_size,
+ random_start=shuffle,
+ seed=seed_i,
+ num_epochs=num_epochs,
+ pad_value=pad_value))
runner = fqr._FeedingQueueRunner( # pylint: disable=protected-access
queue=queue, enqueue_ops=enqueue_ops, feed_fns=feed_fns)