diff options
author | Jianwei Xie <xiejw@google.com> | 2018-01-24 10:02:35 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-01-24 10:06:06 -0800 |
commit | d9f93c42a50b1f1401d9c186eac0ae8dc9093c3b (patch) | |
tree | 178d1a692f56580c266139642b5a1d0d155c477e /tensorflow/contrib/seq2seq | |
parent | 7b62a71e2d46c148df7d5704972f4592bc5e0f1b (diff) |
Merge changes from github.
PiperOrigin-RevId: 183100142
Diffstat (limited to 'tensorflow/contrib/seq2seq')
-rw-r--r-- | tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py | 88 | ||||
-rw-r--r-- | tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py | 31 |
2 files changed, 103 insertions, 16 deletions
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py index d2beac5f31..f498b2bb57 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py @@ -225,6 +225,94 @@ class TestBeamStep(test.TestCase): self.assertAllEqual(next_state_.log_probs, expected_log_probs) +class TestLargeBeamStep(test.TestCase): + """ + Tests a single step of beam search in such + case that beam size is larger than vocabulary size. + """ + + def setUp(self): + super(TestLargeBeamStep, self).setUp() + self.batch_size = 2 + self.beam_width = 8 + self.vocab_size = 5 + self.end_token = 0 + self.length_penalty_weight = 0.6 + + + def test_step(self): + def get_probs(): + """this simulates the initialize method in BeamSearchDecoder""" + log_prob_mask = array_ops.one_hot(array_ops.zeros([self.batch_size], + dtype=dtypes.int32), + depth=self.beam_width, on_value=True, + off_value=False, dtype=dtypes.bool) + + log_prob_zeros = array_ops.zeros([self.batch_size, self.beam_width], + dtype=dtypes.float32) + log_prob_neg_inf = array_ops.ones([self.batch_size, self.beam_width], + dtype=dtypes.float32) * -np.Inf + + log_probs = array_ops.where(log_prob_mask, log_prob_zeros, + log_prob_neg_inf) + return log_probs + + log_probs = get_probs() + dummy_cell_state = array_ops.zeros([self.batch_size, self.beam_width]) + + _finished = array_ops.one_hot( + array_ops.zeros([self.batch_size], dtype=dtypes.int32), + depth=self.beam_width, on_value=False, + off_value=True, dtype=dtypes.bool) + _lengths = np.zeros([self.batch_size, self.beam_width], dtype=np.int64) + _lengths[:, 0]=2 + _lengths = constant_op.constant(_lengths, dtype=dtypes.int64) + + beam_state = beam_search_decoder.BeamSearchDecoderState( + cell_state=dummy_cell_state, + log_probs=log_probs, + lengths=_lengths, + finished=_finished) + + logits_ = np.full([self.batch_size, self.beam_width, self.vocab_size], + 0.0001) + logits_[0, 0, 2] = 1.9 + logits_[0, 0, 3] = 2.1 + logits_[0, 1, 3] = 3.1 + logits_[0, 1, 4] = 0.9 + logits_[1, 0, 1] = 0.5 + logits_[1, 1, 2] = 2.7 + logits_[1, 2, 2] = 10.0 + logits_[1, 2, 3] = 0.2 + logits = constant_op.constant(logits_, dtype=dtypes.float32) + log_probs = nn_ops.log_softmax(logits) + + outputs, next_beam_state = beam_search_decoder._beam_search_step( + time=2, + logits=logits, + next_cell_state=dummy_cell_state, + beam_state=beam_state, + batch_size=ops.convert_to_tensor(self.batch_size), + beam_width=self.beam_width, + end_token=self.end_token, + length_penalty_weight=self.length_penalty_weight) + + with self.test_session() as sess: + outputs_, next_state_, state_, log_probs_ = sess.run( + [outputs, next_beam_state, beam_state, log_probs]) + + self.assertEqual(outputs_.predicted_ids[0, 0], 3) + self.assertEqual(outputs_.predicted_ids[0, 1], 2) + self.assertEqual(outputs_.predicted_ids[1, 0], 1) + neg_inf = -np.Inf + self.assertAllEqual(next_state_.log_probs[:, -3:], + [[neg_inf, neg_inf, neg_inf], + [neg_inf, neg_inf, neg_inf]]) + self.assertEqual((next_state_.log_probs[:, :-3] > neg_inf).all(), True) + self.assertEqual((next_state_.lengths[:, :-3] > 0).all(), True) + self.assertAllEqual(next_state_.lengths[:, -3:], [[0, 0, 0], + [0, 0, 0]]) + class BeamSearchDecoderTest(test.TestCase): def _testDynamicDecodeRNN(self, time_major, has_attention): diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py index ebe25ce077..a5f7169c31 100644 --- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function import collections - import numpy as np from tensorflow.contrib.seq2seq.python.ops import beam_search_ops @@ -229,8 +228,11 @@ class BeamSearchDecoder(decoder.Decoder): self._start_tokens = array_ops.tile( array_ops.expand_dims(self._start_tokens, 1), [1, self._beam_width]) self._start_inputs = self._embedding_fn(self._start_tokens) - self._finished = array_ops.zeros( - [self._batch_size, self._beam_width], dtype=dtypes.bool) + + self._finished = array_ops.one_hot( + array_ops.zeros([self._batch_size], dtype=dtypes.int32), + depth=self._beam_width, on_value=False, + off_value=True, dtype=dtypes.bool) @property def batch_size(self): @@ -298,11 +300,15 @@ class BeamSearchDecoder(decoder.Decoder): """ finished, start_inputs = self._finished, self._start_inputs + log_probs = array_ops.one_hot( # shape(batch_sz, beam_sz) + array_ops.zeros([self._batch_size], dtype=dtypes.int32), + depth=self._beam_width, on_value=0.0, off_value=-np.Inf, + dtype=nest.flatten(self._initial_cell_state)[0].dtype) + + initial_state = BeamSearchDecoderState( cell_state=self._initial_cell_state, - log_probs=array_ops.zeros( - [self._batch_size, self._beam_width], - dtype=nest.flatten(self._initial_cell_state)[0].dtype), + log_probs=log_probs, finished=finished, lengths=array_ops.zeros( [self._batch_size, self._beam_width], dtype=dtypes.int64)) @@ -563,18 +569,11 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size, time = ops.convert_to_tensor(time, name="time") # During the first time step we only consider the initial beam scores_shape = array_ops.shape(scores) - scores_flat = control_flow_ops.cond( - time > 0, - lambda: array_ops.reshape(scores, [batch_size, -1]), - lambda: scores[:, 0]) - num_available_beam = control_flow_ops.cond( - time > 0, lambda: math_ops.reduce_prod(scores_shape[1:]), - lambda: math_ops.reduce_prod(scores_shape[2:])) + scores_flat = array_ops.reshape(scores, [batch_size, -1]) # Pick the next beams according to the specified successors function - next_beam_size = math_ops.minimum( - ops.convert_to_tensor(beam_width, dtype=dtypes.int32, name="beam_width"), - num_available_beam) + next_beam_size = ops.convert_to_tensor(beam_width, dtype=dtypes.int32, + name="beam_width") next_beam_scores, word_indices = nn_ops.top_k(scores_flat, k=next_beam_size) next_beam_scores.set_shape([static_batch_size, beam_width]) |