aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/seq2seq/python/ops/loss.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/seq2seq/python/ops/loss.py')
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/loss.py39
1 files changed, 26 insertions, 13 deletions
diff --git a/tensorflow/contrib/seq2seq/python/ops/loss.py b/tensorflow/contrib/seq2seq/python/ops/loss.py
index cfe6ac5134..39a6d2f58b 100644
--- a/tensorflow/contrib/seq2seq/python/ops/loss.py
+++ b/tensorflow/contrib/seq2seq/python/ops/loss.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""Seq2seq loss operations for use in sequence models.
"""
@@ -28,22 +27,33 @@ from tensorflow.python.ops import nn_ops
__all__ = ["sequence_loss"]
-def sequence_loss(logits, targets, weights,
- average_across_timesteps=True, average_across_batch=True,
- softmax_loss_function=None, name=None):
- """Weighted cross-entropy loss for a sequence of logits (per example).
+def sequence_loss(logits,
+ targets,
+ weights,
+ average_across_timesteps=True,
+ average_across_batch=True,
+ softmax_loss_function=None,
+ name=None):
+ """Weighted cross-entropy loss for a sequence of logits.
+
+ Depending on the values of `average_across_timesteps` and
+ `average_across_batch`, the return Tensor will have rank 0, 1, or 2 as these
+ arguments reduce the cross-entropy at each target, which has shape
+ `[batch_size, sequence_length]`, over their respective dimensions. For
+ example, if `average_across_timesteps` is `True` and `average_across_batch`
+ is `False`, then the return Tensor will have shape `[batch_size]`.
Args:
- logits: A 3D Tensor of shape
- [batch_size x sequence_length x num_decoder_symbols] and dtype float.
+ logits: A Tensor of shape
+ `[batch_size, sequence_length, num_decoder_symbols]` and dtype float.
The logits correspond to the prediction across all classes at each
timestep.
- targets: A 2D Tensor of shape [batch_size x sequence_length] and dtype
+ targets: A Tensor of shape `[batch_size, sequence_length]` and dtype
int. The target represents the true class at each timestep.
- weights: A 2D Tensor of shape [batch_size x sequence_length] and dtype
- float. Weights constitutes the weighting of each prediction in the
- sequence. When using weights as masking set all valid timesteps to 1 and
- all padded timesteps to 0.
+ weights: A Tensor of shape `[batch_size, sequence_length]` and dtype
+ float. `weights` constitutes the weighting of each prediction in the
+ sequence. When using `weights` as masking, set all valid timesteps to 1
+ and all padded timesteps to 0, e.g. a mask returned by `tf.sequence_mask`.
average_across_timesteps: If set, sum the cost across the sequence
dimension and divide the cost by the total label weight across timesteps.
average_across_batch: If set, sum the cost across the batch dimension and
@@ -55,7 +65,10 @@ def sequence_loss(logits, targets, weights,
name: Optional name for this operation, defaults to "sequence_loss".
Returns:
- A scalar float Tensor: The average log-perplexity per symbol (weighted).
+ A float Tensor of rank 0, 1, or 2 depending on the
+ `average_across_timesteps` and `average_across_batch` arguments. By default,
+ it has rank 0 (scalar) and is the weighted average cross-entropy
+ (log-perplexity) per symbol.
Raises:
ValueError: logits does not have 3 dimensions or targets does not have 2