diff options
Diffstat (limited to 'tensorflow/python/ops/seq2seq.py')
-rw-r--r-- | tensorflow/python/ops/seq2seq.py | 19 |
1 files changed, 13 insertions, 6 deletions
diff --git a/tensorflow/python/ops/seq2seq.py b/tensorflow/python/ops/seq2seq.py index f96e00acf5..7a4b547fac 100644 --- a/tensorflow/python/ops/seq2seq.py +++ b/tensorflow/python/ops/seq2seq.py @@ -249,8 +249,11 @@ def embedding_rnn_decoder(decoder_inputs, Returns: A tuple of the form (outputs, state), where: - outputs: A list of the same length as decoder_inputs of 2D Tensors with - shape [batch_size x output_size] containing the generated outputs. + outputs: A list of the same length as decoder_inputs of 2D Tensors. The + output is of shape [batch_size x cell.output_size] when + output_projection is not None (and represents the dense representation + of predicted tokens). It is of shape [batch_size x num_decoder_symbols] + when output_projection is None. state: The state of each decoder cell in each time-step. This is a list with length len(decoder_inputs) -- one item for each time-step. It is a 2D Tensor of shape [batch_size x cell.state_size]. @@ -318,9 +321,11 @@ def embedding_rnn_seq2seq(encoder_inputs, Returns: A tuple of the form (outputs, state), where: - outputs: A list of the same length as decoder_inputs of 2D Tensors with - shape [batch_size x num_decoder_symbols] containing the generated - outputs. + outputs: A list of the same length as decoder_inputs of 2D Tensors. The + output is of shape [batch_size x cell.output_size] when + output_projection is not None (and represents the dense representation + of predicted tokens). It is of shape [batch_size x num_decoder_symbols] + when output_projection is None. state: The state of each decoder cell in each time-step. This is a list with length len(decoder_inputs) -- one item for each time-step. It is a 2D Tensor of shape [batch_size x cell.state_size]. @@ -1082,7 +1087,9 @@ def model_with_buckets(encoder_inputs, decoder_inputs, targets, weights, Returns: A tuple of the form (outputs, losses), where: outputs: The outputs for each bucket. Its j'th element consists of a list - of 2D Tensors of shape [batch_size x num_decoder_symbols] (jth outputs). + of 2D Tensors. The shape of output tensors can be either + [batch_size x output_size] or [batch_size x num_decoder_symbols] + depending on the seq2seq model used. losses: List of scalar Tensors, representing losses for each bucket, or, if per_example_loss is set, a list of 1D batch-sized float Tensors. |