aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py')
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
index 03fe31abf7..6adbb8be40 100644
--- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
+++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
@@ -299,12 +299,13 @@ class BeamSearchDecoder(decoder.Decoder):
"""
finished, start_inputs = self._finished, self._start_inputs
+ dtype = nest.flatten(self._initial_cell_state)[0].dtype
log_probs = array_ops.one_hot( # shape(batch_sz, beam_sz)
array_ops.zeros([self._batch_size], dtype=dtypes.int32),
depth=self._beam_width,
- on_value=0.0,
- off_value=-np.Inf,
- dtype=nest.flatten(self._initial_cell_state)[0].dtype)
+ on_value=ops.convert_to_tensor(0.0, dtype=dtype),
+ off_value=ops.convert_to_tensor(-np.Inf, dtype=dtype),
+ dtype=dtype)
initial_state = BeamSearchDecoderState(
cell_state=self._initial_cell_state,