diff options
Diffstat (limited to 'tensorflow/python/ops/rnn.py')
-rw-r--r-- | tensorflow/python/ops/rnn.py | 10 |
1 files changed, 6 insertions, 4 deletions
diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py index ad916b6b5f..611f5fa314 100644 --- a/tensorflow/python/ops/rnn.py +++ b/tensorflow/python/ops/rnn.py @@ -312,9 +312,11 @@ def bidirectional_rnn(cell_fw, cell_bw, inputs, scope: VariableScope for the created subgraph; defaults to "BiRNN" Returns: - A set of output `Tensors` where: + A tuple (outputs, output_state_fw, output_state_bw) where: outputs is a length T list of outputs (one for each input), which are depth-concatenated forward and backward outputs + output_state_fw is the final state of the forward rnn + output_state_bw is the final state of the backward rnn Raises: TypeError: If "cell_fw" or "cell_bw" is not an instance of RNNCell. @@ -333,19 +335,19 @@ def bidirectional_rnn(cell_fw, cell_bw, inputs, name = scope or "BiRNN" # Forward direction with vs.variable_scope(name + "_FW") as fw_scope: - output_fw, _ = rnn(cell_fw, inputs, initial_state_fw, dtype, + output_fw, output_state_fw = rnn(cell_fw, inputs, initial_state_fw, dtype, sequence_length, scope=fw_scope) # Backward direction with vs.variable_scope(name + "_BW") as bw_scope: - tmp, _ = rnn(cell_bw, _reverse_seq(inputs, sequence_length), + tmp, output_state_bw = rnn(cell_bw, _reverse_seq(inputs, sequence_length), initial_state_bw, dtype, sequence_length, scope=bw_scope) output_bw = _reverse_seq(tmp, sequence_length) # Concat each of the forward/backward outputs outputs = [array_ops.concat(1, [fw, bw]) for fw, bw in zip(output_fw, output_bw)] - return outputs + return (outputs, output_state_fw, output_state_bw) def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None, |