aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/seq2seq
diff options
context:
space:
mode:
authorGravatar Jianwei Xie <xiejw@google.com>2018-01-24 10:02:35 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-24 10:06:06 -0800
commitd9f93c42a50b1f1401d9c186eac0ae8dc9093c3b (patch)
tree178d1a692f56580c266139642b5a1d0d155c477e /tensorflow/contrib/seq2seq
parent7b62a71e2d46c148df7d5704972f4592bc5e0f1b (diff)
Merge changes from github.
PiperOrigin-RevId: 183100142
Diffstat (limited to 'tensorflow/contrib/seq2seq')
-rw-r--r--tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py88
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py31
2 files changed, 103 insertions, 16 deletions
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
index d2beac5f31..f498b2bb57 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
@@ -225,6 +225,94 @@ class TestBeamStep(test.TestCase):
self.assertAllEqual(next_state_.log_probs, expected_log_probs)
+class TestLargeBeamStep(test.TestCase):
+ """
+ Tests a single step of beam search in such
+ case that beam size is larger than vocabulary size.
+ """
+
+ def setUp(self):
+ super(TestLargeBeamStep, self).setUp()
+ self.batch_size = 2
+ self.beam_width = 8
+ self.vocab_size = 5
+ self.end_token = 0
+ self.length_penalty_weight = 0.6
+
+
+ def test_step(self):
+ def get_probs():
+ """this simulates the initialize method in BeamSearchDecoder"""
+ log_prob_mask = array_ops.one_hot(array_ops.zeros([self.batch_size],
+ dtype=dtypes.int32),
+ depth=self.beam_width, on_value=True,
+ off_value=False, dtype=dtypes.bool)
+
+ log_prob_zeros = array_ops.zeros([self.batch_size, self.beam_width],
+ dtype=dtypes.float32)
+ log_prob_neg_inf = array_ops.ones([self.batch_size, self.beam_width],
+ dtype=dtypes.float32) * -np.Inf
+
+ log_probs = array_ops.where(log_prob_mask, log_prob_zeros,
+ log_prob_neg_inf)
+ return log_probs
+
+ log_probs = get_probs()
+ dummy_cell_state = array_ops.zeros([self.batch_size, self.beam_width])
+
+ _finished = array_ops.one_hot(
+ array_ops.zeros([self.batch_size], dtype=dtypes.int32),
+ depth=self.beam_width, on_value=False,
+ off_value=True, dtype=dtypes.bool)
+ _lengths = np.zeros([self.batch_size, self.beam_width], dtype=np.int64)
+ _lengths[:, 0]=2
+ _lengths = constant_op.constant(_lengths, dtype=dtypes.int64)
+
+ beam_state = beam_search_decoder.BeamSearchDecoderState(
+ cell_state=dummy_cell_state,
+ log_probs=log_probs,
+ lengths=_lengths,
+ finished=_finished)
+
+ logits_ = np.full([self.batch_size, self.beam_width, self.vocab_size],
+ 0.0001)
+ logits_[0, 0, 2] = 1.9
+ logits_[0, 0, 3] = 2.1
+ logits_[0, 1, 3] = 3.1
+ logits_[0, 1, 4] = 0.9
+ logits_[1, 0, 1] = 0.5
+ logits_[1, 1, 2] = 2.7
+ logits_[1, 2, 2] = 10.0
+ logits_[1, 2, 3] = 0.2
+ logits = constant_op.constant(logits_, dtype=dtypes.float32)
+ log_probs = nn_ops.log_softmax(logits)
+
+ outputs, next_beam_state = beam_search_decoder._beam_search_step(
+ time=2,
+ logits=logits,
+ next_cell_state=dummy_cell_state,
+ beam_state=beam_state,
+ batch_size=ops.convert_to_tensor(self.batch_size),
+ beam_width=self.beam_width,
+ end_token=self.end_token,
+ length_penalty_weight=self.length_penalty_weight)
+
+ with self.test_session() as sess:
+ outputs_, next_state_, state_, log_probs_ = sess.run(
+ [outputs, next_beam_state, beam_state, log_probs])
+
+ self.assertEqual(outputs_.predicted_ids[0, 0], 3)
+ self.assertEqual(outputs_.predicted_ids[0, 1], 2)
+ self.assertEqual(outputs_.predicted_ids[1, 0], 1)
+ neg_inf = -np.Inf
+ self.assertAllEqual(next_state_.log_probs[:, -3:],
+ [[neg_inf, neg_inf, neg_inf],
+ [neg_inf, neg_inf, neg_inf]])
+ self.assertEqual((next_state_.log_probs[:, :-3] > neg_inf).all(), True)
+ self.assertEqual((next_state_.lengths[:, :-3] > 0).all(), True)
+ self.assertAllEqual(next_state_.lengths[:, -3:], [[0, 0, 0],
+ [0, 0, 0]])
+
class BeamSearchDecoderTest(test.TestCase):
def _testDynamicDecodeRNN(self, time_major, has_attention):
diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
index ebe25ce077..a5f7169c31 100644
--- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
+++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
@@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
import collections
-
import numpy as np
from tensorflow.contrib.seq2seq.python.ops import beam_search_ops
@@ -229,8 +228,11 @@ class BeamSearchDecoder(decoder.Decoder):
self._start_tokens = array_ops.tile(
array_ops.expand_dims(self._start_tokens, 1), [1, self._beam_width])
self._start_inputs = self._embedding_fn(self._start_tokens)
- self._finished = array_ops.zeros(
- [self._batch_size, self._beam_width], dtype=dtypes.bool)
+
+ self._finished = array_ops.one_hot(
+ array_ops.zeros([self._batch_size], dtype=dtypes.int32),
+ depth=self._beam_width, on_value=False,
+ off_value=True, dtype=dtypes.bool)
@property
def batch_size(self):
@@ -298,11 +300,15 @@ class BeamSearchDecoder(decoder.Decoder):
"""
finished, start_inputs = self._finished, self._start_inputs
+ log_probs = array_ops.one_hot( # shape(batch_sz, beam_sz)
+ array_ops.zeros([self._batch_size], dtype=dtypes.int32),
+ depth=self._beam_width, on_value=0.0, off_value=-np.Inf,
+ dtype=nest.flatten(self._initial_cell_state)[0].dtype)
+
+
initial_state = BeamSearchDecoderState(
cell_state=self._initial_cell_state,
- log_probs=array_ops.zeros(
- [self._batch_size, self._beam_width],
- dtype=nest.flatten(self._initial_cell_state)[0].dtype),
+ log_probs=log_probs,
finished=finished,
lengths=array_ops.zeros(
[self._batch_size, self._beam_width], dtype=dtypes.int64))
@@ -563,18 +569,11 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
time = ops.convert_to_tensor(time, name="time")
# During the first time step we only consider the initial beam
scores_shape = array_ops.shape(scores)
- scores_flat = control_flow_ops.cond(
- time > 0,
- lambda: array_ops.reshape(scores, [batch_size, -1]),
- lambda: scores[:, 0])
- num_available_beam = control_flow_ops.cond(
- time > 0, lambda: math_ops.reduce_prod(scores_shape[1:]),
- lambda: math_ops.reduce_prod(scores_shape[2:]))
+ scores_flat = array_ops.reshape(scores, [batch_size, -1])
# Pick the next beams according to the specified successors function
- next_beam_size = math_ops.minimum(
- ops.convert_to_tensor(beam_width, dtype=dtypes.int32, name="beam_width"),
- num_available_beam)
+ next_beam_size = ops.convert_to_tensor(beam_width, dtype=dtypes.int32,
+ name="beam_width")
next_beam_scores, word_indices = nn_ops.top_k(scores_flat, k=next_beam_size)
next_beam_scores.set_shape([static_batch_size, beam_width])