diff options
Diffstat (limited to 'tensorflow/python/estimator/inputs/queues/feeding_functions.py')
-rw-r--r-- | tensorflow/python/estimator/inputs/queues/feeding_functions.py | 123 |
1 files changed, 112 insertions, 11 deletions
diff --git a/tensorflow/python/estimator/inputs/queues/feeding_functions.py b/tensorflow/python/estimator/inputs/queues/feeding_functions.py index 847b27b904..149425436a 100644 --- a/tensorflow/python/estimator/inputs/queues/feeding_functions.py +++ b/tensorflow/python/estimator/inputs/queues/feeding_functions.py @@ -46,6 +46,64 @@ except ImportError: HAS_PANDAS = False +def _fill_array(arr, seq, fillvalue=0): + """Recursively fills padded arr with elements from seq. + + If lenght of seq is less then arr padded length, fillvalue used. + Args: + arr: Padded tensor of shape [batch_size, ..., max_padded_dim_len]. + seq: Non-padded list of data sampels of shape + [batch_size, ..., padded_dim(None)] + fillvalue: Default fillvalue to use. + """ + if arr.ndim == 1: + try: + len_ = len(seq) + except TypeError: + len_ = 0 + arr[:len_] = seq + arr[len_:] = fillvalue + else: + for subarr, subseq in six.moves.zip_longest(arr, seq, fillvalue=()): + _fill_array(subarr, subseq, fillvalue) + + +def _pad_if_needed(batch_key_item, fillvalue=0): + """ Returns padded batch. + + Args: + batch_key_item: List of data samples of any type with shape + [batch_size, ..., padded_dim(None)]. + fillvalue: Default fillvalue to use. + + Returns: + Padded with zeros tensor of same type and shape + [batch_size, ..., max_padded_dim_len]. + + Raises: + ValueError if data samples have different shapes (except last padded dim). + """ + shapes = [ + seq.shape[:-1] if len(seq.shape) > 0 else -1 for seq in batch_key_item + ] + if not all(shapes[0] == x for x in shapes): + raise ValueError("Array shapes must match.") + + last_length = [ + seq.shape[-1] if len(seq.shape) > 0 else 0 for seq in batch_key_item + ] + if all([x == last_length[0] for x in last_length]): + return batch_key_item + + batch_size = len(batch_key_item) + max_sequence_length = max(last_length) + result_batch = np.zeros( + shape=[batch_size] + list(shapes[0]) + [max_sequence_length], + dtype=batch_key_item[0].dtype) + _fill_array(result_batch, batch_key_item, fillvalue) + return result_batch + + def _get_integer_indices_for_next_batch( batch_indices_start, batch_size, epoch_end, array_length, current_epoch, total_epochs): @@ -229,7 +287,8 @@ class _GeneratorFeedFn(object): batch_size, random_start=False, seed=None, - num_epochs=None): + num_epochs=None, + pad_value=None): first_sample = next(generator()) if len(placeholders) != len(first_sample): raise ValueError("Expected {} placeholders; got {}.".format( @@ -241,6 +300,7 @@ class _GeneratorFeedFn(object): self._batch_size = batch_size self._num_epochs = num_epochs self._epoch = 0 + self._pad_value = pad_value random.seed(seed) def __call__(self): @@ -264,7 +324,17 @@ class _GeneratorFeedFn(object): list_dict.setdefault(self._col_placeholders[index], list()).append(data_row[key]) list_dict_size += 1 - feed_dict = {key: np.asarray(item) for key, item in list(list_dict.items())} + + if self._pad_value is not None: + feed_dict = { + key: np.asarray(_pad_if_needed(item, self._pad_value)) + for key, item in list(list_dict.items()) + } + else: + feed_dict = { + key: np.asarray(item) + for key, item in list(list_dict.items()) + } return feed_dict @@ -276,7 +346,8 @@ def _enqueue_data(data, seed=None, name="enqueue_input", enqueue_size=1, - num_epochs=None): + num_epochs=None, + pad_value=None): """Creates a queue filled from a numpy array or pandas `DataFrame`. Returns a queue filled with the rows of the given (`OrderedDict` of) array @@ -298,6 +369,7 @@ def _enqueue_data(data, name: a scope name identifying the data. enqueue_size: the number of rows to enqueue per step. num_epochs: limit enqueuing to a specified number of epochs, if provided. + pad_value: default value for dynamic padding of data samples, if provided. Returns: A queue filled with the rows of the given (`OrderedDict` of) array or @@ -306,6 +378,8 @@ def _enqueue_data(data, Raises: TypeError: `data` is not a Pandas `DataFrame`, an `OrderedDict` of numpy arrays, a numpy `ndarray`, or a generator producing these. + NotImplementedError: padding and shuffling data at the same time. + NotImplementedError: padding usage with non generator data type. """ with ops.name_scope(name): if isinstance(data, np.ndarray): @@ -336,6 +410,14 @@ def _enqueue_data(data, "data must be either a numpy array or pandas DataFrame if pandas is " "installed; got {}".format(type(data).__name__)) + pad_data = pad_value is not None + if pad_data and get_feed_fn is not _GeneratorFeedFn: + raise NotImplementedError( + "padding is only available with generator usage") + if shuffle and pad_data: + raise NotImplementedError( + "padding and shuffling data at the same time is not implemented") + # TODO(jamieas): TensorBoard warnings for all warnings below once available. if num_threads > 1 and num_epochs is not None: @@ -368,6 +450,13 @@ def _enqueue_data(data, dtypes=types, shapes=queue_shapes, seed=seed) + elif pad_data: + min_after_dequeue = 0 # just for the summary text + queue_shapes = list( + map(lambda x: tuple(list(x[:-1]) + [None]) if len(x) > 0 else x, + queue_shapes)) + queue = data_flow_ops.PaddingFIFOQueue( + capacity, dtypes=types, shapes=queue_shapes) else: min_after_dequeue = 0 # just for the summary text queue = data_flow_ops.FIFOQueue( @@ -383,14 +472,26 @@ def _enqueue_data(data, enqueue_ops.append(queue.enqueue_many(placeholders)) seed_i = None if seed is None else (i + 1) * seed - feed_fns.append( - get_feed_fn( - placeholders, - data, - enqueue_size, - random_start=shuffle, - seed=seed_i, - num_epochs=num_epochs)) + + if not pad_data: + feed_fns.append( + get_feed_fn( + placeholders, + data, + enqueue_size, + random_start=shuffle, + seed=seed_i, + num_epochs=num_epochs)) + else: + feed_fns.append( + get_feed_fn( + placeholders, + data, + enqueue_size, + random_start=shuffle, + seed=seed_i, + num_epochs=num_epochs, + pad_value=pad_value)) runner = fqr._FeedingQueueRunner( # pylint: disable=protected-access queue=queue, enqueue_ops=enqueue_ops, feed_fns=feed_fns) |