diff options
author | 2017-05-22 08:11:01 -0700 | |
---|---|---|
committer | 2017-05-22 08:14:48 -0700 | |
commit | ac348ff2fe26f7b507f207767c0be2cb0765061c (patch) | |
tree | 73d91bee87c93b0fb2afb2b5d1ed0ecdbb9794e6 /tensorflow/python/ops/rnn.py | |
parent | 98a0bcf756bf7700664361a5ade778b99ebff1b1 (diff) |
Create initial state with known batch size if the input batch size is known
PiperOrigin-RevId: 156738153
Diffstat (limited to 'tensorflow/python/ops/rnn.py')
-rw-r--r-- | tensorflow/python/ops/rnn.py | 36 |
1 files changed, 29 insertions, 7 deletions
diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py index a6fba046da..d4b7e04b84 100644 --- a/tensorflow/python/ops/rnn.py +++ b/tensorflow/python/ops/rnn.py @@ -68,6 +68,33 @@ def _transpose_batch_time(x): return x_t +def _best_effort_input_batch_size(flat_input): + """Get static input batch size if available, with fallback to the dynamic one. + + Args: + flat_input: An iterable of time major input Tensors of shape [max_time, + batch_size, ...]. All inputs should have compatible batch sizes. + + Returns: + The batch size in Python integer if available, or a scalar Tensor otherwise. + + Raises: + ValueError: if there is any input with an invalid shape. + """ + for input_ in flat_input: + shape = input_.shape + if shape.ndims is None: + continue + if shape.ndims < 2: + raise ValueError( + "Expected input tensor %s to have rank at least 2" % input_) + batch_size = shape[1].value + if batch_size is not None: + return batch_size + # Fallback to the dynamic batch size of the first input. + return array_ops.shape(flat_input[0])[1] + + def _infer_state_dtype(explicit_dtype, state): """Infer the dtype of an RNN state. @@ -525,12 +552,7 @@ def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None, with vs.variable_scope(scope or "rnn") as varscope: if varscope.caching_device is None: varscope.set_caching_device(lambda op: op.device) - input_shape = tuple(array_ops.shape(input_) for input_ in flat_input) - batch_size = input_shape[0][1] - - for input_ in input_shape: - if input_[1].get_shape() != batch_size.get_shape(): - raise ValueError("All inputs should have the same batch size") + batch_size = _best_effort_input_batch_size(flat_input) if initial_state is not None: state = initial_state @@ -623,7 +645,7 @@ def _dynamic_rnn_loop(cell, # Construct an initial output input_shape = array_ops.shape(flat_input[0]) time_steps = input_shape[0] - batch_size = input_shape[1] + batch_size = _best_effort_input_batch_size(flat_input) inputs_got_shape = tuple(input_.get_shape().with_rank_at_least(3) for input_ in flat_input) |