aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py')
-rw-r--r--tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py88
1 files changed, 88 insertions, 0 deletions
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
index d2beac5f31..f498b2bb57 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
@@ -225,6 +225,94 @@ class TestBeamStep(test.TestCase):
self.assertAllEqual(next_state_.log_probs, expected_log_probs)
+class TestLargeBeamStep(test.TestCase):
+ """
+ Tests a single step of beam search in such
+ case that beam size is larger than vocabulary size.
+ """
+
+ def setUp(self):
+ super(TestLargeBeamStep, self).setUp()
+ self.batch_size = 2
+ self.beam_width = 8
+ self.vocab_size = 5
+ self.end_token = 0
+ self.length_penalty_weight = 0.6
+
+
+ def test_step(self):
+ def get_probs():
+ """this simulates the initialize method in BeamSearchDecoder"""
+ log_prob_mask = array_ops.one_hot(array_ops.zeros([self.batch_size],
+ dtype=dtypes.int32),
+ depth=self.beam_width, on_value=True,
+ off_value=False, dtype=dtypes.bool)
+
+ log_prob_zeros = array_ops.zeros([self.batch_size, self.beam_width],
+ dtype=dtypes.float32)
+ log_prob_neg_inf = array_ops.ones([self.batch_size, self.beam_width],
+ dtype=dtypes.float32) * -np.Inf
+
+ log_probs = array_ops.where(log_prob_mask, log_prob_zeros,
+ log_prob_neg_inf)
+ return log_probs
+
+ log_probs = get_probs()
+ dummy_cell_state = array_ops.zeros([self.batch_size, self.beam_width])
+
+ _finished = array_ops.one_hot(
+ array_ops.zeros([self.batch_size], dtype=dtypes.int32),
+ depth=self.beam_width, on_value=False,
+ off_value=True, dtype=dtypes.bool)
+ _lengths = np.zeros([self.batch_size, self.beam_width], dtype=np.int64)
+ _lengths[:, 0]=2
+ _lengths = constant_op.constant(_lengths, dtype=dtypes.int64)
+
+ beam_state = beam_search_decoder.BeamSearchDecoderState(
+ cell_state=dummy_cell_state,
+ log_probs=log_probs,
+ lengths=_lengths,
+ finished=_finished)
+
+ logits_ = np.full([self.batch_size, self.beam_width, self.vocab_size],
+ 0.0001)
+ logits_[0, 0, 2] = 1.9
+ logits_[0, 0, 3] = 2.1
+ logits_[0, 1, 3] = 3.1
+ logits_[0, 1, 4] = 0.9
+ logits_[1, 0, 1] = 0.5
+ logits_[1, 1, 2] = 2.7
+ logits_[1, 2, 2] = 10.0
+ logits_[1, 2, 3] = 0.2
+ logits = constant_op.constant(logits_, dtype=dtypes.float32)
+ log_probs = nn_ops.log_softmax(logits)
+
+ outputs, next_beam_state = beam_search_decoder._beam_search_step(
+ time=2,
+ logits=logits,
+ next_cell_state=dummy_cell_state,
+ beam_state=beam_state,
+ batch_size=ops.convert_to_tensor(self.batch_size),
+ beam_width=self.beam_width,
+ end_token=self.end_token,
+ length_penalty_weight=self.length_penalty_weight)
+
+ with self.test_session() as sess:
+ outputs_, next_state_, state_, log_probs_ = sess.run(
+ [outputs, next_beam_state, beam_state, log_probs])
+
+ self.assertEqual(outputs_.predicted_ids[0, 0], 3)
+ self.assertEqual(outputs_.predicted_ids[0, 1], 2)
+ self.assertEqual(outputs_.predicted_ids[1, 0], 1)
+ neg_inf = -np.Inf
+ self.assertAllEqual(next_state_.log_probs[:, -3:],
+ [[neg_inf, neg_inf, neg_inf],
+ [neg_inf, neg_inf, neg_inf]])
+ self.assertEqual((next_state_.log_probs[:, :-3] > neg_inf).all(), True)
+ self.assertEqual((next_state_.lengths[:, :-3] > 0).all(), True)
+ self.assertAllEqual(next_state_.lengths[:, -3:], [[0, 0, 0],
+ [0, 0, 0]])
+
class BeamSearchDecoderTest(test.TestCase):
def _testDynamicDecodeRNN(self, time_major, has_attention):