aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py')
-rw-r--r--tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py121
1 files changed, 106 insertions, 15 deletions
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
index d2beac5f31..9265540317 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
@@ -46,20 +46,18 @@ class TestGatherTree(test.TestCase):
# create (batch_size, max_time, beam_width) matrix and transpose it
predicted_ids = np.array(
- [[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
- [[2, 3, 4], [5, 6, 7], [8, 9, 10]]],
+ [[[1, 2, 3], [4, 5, 6], [7, 8, 9]], [[2, 3, 4], [5, 6, 7], [8, 9, 10]]],
dtype=np.int32).transpose([1, 0, 2])
parent_ids = np.array(
- [[[0, 0, 0], [0, 1, 1], [2, 1, 2]],
- [[0, 0, 0], [1, 2, 0], [2, 1, 1]]],
+ [[[0, 0, 0], [0, 1, 1], [2, 1, 2]], [[0, 0, 0], [1, 2, 0], [2, 1, 1]]],
dtype=np.int32).transpose([1, 0, 2])
# sequence_lengths is shaped (batch_size = 3)
max_sequence_lengths = [3, 3]
- expected_result = np.array(
- [[[2, 2, 2], [6, 5, 6], [7, 8, 9]],
- [[2, 4, 4], [7, 6, 6], [8, 9, 10]]]).transpose([1, 0, 2])
+ expected_result = np.array([[[2, 2, 2], [6, 5, 6], [7, 8, 9]],
+ [[2, 4, 4], [7, 6, 6],
+ [8, 9, 10]]]).transpose([1, 0, 2])
res = beam_search_ops.gather_tree(
predicted_ids,
@@ -157,8 +155,8 @@ class TestBeamStep(test.TestCase):
self.assertAllEqual(outputs_.predicted_ids, [[3, 3, 2], [2, 2, 1]])
self.assertAllEqual(outputs_.parent_ids, [[1, 0, 0], [2, 1, 0]])
self.assertAllEqual(next_state_.lengths, [[3, 3, 3], [3, 3, 3]])
- self.assertAllEqual(next_state_.finished, [[False, False, False],
- [False, False, False]])
+ self.assertAllEqual(next_state_.finished,
+ [[False, False, False], [False, False, False]])
expected_log_probs = []
expected_log_probs.append(state_.log_probs[0][[1, 0, 0]])
@@ -212,8 +210,8 @@ class TestBeamStep(test.TestCase):
self.assertAllEqual(outputs_.parent_ids, [[1, 0, 0], [1, 2, 0]])
self.assertAllEqual(outputs_.predicted_ids, [[0, 3, 2], [2, 0, 1]])
self.assertAllEqual(next_state_.lengths, [[1, 3, 3], [3, 1, 3]])
- self.assertAllEqual(next_state_.finished, [[True, False, False],
- [False, True, False]])
+ self.assertAllEqual(next_state_.finished,
+ [[True, False, False], [False, True, False]])
expected_log_probs = []
expected_log_probs.append(state_.log_probs[0][[1, 0, 0]])
@@ -225,6 +223,100 @@ class TestBeamStep(test.TestCase):
self.assertAllEqual(next_state_.log_probs, expected_log_probs)
+class TestLargeBeamStep(test.TestCase):
+ """Tests large beam step.
+
+ Tests a single step of beam search in such case that beam size is larger than
+ vocabulary size.
+ """
+
+ def setUp(self):
+ super(TestLargeBeamStep, self).setUp()
+ self.batch_size = 2
+ self.beam_width = 8
+ self.vocab_size = 5
+ self.end_token = 0
+ self.length_penalty_weight = 0.6
+
+ def test_step(self):
+
+ def get_probs():
+ """this simulates the initialize method in BeamSearchDecoder."""
+ log_prob_mask = array_ops.one_hot(
+ array_ops.zeros([self.batch_size], dtype=dtypes.int32),
+ depth=self.beam_width,
+ on_value=True,
+ off_value=False,
+ dtype=dtypes.bool)
+
+ log_prob_zeros = array_ops.zeros(
+ [self.batch_size, self.beam_width], dtype=dtypes.float32)
+ log_prob_neg_inf = array_ops.ones(
+ [self.batch_size, self.beam_width], dtype=dtypes.float32) * -np.Inf
+
+ log_probs = array_ops.where(log_prob_mask, log_prob_zeros,
+ log_prob_neg_inf)
+ return log_probs
+
+ log_probs = get_probs()
+ dummy_cell_state = array_ops.zeros([self.batch_size, self.beam_width])
+
+ # pylint: disable=invalid-name
+ _finished = array_ops.one_hot(
+ array_ops.zeros([self.batch_size], dtype=dtypes.int32),
+ depth=self.beam_width,
+ on_value=False,
+ off_value=True,
+ dtype=dtypes.bool)
+ _lengths = np.zeros([self.batch_size, self.beam_width], dtype=np.int64)
+ _lengths[:, 0] = 2
+ _lengths = constant_op.constant(_lengths, dtype=dtypes.int64)
+
+ beam_state = beam_search_decoder.BeamSearchDecoderState(
+ cell_state=dummy_cell_state,
+ log_probs=log_probs,
+ lengths=_lengths,
+ finished=_finished)
+
+ logits_ = np.full([self.batch_size, self.beam_width, self.vocab_size],
+ 0.0001)
+ logits_[0, 0, 2] = 1.9
+ logits_[0, 0, 3] = 2.1
+ logits_[0, 1, 3] = 3.1
+ logits_[0, 1, 4] = 0.9
+ logits_[1, 0, 1] = 0.5
+ logits_[1, 1, 2] = 2.7
+ logits_[1, 2, 2] = 10.0
+ logits_[1, 2, 3] = 0.2
+ logits = constant_op.constant(logits_, dtype=dtypes.float32)
+ log_probs = nn_ops.log_softmax(logits)
+
+ outputs, next_beam_state = beam_search_decoder._beam_search_step(
+ time=2,
+ logits=logits,
+ next_cell_state=dummy_cell_state,
+ beam_state=beam_state,
+ batch_size=ops.convert_to_tensor(self.batch_size),
+ beam_width=self.beam_width,
+ end_token=self.end_token,
+ length_penalty_weight=self.length_penalty_weight)
+
+ with self.test_session() as sess:
+ outputs_, next_state_, _, _ = sess.run(
+ [outputs, next_beam_state, beam_state, log_probs])
+
+ self.assertEqual(outputs_.predicted_ids[0, 0], 3)
+ self.assertEqual(outputs_.predicted_ids[0, 1], 2)
+ self.assertEqual(outputs_.predicted_ids[1, 0], 1)
+ neg_inf = -np.Inf
+ self.assertAllEqual(
+ next_state_.log_probs[:, -3:],
+ [[neg_inf, neg_inf, neg_inf], [neg_inf, neg_inf, neg_inf]])
+ self.assertEqual((next_state_.log_probs[:, :-3] > neg_inf).all(), True)
+ self.assertEqual((next_state_.lengths[:, :-3] > 0).all(), True)
+ self.assertAllEqual(next_state_.lengths[:, -3:], [[0, 0, 0], [0, 0, 0]])
+
+
class BeamSearchDecoderTest(test.TestCase):
def _testDynamicDecodeRNN(self, time_major, has_attention):
@@ -250,8 +342,8 @@ class BeamSearchDecoderTest(test.TestCase):
initial_state = cell.zero_state(batch_size, dtypes.float32)
if has_attention:
inputs = array_ops.placeholder_with_default(
- np.random.randn(batch_size, decoder_max_time,
- input_depth).astype(np.float32),
+ np.random.randn(batch_size, decoder_max_time, input_depth).astype(
+ np.float32),
shape=(None, None, input_depth))
tiled_inputs = beam_search_decoder.tile_batch(
inputs, multiplier=beam_width)
@@ -271,8 +363,7 @@ class BeamSearchDecoderTest(test.TestCase):
cell_state = cell.zero_state(
dtype=dtypes.float32, batch_size=batch_size_tensor * beam_width)
if has_attention:
- cell_state = cell_state.clone(
- cell_state=initial_state)
+ cell_state = cell_state.clone(cell_state=initial_state)
bsd = beam_search_decoder.BeamSearchDecoder(
cell=cell,
embedding=embedding,