diff options
author | Jacques Pienaar <jpienaar@google.com> | 2018-03-15 12:58:08 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-15 13:03:19 -0700 |
commit | ccd8079e579604547f4b4d8a6b061cfdc6ce49bf (patch) | |
tree | 0d498e84ca32a101afcada0993a30a5e3b0452a2 /tensorflow/contrib/seq2seq | |
parent | 61032e9ca7bf9849cb65db9b646381d124080856 (diff) |
Merge changes from github.
PiperOrigin-RevId: 189231636
Diffstat (limited to 'tensorflow/contrib/seq2seq')
-rw-r--r-- | tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py index 03fe31abf7..6adbb8be40 100644 --- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py @@ -299,12 +299,13 @@ class BeamSearchDecoder(decoder.Decoder): """ finished, start_inputs = self._finished, self._start_inputs + dtype = nest.flatten(self._initial_cell_state)[0].dtype log_probs = array_ops.one_hot( # shape(batch_sz, beam_sz) array_ops.zeros([self._batch_size], dtype=dtypes.int32), depth=self._beam_width, - on_value=0.0, - off_value=-np.Inf, - dtype=nest.flatten(self._initial_cell_state)[0].dtype) + on_value=ops.convert_to_tensor(0.0, dtype=dtype), + off_value=ops.convert_to_tensor(-np.Inf, dtype=dtype), + dtype=dtype) initial_state = BeamSearchDecoderState( cell_state=self._initial_cell_state, |