aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/recurrent
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-24 17:26:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-24 17:30:02 -0700
commit9ab01c6732dae1143e22713375a9cc7758216787 (patch)
tree47de4f82fe041ccb6a50f7384f5821f5c20cb264 /tensorflow/contrib/recurrent
parent73083d29afe770870742a9d19555686886e76f6d (diff)
Update the functional rnn API to add a fast path when cell function is noop for pad input.
PiperOrigin-RevId: 214360620
Diffstat (limited to 'tensorflow/contrib/recurrent')
-rw-r--r--tensorflow/contrib/recurrent/python/ops/functional_rnn.py96
-rw-r--r--tensorflow/contrib/recurrent/python/ops/recurrent.py37
2 files changed, 93 insertions, 40 deletions
diff --git a/tensorflow/contrib/recurrent/python/ops/functional_rnn.py b/tensorflow/contrib/recurrent/python/ops/functional_rnn.py
index efaf63086f..3abf7bd6da 100644
--- a/tensorflow/contrib/recurrent/python/ops/functional_rnn.py
+++ b/tensorflow/contrib/recurrent/python/ops/functional_rnn.py
@@ -219,7 +219,7 @@ def _PickFinalStateFromHistory(acc_state, sequence_length):
def _PostProcessOutput(extended_acc_state, extended_final_state, func_cell,
- total_time, inputs_lengths):
+ total_time, inputs_lengths, is_reversed):
"""Post-process output of recurrent.
This function takes the accumulated extended state and extracts the requested
@@ -228,6 +228,8 @@ def _PostProcessOutput(extended_acc_state, extended_final_state, func_cell,
When `inputs_lengths` has been set, it extracts the output from the
accumulated state. It also sets outputs past.
+ When `is_reversed` is true, the output will be reversed in this function.
+
It also sets the static shape information.
Args:
@@ -238,11 +240,12 @@ def _PostProcessOutput(extended_acc_state, extended_final_state, func_cell,
func_cell: The functional wrapper around the cell.
total_time: A scalar integer tensor.
inputs_lengths: An integer tensor with one entry per input.
+ is_reversed: A boolean to indicate if the sequence is reversed.
Returns:
A tuple with the outputs at each time, and the final state.
"""
- if inputs_lengths is None:
+ if inputs_lengths is None or is_reversed:
flat_final_state = func_cell.MaybeRemoveOutputFromState(
nest.flatten(extended_final_state))
tf_state = nest.pack_sequence_as(func_cell.state_template, flat_final_state)
@@ -256,21 +259,28 @@ def _PostProcessOutput(extended_acc_state, extended_final_state, func_cell,
tf_state = _PickFinalStateFromHistory(acc_state, inputs_lengths)
output_from_state = func_cell.GetOutputFromState(extended_acc_state)
+ if is_reversed:
+ output_from_state = array_ops.reverse(output_from_state, [0])
tf_output = array_ops.transpose(output_from_state, [1, 0, 2])
tf_output.set_shape(
[func_cell.output_shape[0], total_time, func_cell.output_shape[1]])
if inputs_lengths is not None:
# Need set the outputs to zero.
tf_output = _ApplyLengthsToBatch(inputs_lengths, tf_output)
- # tf_output = array_ops.zeros([4, 3, 5])
_SetShapeFromTemplate(tf_state, func_cell.state_template)
return tf_output, tf_state
# pylint: disable=invalid-name
-def functional_rnn(cell, inputs, sequence_length=None,
- initial_state=None, dtype=None, time_major=False,
- scope=None, use_tpu=False):
+def functional_rnn(cell,
+ inputs,
+ sequence_length=None,
+ initial_state=None,
+ dtype=None,
+ time_major=False,
+ scope=None,
+ use_tpu=False,
+ reverse=False):
"""Same interface as `tf.nn.dynamic_rnn`."""
with variable_scope.variable_scope(scope or 'rnn'):
if not time_major:
@@ -285,33 +295,41 @@ def functional_rnn(cell, inputs, sequence_length=None,
max_length = math_ops.reduce_max(sequence_length)
else:
max_length = None
+ if reverse:
+ inputs = array_ops.reverse(inputs, [0])
extended_acc_state, extended_final_state = recurrent.Recurrent(
theta=func_cell.theta,
state0=func_cell.extended_initial_state,
inputs=inputs,
cell_fn=func_cell.cell_step,
max_input_length=max_length,
- use_tpu=use_tpu)
+ use_tpu=use_tpu,
+ aligned_end=reverse)
+
tf_output, tf_state = _PostProcessOutput(
- extended_acc_state, extended_final_state, func_cell,
- inputs_flat[0].shape[0], sequence_length)
+ extended_acc_state,
+ extended_final_state,
+ func_cell,
+ inputs_flat[0].shape[0],
+ sequence_length,
+ is_reversed=reverse)
if time_major:
tf_output = array_ops.transpose(tf_output, [1, 0, 2])
return tf_output, tf_state
-def bidirectional_functional_rnn(
- cell_fw,
- cell_bw,
- inputs,
- initial_state_fw=None,
- initial_state_bw=None,
- dtype=None,
- sequence_length=None,
- time_major=False,
- use_tpu=False,
- scope=None):
+def bidirectional_functional_rnn(cell_fw,
+ cell_bw,
+ inputs,
+ initial_state_fw=None,
+ initial_state_bw=None,
+ dtype=None,
+ sequence_length=None,
+ time_major=False,
+ use_tpu=False,
+ fast_reverse=False,
+ scope=None):
"""Creates a bidirectional recurrent neural network.
Performs fully dynamic unrolling of inputs in both directions. Built to be API
@@ -342,6 +360,10 @@ def bidirectional_functional_rnn(
use_tpu: Whether to enable TPU-compatible operation. If True, does not truly
reverse `inputs` in the backwards RNN. Once b/69305369 is fixed, we can
remove this flag.
+ fast_reverse: Whether to use fast tf.reverse to replace tf.reverse_sequence.
+ This is only possible when either all sequence lengths are the same inside
+ the batch, or when the cell function does not change the state on padded
+ input.
scope: An optional scope name for the dynamic RNN.
Returns:
@@ -390,17 +412,29 @@ def bidirectional_functional_rnn(
return array_ops.reverse(input_, axis=[seq_dim])
with variable_scope.variable_scope('bw') as bw_scope:
- inputs_reverse = _reverse(
- inputs, seq_lengths=sequence_length,
- seq_dim=time_dim, batch_dim=batch_dim)
- tmp, output_state_bw = functional_rnn(
- cell=cell_bw, inputs=inputs_reverse, sequence_length=sequence_length,
- initial_state=initial_state_bw, dtype=dtype,
- time_major=time_major, scope=bw_scope, use_tpu=use_tpu)
-
- output_bw = _reverse(
- tmp, seq_lengths=sequence_length,
- seq_dim=time_dim, batch_dim=batch_dim)
+ if not fast_reverse:
+ inputs = _reverse(
+ inputs,
+ seq_lengths=sequence_length,
+ seq_dim=time_dim,
+ batch_dim=batch_dim)
+ output_bw, output_state_bw = functional_rnn(
+ cell=cell_bw,
+ inputs=inputs,
+ sequence_length=sequence_length,
+ initial_state=initial_state_bw,
+ dtype=dtype,
+ time_major=time_major,
+ scope=bw_scope,
+ use_tpu=use_tpu,
+ reverse=fast_reverse)
+
+ if not fast_reverse:
+ output_bw = _reverse(
+ output_bw,
+ seq_lengths=sequence_length,
+ seq_dim=time_dim,
+ batch_dim=batch_dim)
outputs = (output_fw, output_bw)
output_states = (output_state_fw, output_state_bw)
diff --git a/tensorflow/contrib/recurrent/python/ops/recurrent.py b/tensorflow/contrib/recurrent/python/ops/recurrent.py
index 4f289e0c85..f51de755d8 100644
--- a/tensorflow/contrib/recurrent/python/ops/recurrent.py
+++ b/tensorflow/contrib/recurrent/python/ops/recurrent.py
@@ -274,8 +274,16 @@ def _ConvertNoneGradientToZeros(xs, dxs):
class _Recurrent(object):
"""A helper class to construct a recurrent neural net."""
- def __init__(self, cell_fn, cell_grad, theta, state0, inputs,
- max_input_length, extras, use_tpu):
+ def __init__(self,
+ cell_fn,
+ cell_grad,
+ theta,
+ state0,
+ inputs,
+ max_input_length,
+ extras,
+ use_tpu,
+ aligned_end=False):
"""RNN helper class.
Args:
@@ -294,6 +302,8 @@ class _Recurrent(object):
and shapes of this `extras`.
use_tpu: A boolean indicating whether the computation is mean to
run on a TPU.
+ aligned_end: A boolean indicating whether the sequence is aligned at
+ the end.
"""
self._theta = theta
self._state = state0
@@ -303,6 +313,7 @@ class _Recurrent(object):
self._cell_fn = cell_fn
self._cell_grad = cell_grad
self._extras = extras
+ self._aligned_end = aligned_end
# pylint: disable=unbalanced-tuple-unpacking
@@ -417,10 +428,11 @@ class _Recurrent(object):
acc_state = _EmptyAcc(slen_dim, state0)
acc_extras = _EmptyAcc(slen_dim, extras)
- dev_t = array_ops.constant(0, dtype=dev_t_type)
+ t = slen_dim - max_input_length if self._aligned_end else 0
+ dev_t = math_ops.to_int32(t) if use_tpu else math_ops.to_int64(t)
run = functional_ops.For(
- start=0,
- limit=max_input_length,
+ start=t,
+ limit=slen_dim if self._aligned_end else max_input_length,
delta=1,
inputs=[dev_t] + _Flatten(
[theta, state0, inputs, acc_state, acc_extras]),
@@ -551,13 +563,16 @@ class _Recurrent(object):
d_theta = _EmptyLike(theta)
d_inputs = _EmptyLike(inputs)
+ slen_dim = _SeqLenDim(inputs)
+
# Loop backwards. Note the loop's limit is open-ended, so goes through
# t=0.
- t = max_input_length - 1
+ t = slen_dim - 1 if self._aligned_end else max_input_length - 1
dev_t = math_ops.to_int32(t) if use_tpu else math_ops.to_int64(t)
+ limit = slen_dim - max_input_length - 1 if self._aligned_end else -1
run = functional_ops.For(
start=t,
- limit=-1,
+ limit=limit,
delta=-1,
inputs=[dev_t] + _Flatten([
theta, state0, inputs, acc_state, acc_extras, d_theta, d_state1,
@@ -641,7 +656,8 @@ def Recurrent(theta,
cell_grad=None,
extras=None,
max_input_length=None,
- use_tpu=False):
+ use_tpu=False,
+ aligned_end=False):
"""Compute a recurrent neural net.
Roughly, Recurrent() computes the following:
@@ -684,6 +700,8 @@ def Recurrent(theta,
truncate the computation if the inputs have been allocated to a
larger size. A scalar tensor.
use_tpu: whether or not we are on TPU.
+ aligned_end: A boolean indicating whether the sequence is aligned at
+ the end.
Returns:
accumulate_state and the final state.
@@ -717,4 +735,5 @@ def Recurrent(theta,
inputs=inputs,
max_input_length=max_input_length,
extras=extras,
- use_tpu=use_tpu).Compute()
+ use_tpu=use_tpu,
+ aligned_end=aligned_end).Compute()