aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-24 10:11:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-24 10:16:34 -0700
commit377dd3d0d51f93f22eadfd18f4186c27d8506d69 (patch)
tree58896fe4f064bd96ea9512f2a51a9ae62c79ba8e
parent58b071639d97afdbc5ac5e222a4be81dcb344962 (diff)
Use tf.where instead of multiplies when masking probabilities in the BeamSearchDecoder.
PiperOrigin-RevId: 173273139
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py15
1 files changed, 7 insertions, 8 deletions
diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
index a88d4f5b8b..5be0c92243 100644
--- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
+++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
@@ -715,12 +715,6 @@ def _mask_probs(probs, eos_token, finished):
probability on the EOS token.
"""
vocab_size = array_ops.shape(probs)[2]
- finished_mask = math_ops.cast(array_ops.expand_dims(finished, 2), probs.dtype)
- not_finished_mask = math_ops.cast(
- array_ops.expand_dims(math_ops.logical_not(finished), 2),
- probs.dtype)
- # These examples are not finished and we leave them
- non_finished_examples = not_finished_mask * probs
# All finished examples are replaced with a vector that has all
# probability on EOS
finished_row = array_ops.one_hot(
@@ -729,8 +723,13 @@ def _mask_probs(probs, eos_token, finished):
dtype=probs.dtype,
on_value=0.,
off_value=probs.dtype.min)
- finished_examples = finished_mask * finished_row
- return finished_examples + non_finished_examples
+ finished_probs = array_ops.tile(
+ array_ops.reshape(finished_row, [1, 1, -1]),
+ array_ops.concat([array_ops.shape(finished), [1]], 0))
+ finished_mask = array_ops.tile(
+ array_ops.expand_dims(finished, 2), [1, 1, vocab_size])
+
+ return array_ops.where(finished_mask, finished_probs, probs)
def _maybe_tensor_gather_helper(gather_indices, gather_from, batch_size,