diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-24 17:26:37 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-24 17:30:02 -0700 |
commit | 9ab01c6732dae1143e22713375a9cc7758216787 (patch) | |
tree | 47de4f82fe041ccb6a50f7384f5821f5c20cb264 /tensorflow/contrib/recurrent | |
parent | 73083d29afe770870742a9d19555686886e76f6d (diff) |
Update the functional rnn API to add a fast path when cell function is noop for pad input.
PiperOrigin-RevId: 214360620
Diffstat (limited to 'tensorflow/contrib/recurrent')
-rw-r--r-- | tensorflow/contrib/recurrent/python/ops/functional_rnn.py | 96 | ||||
-rw-r--r-- | tensorflow/contrib/recurrent/python/ops/recurrent.py | 37 |
2 files changed, 93 insertions, 40 deletions
diff --git a/tensorflow/contrib/recurrent/python/ops/functional_rnn.py b/tensorflow/contrib/recurrent/python/ops/functional_rnn.py index efaf63086f..3abf7bd6da 100644 --- a/tensorflow/contrib/recurrent/python/ops/functional_rnn.py +++ b/tensorflow/contrib/recurrent/python/ops/functional_rnn.py @@ -219,7 +219,7 @@ def _PickFinalStateFromHistory(acc_state, sequence_length): def _PostProcessOutput(extended_acc_state, extended_final_state, func_cell, - total_time, inputs_lengths): + total_time, inputs_lengths, is_reversed): """Post-process output of recurrent. This function takes the accumulated extended state and extracts the requested @@ -228,6 +228,8 @@ def _PostProcessOutput(extended_acc_state, extended_final_state, func_cell, When `inputs_lengths` has been set, it extracts the output from the accumulated state. It also sets outputs past. + When `is_reversed` is true, the output will be reversed in this function. + It also sets the static shape information. Args: @@ -238,11 +240,12 @@ def _PostProcessOutput(extended_acc_state, extended_final_state, func_cell, func_cell: The functional wrapper around the cell. total_time: A scalar integer tensor. inputs_lengths: An integer tensor with one entry per input. + is_reversed: A boolean to indicate if the sequence is reversed. Returns: A tuple with the outputs at each time, and the final state. """ - if inputs_lengths is None: + if inputs_lengths is None or is_reversed: flat_final_state = func_cell.MaybeRemoveOutputFromState( nest.flatten(extended_final_state)) tf_state = nest.pack_sequence_as(func_cell.state_template, flat_final_state) @@ -256,21 +259,28 @@ def _PostProcessOutput(extended_acc_state, extended_final_state, func_cell, tf_state = _PickFinalStateFromHistory(acc_state, inputs_lengths) output_from_state = func_cell.GetOutputFromState(extended_acc_state) + if is_reversed: + output_from_state = array_ops.reverse(output_from_state, [0]) tf_output = array_ops.transpose(output_from_state, [1, 0, 2]) tf_output.set_shape( [func_cell.output_shape[0], total_time, func_cell.output_shape[1]]) if inputs_lengths is not None: # Need set the outputs to zero. tf_output = _ApplyLengthsToBatch(inputs_lengths, tf_output) - # tf_output = array_ops.zeros([4, 3, 5]) _SetShapeFromTemplate(tf_state, func_cell.state_template) return tf_output, tf_state # pylint: disable=invalid-name -def functional_rnn(cell, inputs, sequence_length=None, - initial_state=None, dtype=None, time_major=False, - scope=None, use_tpu=False): +def functional_rnn(cell, + inputs, + sequence_length=None, + initial_state=None, + dtype=None, + time_major=False, + scope=None, + use_tpu=False, + reverse=False): """Same interface as `tf.nn.dynamic_rnn`.""" with variable_scope.variable_scope(scope or 'rnn'): if not time_major: @@ -285,33 +295,41 @@ def functional_rnn(cell, inputs, sequence_length=None, max_length = math_ops.reduce_max(sequence_length) else: max_length = None + if reverse: + inputs = array_ops.reverse(inputs, [0]) extended_acc_state, extended_final_state = recurrent.Recurrent( theta=func_cell.theta, state0=func_cell.extended_initial_state, inputs=inputs, cell_fn=func_cell.cell_step, max_input_length=max_length, - use_tpu=use_tpu) + use_tpu=use_tpu, + aligned_end=reverse) + tf_output, tf_state = _PostProcessOutput( - extended_acc_state, extended_final_state, func_cell, - inputs_flat[0].shape[0], sequence_length) + extended_acc_state, + extended_final_state, + func_cell, + inputs_flat[0].shape[0], + sequence_length, + is_reversed=reverse) if time_major: tf_output = array_ops.transpose(tf_output, [1, 0, 2]) return tf_output, tf_state -def bidirectional_functional_rnn( - cell_fw, - cell_bw, - inputs, - initial_state_fw=None, - initial_state_bw=None, - dtype=None, - sequence_length=None, - time_major=False, - use_tpu=False, - scope=None): +def bidirectional_functional_rnn(cell_fw, + cell_bw, + inputs, + initial_state_fw=None, + initial_state_bw=None, + dtype=None, + sequence_length=None, + time_major=False, + use_tpu=False, + fast_reverse=False, + scope=None): """Creates a bidirectional recurrent neural network. Performs fully dynamic unrolling of inputs in both directions. Built to be API @@ -342,6 +360,10 @@ def bidirectional_functional_rnn( use_tpu: Whether to enable TPU-compatible operation. If True, does not truly reverse `inputs` in the backwards RNN. Once b/69305369 is fixed, we can remove this flag. + fast_reverse: Whether to use fast tf.reverse to replace tf.reverse_sequence. + This is only possible when either all sequence lengths are the same inside + the batch, or when the cell function does not change the state on padded + input. scope: An optional scope name for the dynamic RNN. Returns: @@ -390,17 +412,29 @@ def bidirectional_functional_rnn( return array_ops.reverse(input_, axis=[seq_dim]) with variable_scope.variable_scope('bw') as bw_scope: - inputs_reverse = _reverse( - inputs, seq_lengths=sequence_length, - seq_dim=time_dim, batch_dim=batch_dim) - tmp, output_state_bw = functional_rnn( - cell=cell_bw, inputs=inputs_reverse, sequence_length=sequence_length, - initial_state=initial_state_bw, dtype=dtype, - time_major=time_major, scope=bw_scope, use_tpu=use_tpu) - - output_bw = _reverse( - tmp, seq_lengths=sequence_length, - seq_dim=time_dim, batch_dim=batch_dim) + if not fast_reverse: + inputs = _reverse( + inputs, + seq_lengths=sequence_length, + seq_dim=time_dim, + batch_dim=batch_dim) + output_bw, output_state_bw = functional_rnn( + cell=cell_bw, + inputs=inputs, + sequence_length=sequence_length, + initial_state=initial_state_bw, + dtype=dtype, + time_major=time_major, + scope=bw_scope, + use_tpu=use_tpu, + reverse=fast_reverse) + + if not fast_reverse: + output_bw = _reverse( + output_bw, + seq_lengths=sequence_length, + seq_dim=time_dim, + batch_dim=batch_dim) outputs = (output_fw, output_bw) output_states = (output_state_fw, output_state_bw) diff --git a/tensorflow/contrib/recurrent/python/ops/recurrent.py b/tensorflow/contrib/recurrent/python/ops/recurrent.py index 4f289e0c85..f51de755d8 100644 --- a/tensorflow/contrib/recurrent/python/ops/recurrent.py +++ b/tensorflow/contrib/recurrent/python/ops/recurrent.py @@ -274,8 +274,16 @@ def _ConvertNoneGradientToZeros(xs, dxs): class _Recurrent(object): """A helper class to construct a recurrent neural net.""" - def __init__(self, cell_fn, cell_grad, theta, state0, inputs, - max_input_length, extras, use_tpu): + def __init__(self, + cell_fn, + cell_grad, + theta, + state0, + inputs, + max_input_length, + extras, + use_tpu, + aligned_end=False): """RNN helper class. Args: @@ -294,6 +302,8 @@ class _Recurrent(object): and shapes of this `extras`. use_tpu: A boolean indicating whether the computation is mean to run on a TPU. + aligned_end: A boolean indicating whether the sequence is aligned at + the end. """ self._theta = theta self._state = state0 @@ -303,6 +313,7 @@ class _Recurrent(object): self._cell_fn = cell_fn self._cell_grad = cell_grad self._extras = extras + self._aligned_end = aligned_end # pylint: disable=unbalanced-tuple-unpacking @@ -417,10 +428,11 @@ class _Recurrent(object): acc_state = _EmptyAcc(slen_dim, state0) acc_extras = _EmptyAcc(slen_dim, extras) - dev_t = array_ops.constant(0, dtype=dev_t_type) + t = slen_dim - max_input_length if self._aligned_end else 0 + dev_t = math_ops.to_int32(t) if use_tpu else math_ops.to_int64(t) run = functional_ops.For( - start=0, - limit=max_input_length, + start=t, + limit=slen_dim if self._aligned_end else max_input_length, delta=1, inputs=[dev_t] + _Flatten( [theta, state0, inputs, acc_state, acc_extras]), @@ -551,13 +563,16 @@ class _Recurrent(object): d_theta = _EmptyLike(theta) d_inputs = _EmptyLike(inputs) + slen_dim = _SeqLenDim(inputs) + # Loop backwards. Note the loop's limit is open-ended, so goes through # t=0. - t = max_input_length - 1 + t = slen_dim - 1 if self._aligned_end else max_input_length - 1 dev_t = math_ops.to_int32(t) if use_tpu else math_ops.to_int64(t) + limit = slen_dim - max_input_length - 1 if self._aligned_end else -1 run = functional_ops.For( start=t, - limit=-1, + limit=limit, delta=-1, inputs=[dev_t] + _Flatten([ theta, state0, inputs, acc_state, acc_extras, d_theta, d_state1, @@ -641,7 +656,8 @@ def Recurrent(theta, cell_grad=None, extras=None, max_input_length=None, - use_tpu=False): + use_tpu=False, + aligned_end=False): """Compute a recurrent neural net. Roughly, Recurrent() computes the following: @@ -684,6 +700,8 @@ def Recurrent(theta, truncate the computation if the inputs have been allocated to a larger size. A scalar tensor. use_tpu: whether or not we are on TPU. + aligned_end: A boolean indicating whether the sequence is aligned at + the end. Returns: accumulate_state and the final state. @@ -717,4 +735,5 @@ def Recurrent(theta, inputs=inputs, max_input_length=max_input_length, extras=extras, - use_tpu=use_tpu).Compute() + use_tpu=use_tpu, + aligned_end=aligned_end).Compute() |