diff options
author | Eugene Brevdo <ebrevdo@gmail.com> | 2016-03-07 17:11:19 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-03-08 17:17:34 -0800 |
commit | e6388816e30386182967e149a2e59e68c7a73f26 (patch) | |
tree | fe80a0b1d310cc60bcd9b61a3bfdab156c0e85d8 | |
parent | 9ba61973849a9ef79104c8295886049634a193b4 (diff) |
Remove the requirement for sequence_length input to dynamic_rnn.
Change: 116605455
-rw-r--r-- | tensorflow/python/kernel_tests/rnn_test.py | 17 | ||||
-rw-r--r-- | tensorflow/python/ops/rnn.py | 54 |
2 files changed, 45 insertions, 26 deletions
diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py index d8f6376651..7d040c221e 100644 --- a/tensorflow/python/kernel_tests/rnn_test.py +++ b/tensorflow/python/kernel_tests/rnn_test.py @@ -560,7 +560,7 @@ class LSTMTest(tf.test.TestCase): for out0, out1 in zip(outputs0_values, outputs1_values): self.assertAllEqual(out0, out1) - def _testDynamicEquivalentToStaticRNN(self, use_gpu): + def _testDynamicEquivalentToStaticRNN(self, use_gpu, use_sequence_length): time_steps = 8 num_units = 3 num_proj = 4 @@ -569,7 +569,10 @@ class LSTMTest(tf.test.TestCase): input_values = np.random.randn(time_steps, batch_size, input_size) - sequence_length = np.random.randint(0, time_steps, size=batch_size) + if use_sequence_length: + sequence_length = np.random.randint(0, time_steps, size=batch_size) + else: + sequence_length = None ########### Step 1: Run static graph and generate readouts with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess: @@ -744,8 +747,14 @@ class LSTMTest(tf.test.TestCase): self._testDoubleInputWithDropoutAndDynamicCalculation(use_gpu=True) def testDynamicEquivalentToStaticRNN(self): - self._testDynamicEquivalentToStaticRNN(use_gpu=False) - self._testDynamicEquivalentToStaticRNN(use_gpu=True) + self._testDynamicEquivalentToStaticRNN( + use_gpu=False, use_sequence_length=False) + self._testDynamicEquivalentToStaticRNN( + use_gpu=True, use_sequence_length=False) + self._testDynamicEquivalentToStaticRNN( + use_gpu=False, use_sequence_length=True) + self._testDynamicEquivalentToStaticRNN( + use_gpu=True, use_sequence_length=True) class BidirectionalRNNTest(tf.test.TestCase): diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py index f5e4c4317b..e010e371a7 100644 --- a/tensorflow/python/ops/rnn.py +++ b/tensorflow/python/ops/rnn.py @@ -344,9 +344,9 @@ def bidirectional_rnn(cell_fw, cell_bw, inputs, return outputs -def dynamic_rnn(cell, inputs, sequence_length, initial_state=None, dtype=None, - parallel_iterations=None, swap_memory=False, time_major=False, - scope=None): +def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None, + dtype=None, parallel_iterations=None, swap_memory=False, + time_major=False, scope=None): """Creates a recurrent neural network specified by RNNCell "cell". This function is functionally identical to the function `rnn` above, but @@ -367,9 +367,9 @@ def dynamic_rnn(cell, inputs, sequence_length, initial_state=None, dtype=None, `[batch_size, max_time, cell.input_size]`. If time_major == True, this must be a tensor of shape: `[max_time, batch_size, cell.input_size]`. - sequence_length: An int32/int64 vector (tensor) size [batch_size]. + sequence_length: (optional) An int32/int64 vector sized `[batch_size]`. initial_state: (optional) An initial state for the RNN. This must be - a tensor of appropriate type and shape [batch_size x cell.state_size]. + a tensor of appropriate type and shape `[batch_size x cell.state_size]`. dtype: (optional) The data type for the initial state. Required if initial_state is not provided. parallel_iterations: (Default: 32). The number of iterations to run in @@ -413,8 +413,10 @@ def dynamic_rnn(cell, inputs, sequence_length, initial_state=None, dtype=None, inputs = array_ops.transpose(inputs, [1, 0, 2]) # (B,T,D) => (T,B,D) parallel_iterations = parallel_iterations or 32 - sequence_length = math_ops.to_int32(sequence_length) - sequence_length = array_ops.identity(sequence_length, name="sequence_length") + if sequence_length is not None: + sequence_length = math_ops.to_int32(sequence_length) + sequence_length = array_ops.identity( # Just to find it in the graph. + sequence_length, name="sequence_length") # Create a new scope in which the caching device is either # determined by the parent scope, or is set to place the cached @@ -440,15 +442,16 @@ def dynamic_rnn(cell, inputs, sequence_length, initial_state=None, dtype=None, ["Expected shape for Tensor %s is " % x.name, packed_shape, " but saw shape: ", x_shape]) - # Perform some shape validation - with ops.control_dependencies( - [_assert_has_shape(sequence_length, [batch_size])]): - sequence_length = array_ops.identity(sequence_length, name="CheckSeqLen") + if sequence_length is not None: + # Perform some shape validation + with ops.control_dependencies( + [_assert_has_shape(sequence_length, [batch_size])]): + sequence_length = array_ops.identity( + sequence_length, name="CheckSeqLen") (outputs, final_state) = _dynamic_rnn_loop( - cell, inputs, state, sequence_length, - parallel_iterations=parallel_iterations, - swap_memory=swap_memory) + cell, inputs, state, parallel_iterations=parallel_iterations, + swap_memory=swap_memory, sequence_length=sequence_length) # Outputs of _dynamic_rnn_loop are always shaped [time, batch, depth]. # If we are performing batch-major calculations, transpose output back @@ -459,17 +462,18 @@ def dynamic_rnn(cell, inputs, sequence_length, initial_state=None, dtype=None, return (outputs, final_state) -def _dynamic_rnn_loop(cell, inputs, initial_state, sequence_length, - parallel_iterations, swap_memory): +def _dynamic_rnn_loop( + cell, inputs, initial_state, parallel_iterations, swap_memory, + sequence_length=None): """Internal implementation of Dynamic RNN. Args: cell: An instance of RNNCell. inputs: A `Tensor` of shape [time, batch_size, depth]. initial_state: A `Tensor` of shape [batch_size, depth]. - sequence_length: An `int32` `Tensor` of shape [batch_size]. parallel_iterations: Positive Python int. swap_memory: A Python boolean + sequence_length: (optional) An `int32` `Tensor` of shape [batch_size]. Returns: Tuple (final_outputs, final_state). @@ -500,8 +504,9 @@ def _dynamic_rnn_loop(cell, inputs, initial_state, sequence_length, # Prepare dynamic conditional copying of state & output zero_output = array_ops.zeros( array_ops.pack([batch_size, cell.output_size]), inputs.dtype) - min_sequence_length = math_ops.reduce_min(sequence_length) - max_sequence_length = math_ops.reduce_max(sequence_length) + if sequence_length is not None: + min_sequence_length = math_ops.reduce_min(sequence_length) + max_sequence_length = math_ops.reduce_max(sequence_length) time = array_ops.constant(0, dtype=dtypes.int32, name="time") @@ -534,9 +539,14 @@ def _dynamic_rnn_loop(cell, inputs, initial_state, sequence_length, # Restore some shape information input_t.set_shape([const_batch_size, const_depth]) - (output, new_state) = _rnn_step( - time, sequence_length, min_sequence_length, max_sequence_length, - zero_output, state, lambda: cell(input_t, state)) + call_cell = lambda: cell(input_t, state) + + if sequence_length is not None: + (output, new_state) = _rnn_step( + time, sequence_length, min_sequence_length, max_sequence_length, + zero_output, state, call_cell) + else: + (output, new_state) = call_cell() output_ta_t = output_ta_t.write(time, output) |