aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/seq2seq.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/seq2seq.py')
-rw-r--r--tensorflow/python/ops/seq2seq.py19
1 files changed, 13 insertions, 6 deletions
diff --git a/tensorflow/python/ops/seq2seq.py b/tensorflow/python/ops/seq2seq.py
index f96e00acf5..7a4b547fac 100644
--- a/tensorflow/python/ops/seq2seq.py
+++ b/tensorflow/python/ops/seq2seq.py
@@ -249,8 +249,11 @@ def embedding_rnn_decoder(decoder_inputs,
Returns:
A tuple of the form (outputs, state), where:
- outputs: A list of the same length as decoder_inputs of 2D Tensors with
- shape [batch_size x output_size] containing the generated outputs.
+ outputs: A list of the same length as decoder_inputs of 2D Tensors. The
+ output is of shape [batch_size x cell.output_size] when
+ output_projection is not None (and represents the dense representation
+ of predicted tokens). It is of shape [batch_size x num_decoder_symbols]
+ when output_projection is None.
state: The state of each decoder cell in each time-step. This is a list
with length len(decoder_inputs) -- one item for each time-step.
It is a 2D Tensor of shape [batch_size x cell.state_size].
@@ -318,9 +321,11 @@ def embedding_rnn_seq2seq(encoder_inputs,
Returns:
A tuple of the form (outputs, state), where:
- outputs: A list of the same length as decoder_inputs of 2D Tensors with
- shape [batch_size x num_decoder_symbols] containing the generated
- outputs.
+ outputs: A list of the same length as decoder_inputs of 2D Tensors. The
+ output is of shape [batch_size x cell.output_size] when
+ output_projection is not None (and represents the dense representation
+ of predicted tokens). It is of shape [batch_size x num_decoder_symbols]
+ when output_projection is None.
state: The state of each decoder cell in each time-step. This is a list
with length len(decoder_inputs) -- one item for each time-step.
It is a 2D Tensor of shape [batch_size x cell.state_size].
@@ -1082,7 +1087,9 @@ def model_with_buckets(encoder_inputs, decoder_inputs, targets, weights,
Returns:
A tuple of the form (outputs, losses), where:
outputs: The outputs for each bucket. Its j'th element consists of a list
- of 2D Tensors of shape [batch_size x num_decoder_symbols] (jth outputs).
+ of 2D Tensors. The shape of output tensors can be either
+ [batch_size x output_size] or [batch_size x num_decoder_symbols]
+ depending on the seq2seq model used.
losses: List of scalar Tensors, representing losses for each bucket, or,
if per_example_loss is set, a list of 1D batch-sized float Tensors.