diff options
Diffstat (limited to 'tensorflow/contrib/seq2seq/python/ops/loss.py')
-rw-r--r-- | tensorflow/contrib/seq2seq/python/ops/loss.py | 39 |
1 files changed, 26 insertions, 13 deletions
diff --git a/tensorflow/contrib/seq2seq/python/ops/loss.py b/tensorflow/contrib/seq2seq/python/ops/loss.py index cfe6ac5134..39a6d2f58b 100644 --- a/tensorflow/contrib/seq2seq/python/ops/loss.py +++ b/tensorflow/contrib/seq2seq/python/ops/loss.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Seq2seq loss operations for use in sequence models. """ @@ -28,22 +27,33 @@ from tensorflow.python.ops import nn_ops __all__ = ["sequence_loss"] -def sequence_loss(logits, targets, weights, - average_across_timesteps=True, average_across_batch=True, - softmax_loss_function=None, name=None): - """Weighted cross-entropy loss for a sequence of logits (per example). +def sequence_loss(logits, + targets, + weights, + average_across_timesteps=True, + average_across_batch=True, + softmax_loss_function=None, + name=None): + """Weighted cross-entropy loss for a sequence of logits. + + Depending on the values of `average_across_timesteps` and + `average_across_batch`, the return Tensor will have rank 0, 1, or 2 as these + arguments reduce the cross-entropy at each target, which has shape + `[batch_size, sequence_length]`, over their respective dimensions. For + example, if `average_across_timesteps` is `True` and `average_across_batch` + is `False`, then the return Tensor will have shape `[batch_size]`. Args: - logits: A 3D Tensor of shape - [batch_size x sequence_length x num_decoder_symbols] and dtype float. + logits: A Tensor of shape + `[batch_size, sequence_length, num_decoder_symbols]` and dtype float. The logits correspond to the prediction across all classes at each timestep. - targets: A 2D Tensor of shape [batch_size x sequence_length] and dtype + targets: A Tensor of shape `[batch_size, sequence_length]` and dtype int. The target represents the true class at each timestep. - weights: A 2D Tensor of shape [batch_size x sequence_length] and dtype - float. Weights constitutes the weighting of each prediction in the - sequence. When using weights as masking set all valid timesteps to 1 and - all padded timesteps to 0. + weights: A Tensor of shape `[batch_size, sequence_length]` and dtype + float. `weights` constitutes the weighting of each prediction in the + sequence. When using `weights` as masking, set all valid timesteps to 1 + and all padded timesteps to 0, e.g. a mask returned by `tf.sequence_mask`. average_across_timesteps: If set, sum the cost across the sequence dimension and divide the cost by the total label weight across timesteps. average_across_batch: If set, sum the cost across the batch dimension and @@ -55,7 +65,10 @@ def sequence_loss(logits, targets, weights, name: Optional name for this operation, defaults to "sequence_loss". Returns: - A scalar float Tensor: The average log-perplexity per symbol (weighted). + A float Tensor of rank 0, 1, or 2 depending on the + `average_across_timesteps` and `average_across_batch` arguments. By default, + it has rank 0 (scalar) and is the weighted average cross-entropy + (log-perplexity) per symbol. Raises: ValueError: logits does not have 3 dimensions or targets does not have 2 |