aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py')
-rw-r--r--tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py50
1 files changed, 47 insertions, 3 deletions
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py
index f99de76f17..95560fb254 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py
@@ -20,14 +20,58 @@ from __future__ import division
from __future__ import print_function
# pylint: enable=unused-import
+import numpy as np
import tensorflow as tf
-
class LossTest(tf.test.TestCase):
- def testLoss(self):
- pass
+ def testSequenceLoss(self):
+ with self.test_session() as sess:
+ with tf.variable_scope("root",
+ initializer=tf.constant_initializer(0.5)) as varscope:
+ batch_size = 2
+ sequence_length = 3
+ number_of_classes = 5
+ logits = [tf.constant(i + 0.5, shape=[batch_size, number_of_classes])
+ for i in range(sequence_length)]
+ logits = tf.stack(logits, axis=1)
+ targets = [tf.constant(i, tf.int32, shape=[batch_size]) for i in
+ range(sequence_length)]
+ targets = tf.stack(targets, axis=1)
+ weights = [tf.constant(1.0, shape=[batch_size]) for i in
+ range(sequence_length)]
+ weights = tf.stack(weights, axis=1)
+
+ average_loss_per_example = tf.contrib.seq2seq.sequence_loss(
+ logits, targets, weights,
+ average_across_timesteps=True,
+ average_across_batch=True)
+ res = sess.run(average_loss_per_example)
+ self.assertAllClose(1.60944, res)
+
+ average_loss_per_sequence = tf.contrib.seq2seq.sequence_loss(
+ logits, targets, weights,
+ average_across_timesteps=False,
+ average_across_batch=True)
+ res = sess.run(average_loss_per_sequence)
+ compare_per_sequence = np.ones((sequence_length)) * 1.60944
+ self.assertAllClose(compare_per_sequence, res)
+
+ average_loss_per_batch = tf.contrib.seq2seq.sequence_loss(
+ logits, targets, weights,
+ average_across_timesteps=True,
+ average_across_batch=False)
+ res = sess.run(average_loss_per_batch)
+ compare_per_batch = np.ones((batch_size)) * 1.60944
+ self.assertAllClose(compare_per_batch, res)
+ total_loss = tf.contrib.seq2seq.sequence_loss(
+ logits, targets, weights,
+ average_across_timesteps=False,
+ average_across_batch=False)
+ res = sess.run(total_loss)
+ compare_total = np.ones((batch_size, sequence_length)) * 1.60944
+ self.assertAllClose(compare_total, res)
if __name__ == '__main__':
tf.test.main()