diff options
Diffstat (limited to 'tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py')
-rw-r--r-- | tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py | 31 |
1 files changed, 15 insertions, 16 deletions
diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py index ebe25ce077..a5f7169c31 100644 --- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function import collections - import numpy as np from tensorflow.contrib.seq2seq.python.ops import beam_search_ops @@ -229,8 +228,11 @@ class BeamSearchDecoder(decoder.Decoder): self._start_tokens = array_ops.tile( array_ops.expand_dims(self._start_tokens, 1), [1, self._beam_width]) self._start_inputs = self._embedding_fn(self._start_tokens) - self._finished = array_ops.zeros( - [self._batch_size, self._beam_width], dtype=dtypes.bool) + + self._finished = array_ops.one_hot( + array_ops.zeros([self._batch_size], dtype=dtypes.int32), + depth=self._beam_width, on_value=False, + off_value=True, dtype=dtypes.bool) @property def batch_size(self): @@ -298,11 +300,15 @@ class BeamSearchDecoder(decoder.Decoder): """ finished, start_inputs = self._finished, self._start_inputs + log_probs = array_ops.one_hot( # shape(batch_sz, beam_sz) + array_ops.zeros([self._batch_size], dtype=dtypes.int32), + depth=self._beam_width, on_value=0.0, off_value=-np.Inf, + dtype=nest.flatten(self._initial_cell_state)[0].dtype) + + initial_state = BeamSearchDecoderState( cell_state=self._initial_cell_state, - log_probs=array_ops.zeros( - [self._batch_size, self._beam_width], - dtype=nest.flatten(self._initial_cell_state)[0].dtype), + log_probs=log_probs, finished=finished, lengths=array_ops.zeros( [self._batch_size, self._beam_width], dtype=dtypes.int64)) @@ -563,18 +569,11 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size, time = ops.convert_to_tensor(time, name="time") # During the first time step we only consider the initial beam scores_shape = array_ops.shape(scores) - scores_flat = control_flow_ops.cond( - time > 0, - lambda: array_ops.reshape(scores, [batch_size, -1]), - lambda: scores[:, 0]) - num_available_beam = control_flow_ops.cond( - time > 0, lambda: math_ops.reduce_prod(scores_shape[1:]), - lambda: math_ops.reduce_prod(scores_shape[2:])) + scores_flat = array_ops.reshape(scores, [batch_size, -1]) # Pick the next beams according to the specified successors function - next_beam_size = math_ops.minimum( - ops.convert_to_tensor(beam_width, dtype=dtypes.int32, name="beam_width"), - num_available_beam) + next_beam_size = ops.convert_to_tensor(beam_width, dtype=dtypes.int32, + name="beam_width") next_beam_scores, word_indices = nn_ops.top_k(scores_flat, k=next_beam_size) next_beam_scores.set_shape([static_batch_size, beam_width]) |