aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/inputs/queues/feeding_functions.py
diff options
context:
space:
mode:
authorGravatar Jonathan Hseu <jhseu@google.com>2017-08-25 14:01:05 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-25 14:04:48 -0700
commit008910f1122d115a6d7430bfcc63cf4296c7467d (patch)
treee50199dcceed004cecc8510f9251f5e04734800f /tensorflow/python/estimator/inputs/queues/feeding_functions.py
parent005a88f6cc6e4e8c94a4f2d1980737855c4592f4 (diff)
Merge changes from github.
END_PUBLIC --- Commit b30ce4714 authored by James Qin<jamesqin@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Revamp CudnnRNN Saveables 1. Use a lossy way to save/restore cudnn biases during checkpointing. Cudnn uses 2 biases each gate for all RNNs while tf uses one. To allow cudnn checkpoints to be compatible with both Cudnn and platform-independent impls, previously both individual bias and summed biases each gate were stored. The new way only stores the bias sum for each gate, and split it half-half when restoring from a cudnn graph. Doing this does not cause problems since RNNs do not use weight-decay to regularize. 2. Use inheritance instead of branching * Split RNNParamsSaveable to 1 base class and 4 subclasses. * Extract common routines and only overwrite rnn-type-specific pieces in subclasses. PiperOrigin-RevId: 166413989 --- Commit ebc421daf authored by Alan Yee<alyee@ucsd.edu> Committed by Jonathan Hseu<vomjom@vomjom.net>: Update documentation for contrib (#12424) * Update __init__.py Remove ## for standardization of api docs * Create README.md Add README to define this directory's purpose * Update __init.py Markdown styling does not show up well in api docs * Update README.md Add short mention of describing what to deprecate * Update README.md Capitalize title * Update README.md Revert README change * Delete README.md --- Commit fd295394d authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Use latest version of nsync library, which now allows use of cmake on MacOS. PiperOrigin-RevId: 166411437 --- Commit 587d728e0 authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: [XLA] Refactor reduce-precision-insertion filters, add several more options. In particular, this adds the ability to add reduce-precision operations after fusion nodes based on the contents of those fusion nodes, and the ability to filter operations based on the "op_name" metadata. PiperOrigin-RevId: 166408392 --- Commit 3142f8ef5 authored by Ali Yahya<alive@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Steps toward making ResourceVariables compatible with Eager. This change forces the value of the reuse flag in variable scopes to be tf.AUTO_REUSE when in Eager mode. This change also adds comprehensive Eager tests for ResourceVariable. PiperOrigin-RevId: 166408161 --- Commit b2ce45150 authored by Igor Ganichev<iga@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Make Graph::IsValidNode public It can be reimplemented with existing public APIs, but instead of doing so, making this one public seems better. PiperOrigin-RevId: 166407897 --- Commit 0a2f40e92 authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: [XLA::CPU] Fix HLO profiling in parallel CPU backend. PiperOrigin-RevId: 166400211 --- Commit c4a58e3fd authored by Yao Zhang<yaozhang@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Identify frame ids for all nodes in a graph. PiperOrigin-RevId: 166397615 --- Commit 989713f26 authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: BEGIN_PUBLIC Automated g4 rollback of changelist 166294015 PiperOrigin-RevId: 166521502
Diffstat (limited to 'tensorflow/python/estimator/inputs/queues/feeding_functions.py')
-rw-r--r--tensorflow/python/estimator/inputs/queues/feeding_functions.py73
1 files changed, 34 insertions, 39 deletions
diff --git a/tensorflow/python/estimator/inputs/queues/feeding_functions.py b/tensorflow/python/estimator/inputs/queues/feeding_functions.py
index 149425436a..d7fe4bbfa1 100644
--- a/tensorflow/python/estimator/inputs/queues/feeding_functions.py
+++ b/tensorflow/python/estimator/inputs/queues/feeding_functions.py
@@ -47,12 +47,13 @@ except ImportError:
def _fill_array(arr, seq, fillvalue=0):
- """Recursively fills padded arr with elements from seq.
-
+ """
+ Recursively fills padded arr with elements from seq.
If lenght of seq is less then arr padded length, fillvalue used.
+
Args:
arr: Padded tensor of shape [batch_size, ..., max_padded_dim_len].
- seq: Non-padded list of data sampels of shape
+ seq: Non-padded list of data sampels of shape
[batch_size, ..., padded_dim(None)]
fillvalue: Default fillvalue to use.
"""
@@ -83,23 +84,21 @@ def _pad_if_needed(batch_key_item, fillvalue=0):
Raises:
ValueError if data samples have different shapes (except last padded dim).
"""
- shapes = [
- seq.shape[:-1] if len(seq.shape) > 0 else -1 for seq in batch_key_item
- ]
+ shapes = [seq.shape[:-1] if len(seq.shape) > 0 else -1
+ for seq in batch_key_item]
if not all(shapes[0] == x for x in shapes):
raise ValueError("Array shapes must match.")
- last_length = [
- seq.shape[-1] if len(seq.shape) > 0 else 0 for seq in batch_key_item
- ]
+ last_length = [seq.shape[-1] if len(seq.shape) > 0 else 0
+ for seq in batch_key_item]
if all([x == last_length[0] for x in last_length]):
return batch_key_item
batch_size = len(batch_key_item)
max_sequence_length = max(last_length)
result_batch = np.zeros(
- shape=[batch_size] + list(shapes[0]) + [max_sequence_length],
- dtype=batch_key_item[0].dtype)
+ shape=[batch_size] + list(shapes[0]) + [max_sequence_length],
+ dtype=batch_key_item[0].dtype)
_fill_array(result_batch, batch_key_item, fillvalue)
return result_batch
@@ -326,15 +325,11 @@ class _GeneratorFeedFn(object):
list_dict_size += 1
if self._pad_value is not None:
- feed_dict = {
- key: np.asarray(_pad_if_needed(item, self._pad_value))
- for key, item in list(list_dict.items())
- }
+ feed_dict = {key: np.asarray(_pad_if_needed(item, self._pad_value))
+ for key, item in list(list_dict.items())}
else:
- feed_dict = {
- key: np.asarray(item)
- for key, item in list(list_dict.items())
- }
+ feed_dict = {key: np.asarray(item)
+ for key, item in list(list_dict.items())}
return feed_dict
@@ -380,7 +375,7 @@ def _enqueue_data(data,
arrays, a numpy `ndarray`, or a generator producing these.
NotImplementedError: padding and shuffling data at the same time.
NotImplementedError: padding usage with non generator data type.
- """
+ """
with ops.name_scope(name):
if isinstance(data, np.ndarray):
types = [dtypes.int64, dtypes.as_dtype(data.dtype)]
@@ -452,11 +447,11 @@ def _enqueue_data(data,
seed=seed)
elif pad_data:
min_after_dequeue = 0 # just for the summary text
- queue_shapes = list(
- map(lambda x: tuple(list(x[:-1]) + [None]) if len(x) > 0 else x,
- queue_shapes))
+ queue_shapes = list(map(
+ lambda x: tuple(list(x[:-1]) + [None]) if len(x) > 0 else x,
+ queue_shapes))
queue = data_flow_ops.PaddingFIFOQueue(
- capacity, dtypes=types, shapes=queue_shapes)
+ capacity, dtypes=types, shapes=queue_shapes)
else:
min_after_dequeue = 0 # just for the summary text
queue = data_flow_ops.FIFOQueue(
@@ -475,23 +470,23 @@ def _enqueue_data(data,
if not pad_data:
feed_fns.append(
- get_feed_fn(
- placeholders,
- data,
- enqueue_size,
- random_start=shuffle,
- seed=seed_i,
- num_epochs=num_epochs))
+ get_feed_fn(
+ placeholders,
+ data,
+ enqueue_size,
+ random_start=shuffle,
+ seed=seed_i,
+ num_epochs=num_epochs))
else:
feed_fns.append(
- get_feed_fn(
- placeholders,
- data,
- enqueue_size,
- random_start=shuffle,
- seed=seed_i,
- num_epochs=num_epochs,
- pad_value=pad_value))
+ get_feed_fn(
+ placeholders,
+ data,
+ enqueue_size,
+ random_start=shuffle,
+ seed=seed_i,
+ num_epochs=num_epochs,
+ pad_value=pad_value))
runner = fqr._FeedingQueueRunner( # pylint: disable=protected-access
queue=queue, enqueue_ops=enqueue_ops, feed_fns=feed_fns)