aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/seq2seq
diff options
context:
space:
mode:
authorGravatar Jacques Pienaar <jpienaar@google.com>2018-03-15 12:58:08 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-15 13:03:19 -0700
commitccd8079e579604547f4b4d8a6b061cfdc6ce49bf (patch)
tree0d498e84ca32a101afcada0993a30a5e3b0452a2 /tensorflow/contrib/seq2seq
parent61032e9ca7bf9849cb65db9b646381d124080856 (diff)
Merge changes from github.
PiperOrigin-RevId: 189231636
Diffstat (limited to 'tensorflow/contrib/seq2seq')
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
index 03fe31abf7..6adbb8be40 100644
--- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
+++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
@@ -299,12 +299,13 @@ class BeamSearchDecoder(decoder.Decoder):
"""
finished, start_inputs = self._finished, self._start_inputs
+ dtype = nest.flatten(self._initial_cell_state)[0].dtype
log_probs = array_ops.one_hot( # shape(batch_sz, beam_sz)
array_ops.zeros([self._batch_size], dtype=dtypes.int32),
depth=self._beam_width,
- on_value=0.0,
- off_value=-np.Inf,
- dtype=nest.flatten(self._initial_cell_state)[0].dtype)
+ on_value=ops.convert_to_tensor(0.0, dtype=dtype),
+ off_value=ops.convert_to_tensor(-np.Inf, dtype=dtype),
+ dtype=dtype)
initial_state = BeamSearchDecoderState(
cell_state=self._initial_cell_state,