aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@gmail.com>2016-03-07 17:11:19 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-03-08 17:17:34 -0800
commite6388816e30386182967e149a2e59e68c7a73f26 (patch)
treefe80a0b1d310cc60bcd9b61a3bfdab156c0e85d8
parent9ba61973849a9ef79104c8295886049634a193b4 (diff)
Remove the requirement for sequence_length input to dynamic_rnn.
Change: 116605455
-rw-r--r--tensorflow/python/kernel_tests/rnn_test.py17
-rw-r--r--tensorflow/python/ops/rnn.py54
2 files changed, 45 insertions, 26 deletions
diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py
index d8f6376651..7d040c221e 100644
--- a/tensorflow/python/kernel_tests/rnn_test.py
+++ b/tensorflow/python/kernel_tests/rnn_test.py
@@ -560,7 +560,7 @@ class LSTMTest(tf.test.TestCase):
for out0, out1 in zip(outputs0_values, outputs1_values):
self.assertAllEqual(out0, out1)
- def _testDynamicEquivalentToStaticRNN(self, use_gpu):
+ def _testDynamicEquivalentToStaticRNN(self, use_gpu, use_sequence_length):
time_steps = 8
num_units = 3
num_proj = 4
@@ -569,7 +569,10 @@ class LSTMTest(tf.test.TestCase):
input_values = np.random.randn(time_steps, batch_size, input_size)
- sequence_length = np.random.randint(0, time_steps, size=batch_size)
+ if use_sequence_length:
+ sequence_length = np.random.randint(0, time_steps, size=batch_size)
+ else:
+ sequence_length = None
########### Step 1: Run static graph and generate readouts
with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess:
@@ -744,8 +747,14 @@ class LSTMTest(tf.test.TestCase):
self._testDoubleInputWithDropoutAndDynamicCalculation(use_gpu=True)
def testDynamicEquivalentToStaticRNN(self):
- self._testDynamicEquivalentToStaticRNN(use_gpu=False)
- self._testDynamicEquivalentToStaticRNN(use_gpu=True)
+ self._testDynamicEquivalentToStaticRNN(
+ use_gpu=False, use_sequence_length=False)
+ self._testDynamicEquivalentToStaticRNN(
+ use_gpu=True, use_sequence_length=False)
+ self._testDynamicEquivalentToStaticRNN(
+ use_gpu=False, use_sequence_length=True)
+ self._testDynamicEquivalentToStaticRNN(
+ use_gpu=True, use_sequence_length=True)
class BidirectionalRNNTest(tf.test.TestCase):
diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py
index f5e4c4317b..e010e371a7 100644
--- a/tensorflow/python/ops/rnn.py
+++ b/tensorflow/python/ops/rnn.py
@@ -344,9 +344,9 @@ def bidirectional_rnn(cell_fw, cell_bw, inputs,
return outputs
-def dynamic_rnn(cell, inputs, sequence_length, initial_state=None, dtype=None,
- parallel_iterations=None, swap_memory=False, time_major=False,
- scope=None):
+def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None,
+ dtype=None, parallel_iterations=None, swap_memory=False,
+ time_major=False, scope=None):
"""Creates a recurrent neural network specified by RNNCell "cell".
This function is functionally identical to the function `rnn` above, but
@@ -367,9 +367,9 @@ def dynamic_rnn(cell, inputs, sequence_length, initial_state=None, dtype=None,
`[batch_size, max_time, cell.input_size]`.
If time_major == True, this must be a tensor of shape:
`[max_time, batch_size, cell.input_size]`.
- sequence_length: An int32/int64 vector (tensor) size [batch_size].
+ sequence_length: (optional) An int32/int64 vector sized `[batch_size]`.
initial_state: (optional) An initial state for the RNN. This must be
- a tensor of appropriate type and shape [batch_size x cell.state_size].
+ a tensor of appropriate type and shape `[batch_size x cell.state_size]`.
dtype: (optional) The data type for the initial state. Required if
initial_state is not provided.
parallel_iterations: (Default: 32). The number of iterations to run in
@@ -413,8 +413,10 @@ def dynamic_rnn(cell, inputs, sequence_length, initial_state=None, dtype=None,
inputs = array_ops.transpose(inputs, [1, 0, 2]) # (B,T,D) => (T,B,D)
parallel_iterations = parallel_iterations or 32
- sequence_length = math_ops.to_int32(sequence_length)
- sequence_length = array_ops.identity(sequence_length, name="sequence_length")
+ if sequence_length is not None:
+ sequence_length = math_ops.to_int32(sequence_length)
+ sequence_length = array_ops.identity( # Just to find it in the graph.
+ sequence_length, name="sequence_length")
# Create a new scope in which the caching device is either
# determined by the parent scope, or is set to place the cached
@@ -440,15 +442,16 @@ def dynamic_rnn(cell, inputs, sequence_length, initial_state=None, dtype=None,
["Expected shape for Tensor %s is " % x.name,
packed_shape, " but saw shape: ", x_shape])
- # Perform some shape validation
- with ops.control_dependencies(
- [_assert_has_shape(sequence_length, [batch_size])]):
- sequence_length = array_ops.identity(sequence_length, name="CheckSeqLen")
+ if sequence_length is not None:
+ # Perform some shape validation
+ with ops.control_dependencies(
+ [_assert_has_shape(sequence_length, [batch_size])]):
+ sequence_length = array_ops.identity(
+ sequence_length, name="CheckSeqLen")
(outputs, final_state) = _dynamic_rnn_loop(
- cell, inputs, state, sequence_length,
- parallel_iterations=parallel_iterations,
- swap_memory=swap_memory)
+ cell, inputs, state, parallel_iterations=parallel_iterations,
+ swap_memory=swap_memory, sequence_length=sequence_length)
# Outputs of _dynamic_rnn_loop are always shaped [time, batch, depth].
# If we are performing batch-major calculations, transpose output back
@@ -459,17 +462,18 @@ def dynamic_rnn(cell, inputs, sequence_length, initial_state=None, dtype=None,
return (outputs, final_state)
-def _dynamic_rnn_loop(cell, inputs, initial_state, sequence_length,
- parallel_iterations, swap_memory):
+def _dynamic_rnn_loop(
+ cell, inputs, initial_state, parallel_iterations, swap_memory,
+ sequence_length=None):
"""Internal implementation of Dynamic RNN.
Args:
cell: An instance of RNNCell.
inputs: A `Tensor` of shape [time, batch_size, depth].
initial_state: A `Tensor` of shape [batch_size, depth].
- sequence_length: An `int32` `Tensor` of shape [batch_size].
parallel_iterations: Positive Python int.
swap_memory: A Python boolean
+ sequence_length: (optional) An `int32` `Tensor` of shape [batch_size].
Returns:
Tuple (final_outputs, final_state).
@@ -500,8 +504,9 @@ def _dynamic_rnn_loop(cell, inputs, initial_state, sequence_length,
# Prepare dynamic conditional copying of state & output
zero_output = array_ops.zeros(
array_ops.pack([batch_size, cell.output_size]), inputs.dtype)
- min_sequence_length = math_ops.reduce_min(sequence_length)
- max_sequence_length = math_ops.reduce_max(sequence_length)
+ if sequence_length is not None:
+ min_sequence_length = math_ops.reduce_min(sequence_length)
+ max_sequence_length = math_ops.reduce_max(sequence_length)
time = array_ops.constant(0, dtype=dtypes.int32, name="time")
@@ -534,9 +539,14 @@ def _dynamic_rnn_loop(cell, inputs, initial_state, sequence_length,
# Restore some shape information
input_t.set_shape([const_batch_size, const_depth])
- (output, new_state) = _rnn_step(
- time, sequence_length, min_sequence_length, max_sequence_length,
- zero_output, state, lambda: cell(input_t, state))
+ call_cell = lambda: cell(input_t, state)
+
+ if sequence_length is not None:
+ (output, new_state) = _rnn_step(
+ time, sequence_length, min_sequence_length, max_sequence_length,
+ zero_output, state, call_cell)
+ else:
+ (output, new_state) = call_cell()
output_ta_t = output_ta_t.write(time, output)