diff options
author | 2016-03-15 09:52:29 -0800 | |
---|---|---|
committer | 2016-03-15 11:46:46 -0700 | |
commit | b10c65c2a00b9ce7972ab4300f49e83d0f849fd4 (patch) | |
tree | 1b2560e61e11e9df7f7b91ff0d7c90dbfbc2dffe | |
parent | c9c341e4833612392ddd4fca3fd4343afaed47a4 (diff) |
Remove use of conitional pass-through from dynamic_rnn in favor of just
calling select() at each step. conditionals add a bunch of extra ops
that slow things down; and generally the size of the input tensor into
dynamic_rnn matches max_sequence_length so they provide no benefit.
Before change, benchmarks:
Calculation: Static Unroll with Dynamic Flow LSTM vs. Dynamic Unroll LSTM
batch max_t units gpu dt(static) dt(dynamic) dt(dynamic)/dt(static)
256 50 512 False 1.795002 1.774248 0.988437
256 50 512 True 0.186828 0.200752 1.074525
256 50 256 False 0.597320 0.750226 1.255986
256 50 256 True 0.082047 0.091411 1.114130
256 50 128 False 0.250596 0.238233 0.950666
256 50 128 True 0.056480 0.063086 1.116960
After change, benchmarks:
Calculation: Static Unroll with Dynamic Flow LSTM vs. Dynamic Unroll LSTM batch max_t units gpu dt(static) dt(dynamic) dt(dynamic)/dt(static)
256 50 512 False 1.723348 1.763019 1.023020
256 50 512 True 0.186794 0.196334 1.051072
256 50 256 False 0.644540 0.704506 1.093036
256 50 256 True 0.082274 0.087785 1.066985
256 50 128 False 0.241971 0.234559 0.969368
256 50 128 True 0.056356 0.059771 1.060611
Basically expect a more significant decrease in GPU step time when the matrices are smaller.
Change: 117254684
-rw-r--r-- | tensorflow/python/kernel_tests/rnn_test.py | 14 | ||||
-rw-r--r-- | tensorflow/python/ops/rnn.py | 42 |
2 files changed, 41 insertions, 15 deletions
diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py index be59ac08c2..82c432922c 100644 --- a/tensorflow/python/kernel_tests/rnn_test.py +++ b/tensorflow/python/kernel_tests/rnn_test.py @@ -957,7 +957,7 @@ def graph_creation_static_vs_dynamic_rnn_benchmark(max_time): def _timer(sess, ops): # Warm in - for _ in range(5): + for _ in range(2): sess.run(ops) # Timing run @@ -1100,24 +1100,24 @@ def rnn_long_sequence_benchmark(batch_size, seqlen, num_units, def main(_): print("Graph Creation: Static Unroll vs. Dynamic Unroll LSTM") print("max_t \t dt(static) \t dt(dynamic) \t dt(dynamic)/dt(static)") - for max_time in (1, 25, 50, 100, 200): + for max_time in (1, 25, 50): graph_creation_static_vs_dynamic_rnn_benchmark(max_time) print("Calculation: Static Unroll with Dynamic Flow LSTM " "vs. Dynamic Unroll LSTM") print("batch \t max_t \t units \t gpu \t dt(static) \t dt(dynamic) " "\t dt(dynamic)/dt(static)") - for use_gpu in (False, True): - for batch_size in (256, 512): - for max_time in (50, 100): - for num_units in (512, 256, 128): + for batch_size in (256,): + for max_time in (50,): + for num_units in (512, 256, 128): + for use_gpu in (False, True): static_vs_dynamic_rnn_benchmark( batch_size, max_time, num_units, use_gpu) print("Calculation: Dynamic LSTM No Memory Swap vs. Memory Swap") print("batch \t max_t \t units \t no_swap \t swap \t swap/no_swap") for batch_size in (256, 512): - for max_time in (50, 100): + for max_time in (100,): for num_units in (512, 256, 128): dynamic_rnn_swap_memory_benchmark(batch_size, max_time, num_units) diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py index cddaaeee7b..6b50a30205 100644 --- a/tensorflow/python/ops/rnn.py +++ b/tensorflow/python/ops/rnn.py @@ -182,11 +182,11 @@ def state_saving_rnn(cell, inputs, state_saver, state_name, def _rnn_step( time, sequence_length, min_sequence_length, max_sequence_length, - zero_output, state, call_cell): + zero_output, state, call_cell, skip_conditionals=False): """Calculate one step of a dynamic RNN minibatch. Returns an (output, state) pair conditioned on the sequence_lengths. - The pseudocode is something like: + When skip_conditionals=False, the pseudocode is something like: if t >= max_sequence_length: return (zero_output, state) @@ -216,6 +216,10 @@ def _rnn_step( call_cell: lambda returning tuple of (new_output, new_state) where new_output is a `Tensor` matrix of shape [batch_size, output_size] new_state is a `Tensor` matrix of shape [batch_size, state_size] + skip_conditionals: Python bool, whether to skip using the conditional + calculations. This is useful for dynamic_rnn, where the input tensor + matches max_sequence_length, and using conditionals just slows + everything down. Returns: A tuple of (final_output, final_state) as given by the pseudocode above: @@ -225,8 +229,15 @@ def _rnn_step( # Step 1: determine whether we need to call_cell or not empty_update = lambda: (zero_output, state) state_shape = state.get_shape() - output, new_state = control_flow_ops.cond( - time < max_sequence_length, call_cell, empty_update) + + if skip_conditionals: + # Skip using conditionals: calculate the RNN step at all time + # steps. This is faster for dynamic_rnn, where the time steps + # should cap out at max_sequence_length anyway. + output, new_state = call_cell() + else: + output, new_state = control_flow_ops.cond( + time < max_sequence_length, call_cell, empty_update) # Step 2: determine whether we need to copy through state and/or outputs existing_output_state = lambda: (output, new_state) @@ -239,8 +250,17 @@ def _rnn_step( return (math_ops.select(copy_cond, zero_output, output), math_ops.select(copy_cond, state, new_state)) - (output, state) = control_flow_ops.cond( - time < min_sequence_length, existing_output_state, copy_through) + # TODO(ebrevdo): skipping these conditionals may cause a slowdown, + # but benefits from removing cond() and its gradient. We should + # profile with and without this switch here. + if skip_conditionals: + # Skip using conditionals: perform the selective copy at all time + # steps. This is usually faster. + (output, state) = copy_through() + else: + (output, state) = control_flow_ops.cond( + time < min_sequence_length, existing_output_state, copy_through) + output.set_shape(zero_output.get_shape()) state.set_shape(state_shape) return (output, state) @@ -549,8 +569,14 @@ def _dynamic_rnn_loop( if sequence_length is not None: (output, new_state) = _rnn_step( - time, sequence_length, min_sequence_length, max_sequence_length, - zero_output, state, call_cell) + time=time, + sequence_length=sequence_length, + min_sequence_length=min_sequence_length, + max_sequence_length=max_sequence_length, + zero_output=zero_output, + state=state, + call_cell=call_cell, + skip_conditionals=True) else: (output, new_state) = call_cell() |