diff options
Diffstat (limited to 'tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py')
-rw-r--r-- | tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py | 88 |
1 files changed, 88 insertions, 0 deletions
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py index d2beac5f31..f498b2bb57 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py @@ -225,6 +225,94 @@ class TestBeamStep(test.TestCase): self.assertAllEqual(next_state_.log_probs, expected_log_probs) +class TestLargeBeamStep(test.TestCase): + """ + Tests a single step of beam search in such + case that beam size is larger than vocabulary size. + """ + + def setUp(self): + super(TestLargeBeamStep, self).setUp() + self.batch_size = 2 + self.beam_width = 8 + self.vocab_size = 5 + self.end_token = 0 + self.length_penalty_weight = 0.6 + + + def test_step(self): + def get_probs(): + """this simulates the initialize method in BeamSearchDecoder""" + log_prob_mask = array_ops.one_hot(array_ops.zeros([self.batch_size], + dtype=dtypes.int32), + depth=self.beam_width, on_value=True, + off_value=False, dtype=dtypes.bool) + + log_prob_zeros = array_ops.zeros([self.batch_size, self.beam_width], + dtype=dtypes.float32) + log_prob_neg_inf = array_ops.ones([self.batch_size, self.beam_width], + dtype=dtypes.float32) * -np.Inf + + log_probs = array_ops.where(log_prob_mask, log_prob_zeros, + log_prob_neg_inf) + return log_probs + + log_probs = get_probs() + dummy_cell_state = array_ops.zeros([self.batch_size, self.beam_width]) + + _finished = array_ops.one_hot( + array_ops.zeros([self.batch_size], dtype=dtypes.int32), + depth=self.beam_width, on_value=False, + off_value=True, dtype=dtypes.bool) + _lengths = np.zeros([self.batch_size, self.beam_width], dtype=np.int64) + _lengths[:, 0]=2 + _lengths = constant_op.constant(_lengths, dtype=dtypes.int64) + + beam_state = beam_search_decoder.BeamSearchDecoderState( + cell_state=dummy_cell_state, + log_probs=log_probs, + lengths=_lengths, + finished=_finished) + + logits_ = np.full([self.batch_size, self.beam_width, self.vocab_size], + 0.0001) + logits_[0, 0, 2] = 1.9 + logits_[0, 0, 3] = 2.1 + logits_[0, 1, 3] = 3.1 + logits_[0, 1, 4] = 0.9 + logits_[1, 0, 1] = 0.5 + logits_[1, 1, 2] = 2.7 + logits_[1, 2, 2] = 10.0 + logits_[1, 2, 3] = 0.2 + logits = constant_op.constant(logits_, dtype=dtypes.float32) + log_probs = nn_ops.log_softmax(logits) + + outputs, next_beam_state = beam_search_decoder._beam_search_step( + time=2, + logits=logits, + next_cell_state=dummy_cell_state, + beam_state=beam_state, + batch_size=ops.convert_to_tensor(self.batch_size), + beam_width=self.beam_width, + end_token=self.end_token, + length_penalty_weight=self.length_penalty_weight) + + with self.test_session() as sess: + outputs_, next_state_, state_, log_probs_ = sess.run( + [outputs, next_beam_state, beam_state, log_probs]) + + self.assertEqual(outputs_.predicted_ids[0, 0], 3) + self.assertEqual(outputs_.predicted_ids[0, 1], 2) + self.assertEqual(outputs_.predicted_ids[1, 0], 1) + neg_inf = -np.Inf + self.assertAllEqual(next_state_.log_probs[:, -3:], + [[neg_inf, neg_inf, neg_inf], + [neg_inf, neg_inf, neg_inf]]) + self.assertEqual((next_state_.log_probs[:, :-3] > neg_inf).all(), True) + self.assertEqual((next_state_.lengths[:, :-3] > 0).all(), True) + self.assertAllEqual(next_state_.lengths[:, -3:], [[0, 0, 0], + [0, 0, 0]]) + class BeamSearchDecoderTest(test.TestCase): def _testDynamicDecodeRNN(self, time_major, has_attention): |