From 841646787402fe4c273857fbc59457dd15f7c102 Mon Sep 17 00:00:00 2001 From: Guillaume Klein Date: Wed, 14 Mar 2018 16:14:39 +0100 Subject: Support other dtypes in BeamSearchDecoder initialization (#17591) --- tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) (limited to 'tensorflow/contrib/seq2seq/python') diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py index 03fe31abf7..6adbb8be40 100644 --- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py @@ -299,12 +299,13 @@ class BeamSearchDecoder(decoder.Decoder): """ finished, start_inputs = self._finished, self._start_inputs + dtype = nest.flatten(self._initial_cell_state)[0].dtype log_probs = array_ops.one_hot( # shape(batch_sz, beam_sz) array_ops.zeros([self._batch_size], dtype=dtypes.int32), depth=self._beam_width, - on_value=0.0, - off_value=-np.Inf, - dtype=nest.flatten(self._initial_cell_state)[0].dtype) + on_value=ops.convert_to_tensor(0.0, dtype=dtype), + off_value=ops.convert_to_tensor(-np.Inf, dtype=dtype), + dtype=dtype) initial_state = BeamSearchDecoderState( cell_state=self._initial_cell_state, -- cgit v1.2.3