aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/seq2seq
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-15 15:44:56 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-15 15:49:01 -0700
commit1e75c69339da2fbf2c5c5fbeb891243badae7ff8 (patch)
tree81268169bbff6836bfbbd4e9866a1374f597a624 /tensorflow/contrib/seq2seq
parent6c62e650252ab32f83637a8de6720e73ffeca226 (diff)
Automated g4 rollback of changelist 189231636
PiperOrigin-RevId: 189258641
Diffstat (limited to 'tensorflow/contrib/seq2seq')
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py7
1 files changed, 3 insertions, 4 deletions
diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
index 6adbb8be40..03fe31abf7 100644
--- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
+++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
@@ -299,13 +299,12 @@ class BeamSearchDecoder(decoder.Decoder):
"""
finished, start_inputs = self._finished, self._start_inputs
- dtype = nest.flatten(self._initial_cell_state)[0].dtype
log_probs = array_ops.one_hot( # shape(batch_sz, beam_sz)
array_ops.zeros([self._batch_size], dtype=dtypes.int32),
depth=self._beam_width,
- on_value=ops.convert_to_tensor(0.0, dtype=dtype),
- off_value=ops.convert_to_tensor(-np.Inf, dtype=dtype),
- dtype=dtype)
+ on_value=0.0,
+ off_value=-np.Inf,
+ dtype=nest.flatten(self._initial_cell_state)[0].dtype)
initial_state = BeamSearchDecoderState(
cell_state=self._initial_cell_state,