aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/seq2seq
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/seq2seq')
-rw-r--r--tensorflow/contrib/seq2seq/BUILD12
-rw-r--r--tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py50
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/loss.py78
3 files changed, 133 insertions, 7 deletions
diff --git a/tensorflow/contrib/seq2seq/BUILD b/tensorflow/contrib/seq2seq/BUILD
index 9566d03211..3c314e2f28 100644
--- a/tensorflow/contrib/seq2seq/BUILD
+++ b/tensorflow/contrib/seq2seq/BUILD
@@ -41,6 +41,18 @@ cuda_py_test(
)
cuda_py_test(
+ name = "loss_test",
+ size = "medium",
+ srcs = ["python/kernel_tests/loss_test.py"],
+ additional_deps = [
+ ":seq2seq_py",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+cuda_py_test(
name = "seq2seq_test",
size = "medium",
srcs = ["python/kernel_tests/seq2seq_test.py"],
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py
index f99de76f17..95560fb254 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py
@@ -20,14 +20,58 @@ from __future__ import division
from __future__ import print_function
# pylint: enable=unused-import
+import numpy as np
import tensorflow as tf
-
class LossTest(tf.test.TestCase):
- def testLoss(self):
- pass
+ def testSequenceLoss(self):
+ with self.test_session() as sess:
+ with tf.variable_scope("root",
+ initializer=tf.constant_initializer(0.5)) as varscope:
+ batch_size = 2
+ sequence_length = 3
+ number_of_classes = 5
+ logits = [tf.constant(i + 0.5, shape=[batch_size, number_of_classes])
+ for i in range(sequence_length)]
+ logits = tf.stack(logits, axis=1)
+ targets = [tf.constant(i, tf.int32, shape=[batch_size]) for i in
+ range(sequence_length)]
+ targets = tf.stack(targets, axis=1)
+ weights = [tf.constant(1.0, shape=[batch_size]) for i in
+ range(sequence_length)]
+ weights = tf.stack(weights, axis=1)
+
+ average_loss_per_example = tf.contrib.seq2seq.sequence_loss(
+ logits, targets, weights,
+ average_across_timesteps=True,
+ average_across_batch=True)
+ res = sess.run(average_loss_per_example)
+ self.assertAllClose(1.60944, res)
+
+ average_loss_per_sequence = tf.contrib.seq2seq.sequence_loss(
+ logits, targets, weights,
+ average_across_timesteps=False,
+ average_across_batch=True)
+ res = sess.run(average_loss_per_sequence)
+ compare_per_sequence = np.ones((sequence_length)) * 1.60944
+ self.assertAllClose(compare_per_sequence, res)
+
+ average_loss_per_batch = tf.contrib.seq2seq.sequence_loss(
+ logits, targets, weights,
+ average_across_timesteps=True,
+ average_across_batch=False)
+ res = sess.run(average_loss_per_batch)
+ compare_per_batch = np.ones((batch_size)) * 1.60944
+ self.assertAllClose(compare_per_batch, res)
+ total_loss = tf.contrib.seq2seq.sequence_loss(
+ logits, targets, weights,
+ average_across_timesteps=False,
+ average_across_batch=False)
+ res = sess.run(total_loss)
+ compare_total = np.ones((batch_size, sequence_length)) * 1.60944
+ self.assertAllClose(compare_total, res)
if __name__ == '__main__':
tf.test.main()
diff --git a/tensorflow/contrib/seq2seq/python/ops/loss.py b/tensorflow/contrib/seq2seq/python/ops/loss.py
index b8a33b3f6f..bb87111266 100644
--- a/tensorflow/contrib/seq2seq/python/ops/loss.py
+++ b/tensorflow/contrib/seq2seq/python/ops/loss.py
@@ -13,18 +13,88 @@
# limitations under the License.
# ==============================================================================
-"""Seq2seq loss operations for use in neural networks.
+"""Seq2seq loss operations for use in sequence models.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import math_ops
+__all__ = ["sequence_loss"]
-__all__ = ["seq2seq_loss"]
+def sequence_loss(logits, targets, weights,
+ average_across_timesteps=True, average_across_batch=True,
+ softmax_loss_function=None, name=None):
+ """Weighted cross-entropy loss for a sequence of logits (per example).
+ Args:
+ logits: A 3D Tensor of shape
+ [batch_size x sequence_length x num_decoder_symbols] and dtype float.
+ The logits correspond to the prediction across all classes at each
+ timestep.
+ targets: A 2D Tensor of shape [batch_size x sequence_length] and dtype
+ int. The target represents the true class at each timestep.
+ weights: A 2D Tensor of shape [batch_size x sequence_length] and dtype
+ float. Weights constitutes the weighting of each prediction in the
+ sequence. When using weights as masking set all valid timesteps to 1 and
+ all padded timesteps to 0.
+ average_across_timesteps: If set, sum the cost across the sequence
+ dimension and divide by the cost by the total label weight across
+ timesteps.
+ average_across_batch: If set, sum the cost across the batch dimension and
+ divide the returned cost by the batch size.
+ softmax_loss_function: Function (inputs-batch, labels-batch) -> loss-batch
+ to be used instead of the standard softmax (the default if this is None).
+ name: Optional name for this operation, defaults to "sequence_loss".
-def seq2seq_loss(*args, **kwargs):
- pass
+ Returns:
+ A scalar float Tensor: The average log-perplexity per symbol (weighted).
+
+ Raises:
+ ValueError: logits does not have 3 dimensions or targets does not have 2
+ dimensions or weights does not have 2 dimensions.
+ """
+ if len(logits.get_shape()) != 3:
+ raise ValueError("Logits must be a "
+ "[batch_size x sequence_length x logits] tensor")
+ if len(targets.get_shape()) != 2:
+ raise ValueError("Targets must be a [batch_size x sequence_length] "
+ "tensor")
+ if len(weights.get_shape()) != 2:
+ raise ValueError("Weights must be a [batch_size x sequence_length] "
+ "tensor")
+ with ops.name_scope(name, "sequence_loss", [logits, targets, weights]):
+ num_classes = array_ops.shape(logits)[2]
+ probs_flat = array_ops.reshape(logits, [-1, num_classes])
+ targets = array_ops.reshape(targets, [-1])
+ if softmax_loss_function is None:
+ crossent = nn_ops.sparse_softmax_cross_entropy_with_logits(
+ labels=targets, logits=probs_flat)
+ else:
+ crossent = softmax_loss_function(probs_flat, targets)
+ crossent = crossent * array_ops.reshape(weights, [-1])
+ if average_across_timesteps and average_across_batch:
+ crossent = math_ops.reduce_sum(crossent)
+ total_size = math_ops.reduce_sum(weights)
+ total_size += 1e-12 # to avoid division by 0 for all-0 weights
+ crossent /= total_size
+ else:
+ batch_size = array_ops.shape(logits)[0]
+ sequence_length = array_ops.shape(logits)[1]
+ crossent = array_ops.reshape(crossent, [batch_size, sequence_length])
+ if average_across_timesteps and not average_across_batch:
+ crossent = math_ops.reduce_sum(crossent, axis=[1])
+ total_size = math_ops.reduce_sum(weights, axis=[1])
+ total_size += 1e-12 # to avoid division by 0 for all-0 weights
+ crossent /= total_size
+ if not average_across_timesteps and average_across_batch:
+ crossent = math_ops.reduce_sum(crossent, axis=[0])
+ total_size = math_ops.reduce_sum(weights, axis=[0])
+ total_size += 1e-12 # to avoid division by 0 for all-0 weights
+ crossent /= total_size
+ return crossent