aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py')
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py31
1 files changed, 15 insertions, 16 deletions
diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
index ebe25ce077..a5f7169c31 100644
--- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
+++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
@@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
import collections
-
import numpy as np
from tensorflow.contrib.seq2seq.python.ops import beam_search_ops
@@ -229,8 +228,11 @@ class BeamSearchDecoder(decoder.Decoder):
self._start_tokens = array_ops.tile(
array_ops.expand_dims(self._start_tokens, 1), [1, self._beam_width])
self._start_inputs = self._embedding_fn(self._start_tokens)
- self._finished = array_ops.zeros(
- [self._batch_size, self._beam_width], dtype=dtypes.bool)
+
+ self._finished = array_ops.one_hot(
+ array_ops.zeros([self._batch_size], dtype=dtypes.int32),
+ depth=self._beam_width, on_value=False,
+ off_value=True, dtype=dtypes.bool)
@property
def batch_size(self):
@@ -298,11 +300,15 @@ class BeamSearchDecoder(decoder.Decoder):
"""
finished, start_inputs = self._finished, self._start_inputs
+ log_probs = array_ops.one_hot( # shape(batch_sz, beam_sz)
+ array_ops.zeros([self._batch_size], dtype=dtypes.int32),
+ depth=self._beam_width, on_value=0.0, off_value=-np.Inf,
+ dtype=nest.flatten(self._initial_cell_state)[0].dtype)
+
+
initial_state = BeamSearchDecoderState(
cell_state=self._initial_cell_state,
- log_probs=array_ops.zeros(
- [self._batch_size, self._beam_width],
- dtype=nest.flatten(self._initial_cell_state)[0].dtype),
+ log_probs=log_probs,
finished=finished,
lengths=array_ops.zeros(
[self._batch_size, self._beam_width], dtype=dtypes.int64))
@@ -563,18 +569,11 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
time = ops.convert_to_tensor(time, name="time")
# During the first time step we only consider the initial beam
scores_shape = array_ops.shape(scores)
- scores_flat = control_flow_ops.cond(
- time > 0,
- lambda: array_ops.reshape(scores, [batch_size, -1]),
- lambda: scores[:, 0])
- num_available_beam = control_flow_ops.cond(
- time > 0, lambda: math_ops.reduce_prod(scores_shape[1:]),
- lambda: math_ops.reduce_prod(scores_shape[2:]))
+ scores_flat = array_ops.reshape(scores, [batch_size, -1])
# Pick the next beams according to the specified successors function
- next_beam_size = math_ops.minimum(
- ops.convert_to_tensor(beam_width, dtype=dtypes.int32, name="beam_width"),
- num_available_beam)
+ next_beam_size = ops.convert_to_tensor(beam_width, dtype=dtypes.int32,
+ name="beam_width")
next_beam_scores, word_indices = nn_ops.top_k(scores_flat, k=next_beam_size)
next_beam_scores.set_shape([static_batch_size, beam_width])