aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@gmail.com>2016-03-15 09:52:29 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-03-15 11:46:46 -0700
commitb10c65c2a00b9ce7972ab4300f49e83d0f849fd4 (patch)
tree1b2560e61e11e9df7f7b91ff0d7c90dbfbc2dffe
parentc9c341e4833612392ddd4fca3fd4343afaed47a4 (diff)
Remove use of conitional pass-through from dynamic_rnn in favor of just
calling select() at each step. conditionals add a bunch of extra ops that slow things down; and generally the size of the input tensor into dynamic_rnn matches max_sequence_length so they provide no benefit. Before change, benchmarks: Calculation: Static Unroll with Dynamic Flow LSTM vs. Dynamic Unroll LSTM batch max_t units gpu dt(static) dt(dynamic) dt(dynamic)/dt(static) 256 50 512 False 1.795002 1.774248 0.988437 256 50 512 True 0.186828 0.200752 1.074525 256 50 256 False 0.597320 0.750226 1.255986 256 50 256 True 0.082047 0.091411 1.114130 256 50 128 False 0.250596 0.238233 0.950666 256 50 128 True 0.056480 0.063086 1.116960 After change, benchmarks: Calculation: Static Unroll with Dynamic Flow LSTM vs. Dynamic Unroll LSTM batch max_t units gpu dt(static) dt(dynamic) dt(dynamic)/dt(static) 256 50 512 False 1.723348 1.763019 1.023020 256 50 512 True 0.186794 0.196334 1.051072 256 50 256 False 0.644540 0.704506 1.093036 256 50 256 True 0.082274 0.087785 1.066985 256 50 128 False 0.241971 0.234559 0.969368 256 50 128 True 0.056356 0.059771 1.060611 Basically expect a more significant decrease in GPU step time when the matrices are smaller. Change: 117254684
-rw-r--r--tensorflow/python/kernel_tests/rnn_test.py14
-rw-r--r--tensorflow/python/ops/rnn.py42
2 files changed, 41 insertions, 15 deletions
diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py
index be59ac08c2..82c432922c 100644
--- a/tensorflow/python/kernel_tests/rnn_test.py
+++ b/tensorflow/python/kernel_tests/rnn_test.py
@@ -957,7 +957,7 @@ def graph_creation_static_vs_dynamic_rnn_benchmark(max_time):
def _timer(sess, ops):
# Warm in
- for _ in range(5):
+ for _ in range(2):
sess.run(ops)
# Timing run
@@ -1100,24 +1100,24 @@ def rnn_long_sequence_benchmark(batch_size, seqlen, num_units,
def main(_):
print("Graph Creation: Static Unroll vs. Dynamic Unroll LSTM")
print("max_t \t dt(static) \t dt(dynamic) \t dt(dynamic)/dt(static)")
- for max_time in (1, 25, 50, 100, 200):
+ for max_time in (1, 25, 50):
graph_creation_static_vs_dynamic_rnn_benchmark(max_time)
print("Calculation: Static Unroll with Dynamic Flow LSTM "
"vs. Dynamic Unroll LSTM")
print("batch \t max_t \t units \t gpu \t dt(static) \t dt(dynamic) "
"\t dt(dynamic)/dt(static)")
- for use_gpu in (False, True):
- for batch_size in (256, 512):
- for max_time in (50, 100):
- for num_units in (512, 256, 128):
+ for batch_size in (256,):
+ for max_time in (50,):
+ for num_units in (512, 256, 128):
+ for use_gpu in (False, True):
static_vs_dynamic_rnn_benchmark(
batch_size, max_time, num_units, use_gpu)
print("Calculation: Dynamic LSTM No Memory Swap vs. Memory Swap")
print("batch \t max_t \t units \t no_swap \t swap \t swap/no_swap")
for batch_size in (256, 512):
- for max_time in (50, 100):
+ for max_time in (100,):
for num_units in (512, 256, 128):
dynamic_rnn_swap_memory_benchmark(batch_size, max_time, num_units)
diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py
index cddaaeee7b..6b50a30205 100644
--- a/tensorflow/python/ops/rnn.py
+++ b/tensorflow/python/ops/rnn.py
@@ -182,11 +182,11 @@ def state_saving_rnn(cell, inputs, state_saver, state_name,
def _rnn_step(
time, sequence_length, min_sequence_length, max_sequence_length,
- zero_output, state, call_cell):
+ zero_output, state, call_cell, skip_conditionals=False):
"""Calculate one step of a dynamic RNN minibatch.
Returns an (output, state) pair conditioned on the sequence_lengths.
- The pseudocode is something like:
+ When skip_conditionals=False, the pseudocode is something like:
if t >= max_sequence_length:
return (zero_output, state)
@@ -216,6 +216,10 @@ def _rnn_step(
call_cell: lambda returning tuple of (new_output, new_state) where
new_output is a `Tensor` matrix of shape [batch_size, output_size]
new_state is a `Tensor` matrix of shape [batch_size, state_size]
+ skip_conditionals: Python bool, whether to skip using the conditional
+ calculations. This is useful for dynamic_rnn, where the input tensor
+ matches max_sequence_length, and using conditionals just slows
+ everything down.
Returns:
A tuple of (final_output, final_state) as given by the pseudocode above:
@@ -225,8 +229,15 @@ def _rnn_step(
# Step 1: determine whether we need to call_cell or not
empty_update = lambda: (zero_output, state)
state_shape = state.get_shape()
- output, new_state = control_flow_ops.cond(
- time < max_sequence_length, call_cell, empty_update)
+
+ if skip_conditionals:
+ # Skip using conditionals: calculate the RNN step at all time
+ # steps. This is faster for dynamic_rnn, where the time steps
+ # should cap out at max_sequence_length anyway.
+ output, new_state = call_cell()
+ else:
+ output, new_state = control_flow_ops.cond(
+ time < max_sequence_length, call_cell, empty_update)
# Step 2: determine whether we need to copy through state and/or outputs
existing_output_state = lambda: (output, new_state)
@@ -239,8 +250,17 @@ def _rnn_step(
return (math_ops.select(copy_cond, zero_output, output),
math_ops.select(copy_cond, state, new_state))
- (output, state) = control_flow_ops.cond(
- time < min_sequence_length, existing_output_state, copy_through)
+ # TODO(ebrevdo): skipping these conditionals may cause a slowdown,
+ # but benefits from removing cond() and its gradient. We should
+ # profile with and without this switch here.
+ if skip_conditionals:
+ # Skip using conditionals: perform the selective copy at all time
+ # steps. This is usually faster.
+ (output, state) = copy_through()
+ else:
+ (output, state) = control_flow_ops.cond(
+ time < min_sequence_length, existing_output_state, copy_through)
+
output.set_shape(zero_output.get_shape())
state.set_shape(state_shape)
return (output, state)
@@ -549,8 +569,14 @@ def _dynamic_rnn_loop(
if sequence_length is not None:
(output, new_state) = _rnn_step(
- time, sequence_length, min_sequence_length, max_sequence_length,
- zero_output, state, call_cell)
+ time=time,
+ sequence_length=sequence_length,
+ min_sequence_length=min_sequence_length,
+ max_sequence_length=max_sequence_length,
+ zero_output=zero_output,
+ state=state,
+ call_cell=call_cell,
+ skip_conditionals=True)
else:
(output, new_state) = call_cell()