aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/rnn.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/rnn.py')
-rw-r--r--tensorflow/python/ops/rnn.py10
1 files changed, 6 insertions, 4 deletions
diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py
index ad916b6b5f..611f5fa314 100644
--- a/tensorflow/python/ops/rnn.py
+++ b/tensorflow/python/ops/rnn.py
@@ -312,9 +312,11 @@ def bidirectional_rnn(cell_fw, cell_bw, inputs,
scope: VariableScope for the created subgraph; defaults to "BiRNN"
Returns:
- A set of output `Tensors` where:
+ A tuple (outputs, output_state_fw, output_state_bw) where:
outputs is a length T list of outputs (one for each input), which
are depth-concatenated forward and backward outputs
+ output_state_fw is the final state of the forward rnn
+ output_state_bw is the final state of the backward rnn
Raises:
TypeError: If "cell_fw" or "cell_bw" is not an instance of RNNCell.
@@ -333,19 +335,19 @@ def bidirectional_rnn(cell_fw, cell_bw, inputs,
name = scope or "BiRNN"
# Forward direction
with vs.variable_scope(name + "_FW") as fw_scope:
- output_fw, _ = rnn(cell_fw, inputs, initial_state_fw, dtype,
+ output_fw, output_state_fw = rnn(cell_fw, inputs, initial_state_fw, dtype,
sequence_length, scope=fw_scope)
# Backward direction
with vs.variable_scope(name + "_BW") as bw_scope:
- tmp, _ = rnn(cell_bw, _reverse_seq(inputs, sequence_length),
+ tmp, output_state_bw = rnn(cell_bw, _reverse_seq(inputs, sequence_length),
initial_state_bw, dtype, sequence_length, scope=bw_scope)
output_bw = _reverse_seq(tmp, sequence_length)
# Concat each of the forward/backward outputs
outputs = [array_ops.concat(1, [fw, bw])
for fw, bw in zip(output_fw, output_bw)]
- return outputs
+ return (outputs, output_state_fw, output_state_bw)
def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None,