aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/seq2seq/python/ops/loss.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/seq2seq/python/ops/loss.py')
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/loss.py78
1 files changed, 74 insertions, 4 deletions
diff --git a/tensorflow/contrib/seq2seq/python/ops/loss.py b/tensorflow/contrib/seq2seq/python/ops/loss.py
index b8a33b3f6f..bb87111266 100644
--- a/tensorflow/contrib/seq2seq/python/ops/loss.py
+++ b/tensorflow/contrib/seq2seq/python/ops/loss.py
@@ -13,18 +13,88 @@
# limitations under the License.
# ==============================================================================
-"""Seq2seq loss operations for use in neural networks.
+"""Seq2seq loss operations for use in sequence models.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import math_ops
+__all__ = ["sequence_loss"]
-__all__ = ["seq2seq_loss"]
+def sequence_loss(logits, targets, weights,
+ average_across_timesteps=True, average_across_batch=True,
+ softmax_loss_function=None, name=None):
+ """Weighted cross-entropy loss for a sequence of logits (per example).
+ Args:
+ logits: A 3D Tensor of shape
+ [batch_size x sequence_length x num_decoder_symbols] and dtype float.
+ The logits correspond to the prediction across all classes at each
+ timestep.
+ targets: A 2D Tensor of shape [batch_size x sequence_length] and dtype
+ int. The target represents the true class at each timestep.
+ weights: A 2D Tensor of shape [batch_size x sequence_length] and dtype
+ float. Weights constitutes the weighting of each prediction in the
+ sequence. When using weights as masking set all valid timesteps to 1 and
+ all padded timesteps to 0.
+ average_across_timesteps: If set, sum the cost across the sequence
+ dimension and divide by the cost by the total label weight across
+ timesteps.
+ average_across_batch: If set, sum the cost across the batch dimension and
+ divide the returned cost by the batch size.
+ softmax_loss_function: Function (inputs-batch, labels-batch) -> loss-batch
+ to be used instead of the standard softmax (the default if this is None).
+ name: Optional name for this operation, defaults to "sequence_loss".
-def seq2seq_loss(*args, **kwargs):
- pass
+ Returns:
+ A scalar float Tensor: The average log-perplexity per symbol (weighted).
+
+ Raises:
+ ValueError: logits does not have 3 dimensions or targets does not have 2
+ dimensions or weights does not have 2 dimensions.
+ """
+ if len(logits.get_shape()) != 3:
+ raise ValueError("Logits must be a "
+ "[batch_size x sequence_length x logits] tensor")
+ if len(targets.get_shape()) != 2:
+ raise ValueError("Targets must be a [batch_size x sequence_length] "
+ "tensor")
+ if len(weights.get_shape()) != 2:
+ raise ValueError("Weights must be a [batch_size x sequence_length] "
+ "tensor")
+ with ops.name_scope(name, "sequence_loss", [logits, targets, weights]):
+ num_classes = array_ops.shape(logits)[2]
+ probs_flat = array_ops.reshape(logits, [-1, num_classes])
+ targets = array_ops.reshape(targets, [-1])
+ if softmax_loss_function is None:
+ crossent = nn_ops.sparse_softmax_cross_entropy_with_logits(
+ labels=targets, logits=probs_flat)
+ else:
+ crossent = softmax_loss_function(probs_flat, targets)
+ crossent = crossent * array_ops.reshape(weights, [-1])
+ if average_across_timesteps and average_across_batch:
+ crossent = math_ops.reduce_sum(crossent)
+ total_size = math_ops.reduce_sum(weights)
+ total_size += 1e-12 # to avoid division by 0 for all-0 weights
+ crossent /= total_size
+ else:
+ batch_size = array_ops.shape(logits)[0]
+ sequence_length = array_ops.shape(logits)[1]
+ crossent = array_ops.reshape(crossent, [batch_size, sequence_length])
+ if average_across_timesteps and not average_across_batch:
+ crossent = math_ops.reduce_sum(crossent, axis=[1])
+ total_size = math_ops.reduce_sum(weights, axis=[1])
+ total_size += 1e-12 # to avoid division by 0 for all-0 weights
+ crossent /= total_size
+ if not average_across_timesteps and average_across_batch:
+ crossent = math_ops.reduce_sum(crossent, axis=[0])
+ total_size = math_ops.reduce_sum(weights, axis=[0])
+ total_size += 1e-12 # to avoid division by 0 for all-0 weights
+ crossent /= total_size
+ return crossent