diff options
-rw-r--r-- | tensorflow/python/kernel_tests/rnn_test.py | 14 | ||||
-rw-r--r-- | tensorflow/python/ops/rnn.py | 42 |
2 files changed, 41 insertions, 15 deletions
diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py index be59ac08c2..82c432922c 100644 --- a/tensorflow/python/kernel_tests/rnn_test.py +++ b/tensorflow/python/kernel_tests/rnn_test.py @@ -957,7 +957,7 @@ def graph_creation_static_vs_dynamic_rnn_benchmark(max_time): def _timer(sess, ops): # Warm in - for _ in range(5): + for _ in range(2): sess.run(ops) # Timing run @@ -1100,24 +1100,24 @@ def rnn_long_sequence_benchmark(batch_size, seqlen, num_units, def main(_): print("Graph Creation: Static Unroll vs. Dynamic Unroll LSTM") print("max_t \t dt(static) \t dt(dynamic) \t dt(dynamic)/dt(static)") - for max_time in (1, 25, 50, 100, 200): + for max_time in (1, 25, 50): graph_creation_static_vs_dynamic_rnn_benchmark(max_time) print("Calculation: Static Unroll with Dynamic Flow LSTM " "vs. Dynamic Unroll LSTM") print("batch \t max_t \t units \t gpu \t dt(static) \t dt(dynamic) " "\t dt(dynamic)/dt(static)") - for use_gpu in (False, True): - for batch_size in (256, 512): - for max_time in (50, 100): - for num_units in (512, 256, 128): + for batch_size in (256,): + for max_time in (50,): + for num_units in (512, 256, 128): + for use_gpu in (False, True): static_vs_dynamic_rnn_benchmark( batch_size, max_time, num_units, use_gpu) print("Calculation: Dynamic LSTM No Memory Swap vs. Memory Swap") print("batch \t max_t \t units \t no_swap \t swap \t swap/no_swap") for batch_size in (256, 512): - for max_time in (50, 100): + for max_time in (100,): for num_units in (512, 256, 128): dynamic_rnn_swap_memory_benchmark(batch_size, max_time, num_units) diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py index cddaaeee7b..6b50a30205 100644 --- a/tensorflow/python/ops/rnn.py +++ b/tensorflow/python/ops/rnn.py @@ -182,11 +182,11 @@ def state_saving_rnn(cell, inputs, state_saver, state_name, def _rnn_step( time, sequence_length, min_sequence_length, max_sequence_length, - zero_output, state, call_cell): + zero_output, state, call_cell, skip_conditionals=False): """Calculate one step of a dynamic RNN minibatch. Returns an (output, state) pair conditioned on the sequence_lengths. - The pseudocode is something like: + When skip_conditionals=False, the pseudocode is something like: if t >= max_sequence_length: return (zero_output, state) @@ -216,6 +216,10 @@ def _rnn_step( call_cell: lambda returning tuple of (new_output, new_state) where new_output is a `Tensor` matrix of shape [batch_size, output_size] new_state is a `Tensor` matrix of shape [batch_size, state_size] + skip_conditionals: Python bool, whether to skip using the conditional + calculations. This is useful for dynamic_rnn, where the input tensor + matches max_sequence_length, and using conditionals just slows + everything down. Returns: A tuple of (final_output, final_state) as given by the pseudocode above: @@ -225,8 +229,15 @@ def _rnn_step( # Step 1: determine whether we need to call_cell or not empty_update = lambda: (zero_output, state) state_shape = state.get_shape() - output, new_state = control_flow_ops.cond( - time < max_sequence_length, call_cell, empty_update) + + if skip_conditionals: + # Skip using conditionals: calculate the RNN step at all time + # steps. This is faster for dynamic_rnn, where the time steps + # should cap out at max_sequence_length anyway. + output, new_state = call_cell() + else: + output, new_state = control_flow_ops.cond( + time < max_sequence_length, call_cell, empty_update) # Step 2: determine whether we need to copy through state and/or outputs existing_output_state = lambda: (output, new_state) @@ -239,8 +250,17 @@ def _rnn_step( return (math_ops.select(copy_cond, zero_output, output), math_ops.select(copy_cond, state, new_state)) - (output, state) = control_flow_ops.cond( - time < min_sequence_length, existing_output_state, copy_through) + # TODO(ebrevdo): skipping these conditionals may cause a slowdown, + # but benefits from removing cond() and its gradient. We should + # profile with and without this switch here. + if skip_conditionals: + # Skip using conditionals: perform the selective copy at all time + # steps. This is usually faster. + (output, state) = copy_through() + else: + (output, state) = control_flow_ops.cond( + time < min_sequence_length, existing_output_state, copy_through) + output.set_shape(zero_output.get_shape()) state.set_shape(state_shape) return (output, state) @@ -549,8 +569,14 @@ def _dynamic_rnn_loop( if sequence_length is not None: (output, new_state) = _rnn_step( - time, sequence_length, min_sequence_length, max_sequence_length, - zero_output, state, call_cell) + time=time, + sequence_length=sequence_length, + min_sequence_length=min_sequence_length, + max_sequence_length=max_sequence_length, + zero_output=zero_output, + state=state, + call_cell=call_cell, + skip_conditionals=True) else: (output, new_state) = call_cell() |