aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/rnn.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-05-22 08:11:01 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-22 08:14:48 -0700
commitac348ff2fe26f7b507f207767c0be2cb0765061c (patch)
tree73d91bee87c93b0fb2afb2b5d1ed0ecdbb9794e6 /tensorflow/python/ops/rnn.py
parent98a0bcf756bf7700664361a5ade778b99ebff1b1 (diff)
Create initial state with known batch size if the input batch size is known
PiperOrigin-RevId: 156738153
Diffstat (limited to 'tensorflow/python/ops/rnn.py')
-rw-r--r--tensorflow/python/ops/rnn.py36
1 files changed, 29 insertions, 7 deletions
diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py
index a6fba046da..d4b7e04b84 100644
--- a/tensorflow/python/ops/rnn.py
+++ b/tensorflow/python/ops/rnn.py
@@ -68,6 +68,33 @@ def _transpose_batch_time(x):
return x_t
+def _best_effort_input_batch_size(flat_input):
+ """Get static input batch size if available, with fallback to the dynamic one.
+
+ Args:
+ flat_input: An iterable of time major input Tensors of shape [max_time,
+ batch_size, ...]. All inputs should have compatible batch sizes.
+
+ Returns:
+ The batch size in Python integer if available, or a scalar Tensor otherwise.
+
+ Raises:
+ ValueError: if there is any input with an invalid shape.
+ """
+ for input_ in flat_input:
+ shape = input_.shape
+ if shape.ndims is None:
+ continue
+ if shape.ndims < 2:
+ raise ValueError(
+ "Expected input tensor %s to have rank at least 2" % input_)
+ batch_size = shape[1].value
+ if batch_size is not None:
+ return batch_size
+ # Fallback to the dynamic batch size of the first input.
+ return array_ops.shape(flat_input[0])[1]
+
+
def _infer_state_dtype(explicit_dtype, state):
"""Infer the dtype of an RNN state.
@@ -525,12 +552,7 @@ def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None,
with vs.variable_scope(scope or "rnn") as varscope:
if varscope.caching_device is None:
varscope.set_caching_device(lambda op: op.device)
- input_shape = tuple(array_ops.shape(input_) for input_ in flat_input)
- batch_size = input_shape[0][1]
-
- for input_ in input_shape:
- if input_[1].get_shape() != batch_size.get_shape():
- raise ValueError("All inputs should have the same batch size")
+ batch_size = _best_effort_input_batch_size(flat_input)
if initial_state is not None:
state = initial_state
@@ -623,7 +645,7 @@ def _dynamic_rnn_loop(cell,
# Construct an initial output
input_shape = array_ops.shape(flat_input[0])
time_steps = input_shape[0]
- batch_size = input_shape[1]
+ batch_size = _best_effort_input_batch_size(flat_input)
inputs_got_shape = tuple(input_.get_shape().with_rank_at_least(3)
for input_ in flat_input)