aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/kernel_tests/rnn_test.py14
-rw-r--r--tensorflow/python/ops/rnn.py42
2 files changed, 41 insertions, 15 deletions
diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py
index be59ac08c2..82c432922c 100644
--- a/tensorflow/python/kernel_tests/rnn_test.py
+++ b/tensorflow/python/kernel_tests/rnn_test.py
@@ -957,7 +957,7 @@ def graph_creation_static_vs_dynamic_rnn_benchmark(max_time):
def _timer(sess, ops):
# Warm in
- for _ in range(5):
+ for _ in range(2):
sess.run(ops)
# Timing run
@@ -1100,24 +1100,24 @@ def rnn_long_sequence_benchmark(batch_size, seqlen, num_units,
def main(_):
print("Graph Creation: Static Unroll vs. Dynamic Unroll LSTM")
print("max_t \t dt(static) \t dt(dynamic) \t dt(dynamic)/dt(static)")
- for max_time in (1, 25, 50, 100, 200):
+ for max_time in (1, 25, 50):
graph_creation_static_vs_dynamic_rnn_benchmark(max_time)
print("Calculation: Static Unroll with Dynamic Flow LSTM "
"vs. Dynamic Unroll LSTM")
print("batch \t max_t \t units \t gpu \t dt(static) \t dt(dynamic) "
"\t dt(dynamic)/dt(static)")
- for use_gpu in (False, True):
- for batch_size in (256, 512):
- for max_time in (50, 100):
- for num_units in (512, 256, 128):
+ for batch_size in (256,):
+ for max_time in (50,):
+ for num_units in (512, 256, 128):
+ for use_gpu in (False, True):
static_vs_dynamic_rnn_benchmark(
batch_size, max_time, num_units, use_gpu)
print("Calculation: Dynamic LSTM No Memory Swap vs. Memory Swap")
print("batch \t max_t \t units \t no_swap \t swap \t swap/no_swap")
for batch_size in (256, 512):
- for max_time in (50, 100):
+ for max_time in (100,):
for num_units in (512, 256, 128):
dynamic_rnn_swap_memory_benchmark(batch_size, max_time, num_units)
diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py
index cddaaeee7b..6b50a30205 100644
--- a/tensorflow/python/ops/rnn.py
+++ b/tensorflow/python/ops/rnn.py
@@ -182,11 +182,11 @@ def state_saving_rnn(cell, inputs, state_saver, state_name,
def _rnn_step(
time, sequence_length, min_sequence_length, max_sequence_length,
- zero_output, state, call_cell):
+ zero_output, state, call_cell, skip_conditionals=False):
"""Calculate one step of a dynamic RNN minibatch.
Returns an (output, state) pair conditioned on the sequence_lengths.
- The pseudocode is something like:
+ When skip_conditionals=False, the pseudocode is something like:
if t >= max_sequence_length:
return (zero_output, state)
@@ -216,6 +216,10 @@ def _rnn_step(
call_cell: lambda returning tuple of (new_output, new_state) where
new_output is a `Tensor` matrix of shape [batch_size, output_size]
new_state is a `Tensor` matrix of shape [batch_size, state_size]
+ skip_conditionals: Python bool, whether to skip using the conditional
+ calculations. This is useful for dynamic_rnn, where the input tensor
+ matches max_sequence_length, and using conditionals just slows
+ everything down.
Returns:
A tuple of (final_output, final_state) as given by the pseudocode above:
@@ -225,8 +229,15 @@ def _rnn_step(
# Step 1: determine whether we need to call_cell or not
empty_update = lambda: (zero_output, state)
state_shape = state.get_shape()
- output, new_state = control_flow_ops.cond(
- time < max_sequence_length, call_cell, empty_update)
+
+ if skip_conditionals:
+ # Skip using conditionals: calculate the RNN step at all time
+ # steps. This is faster for dynamic_rnn, where the time steps
+ # should cap out at max_sequence_length anyway.
+ output, new_state = call_cell()
+ else:
+ output, new_state = control_flow_ops.cond(
+ time < max_sequence_length, call_cell, empty_update)
# Step 2: determine whether we need to copy through state and/or outputs
existing_output_state = lambda: (output, new_state)
@@ -239,8 +250,17 @@ def _rnn_step(
return (math_ops.select(copy_cond, zero_output, output),
math_ops.select(copy_cond, state, new_state))
- (output, state) = control_flow_ops.cond(
- time < min_sequence_length, existing_output_state, copy_through)
+ # TODO(ebrevdo): skipping these conditionals may cause a slowdown,
+ # but benefits from removing cond() and its gradient. We should
+ # profile with and without this switch here.
+ if skip_conditionals:
+ # Skip using conditionals: perform the selective copy at all time
+ # steps. This is usually faster.
+ (output, state) = copy_through()
+ else:
+ (output, state) = control_flow_ops.cond(
+ time < min_sequence_length, existing_output_state, copy_through)
+
output.set_shape(zero_output.get_shape())
state.set_shape(state_shape)
return (output, state)
@@ -549,8 +569,14 @@ def _dynamic_rnn_loop(
if sequence_length is not None:
(output, new_state) = _rnn_step(
- time, sequence_length, min_sequence_length, max_sequence_length,
- zero_output, state, call_cell)
+ time=time,
+ sequence_length=sequence_length,
+ min_sequence_length=min_sequence_length,
+ max_sequence_length=max_sequence_length,
+ zero_output=zero_output,
+ state=state,
+ call_cell=call_cell,
+ skip_conditionals=True)
else:
(output, new_state) = call_cell()