diff options
Diffstat (limited to 'tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py')
-rw-r--r-- | tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py | 50 |
1 files changed, 47 insertions, 3 deletions
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py index f99de76f17..95560fb254 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py @@ -20,14 +20,58 @@ from __future__ import division from __future__ import print_function # pylint: enable=unused-import +import numpy as np import tensorflow as tf - class LossTest(tf.test.TestCase): - def testLoss(self): - pass + def testSequenceLoss(self): + with self.test_session() as sess: + with tf.variable_scope("root", + initializer=tf.constant_initializer(0.5)) as varscope: + batch_size = 2 + sequence_length = 3 + number_of_classes = 5 + logits = [tf.constant(i + 0.5, shape=[batch_size, number_of_classes]) + for i in range(sequence_length)] + logits = tf.stack(logits, axis=1) + targets = [tf.constant(i, tf.int32, shape=[batch_size]) for i in + range(sequence_length)] + targets = tf.stack(targets, axis=1) + weights = [tf.constant(1.0, shape=[batch_size]) for i in + range(sequence_length)] + weights = tf.stack(weights, axis=1) + + average_loss_per_example = tf.contrib.seq2seq.sequence_loss( + logits, targets, weights, + average_across_timesteps=True, + average_across_batch=True) + res = sess.run(average_loss_per_example) + self.assertAllClose(1.60944, res) + + average_loss_per_sequence = tf.contrib.seq2seq.sequence_loss( + logits, targets, weights, + average_across_timesteps=False, + average_across_batch=True) + res = sess.run(average_loss_per_sequence) + compare_per_sequence = np.ones((sequence_length)) * 1.60944 + self.assertAllClose(compare_per_sequence, res) + + average_loss_per_batch = tf.contrib.seq2seq.sequence_loss( + logits, targets, weights, + average_across_timesteps=True, + average_across_batch=False) + res = sess.run(average_loss_per_batch) + compare_per_batch = np.ones((batch_size)) * 1.60944 + self.assertAllClose(compare_per_batch, res) + total_loss = tf.contrib.seq2seq.sequence_loss( + logits, targets, weights, + average_across_timesteps=False, + average_across_batch=False) + res = sess.run(total_loss) + compare_total = np.ones((batch_size, sequence_length)) * 1.60944 + self.assertAllClose(compare_total, res) if __name__ == '__main__': tf.test.main() |