From 377dd3d0d51f93f22eadfd18f4186c27d8506d69 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 24 Oct 2017 10:11:42 -0700 Subject: Use tf.where instead of multiplies when masking probabilities in the BeamSearchDecoder. PiperOrigin-RevId: 173273139 --- .../contrib/seq2seq/python/ops/beam_search_decoder.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py index a88d4f5b8b..5be0c92243 100644 --- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py @@ -715,12 +715,6 @@ def _mask_probs(probs, eos_token, finished): probability on the EOS token. """ vocab_size = array_ops.shape(probs)[2] - finished_mask = math_ops.cast(array_ops.expand_dims(finished, 2), probs.dtype) - not_finished_mask = math_ops.cast( - array_ops.expand_dims(math_ops.logical_not(finished), 2), - probs.dtype) - # These examples are not finished and we leave them - non_finished_examples = not_finished_mask * probs # All finished examples are replaced with a vector that has all # probability on EOS finished_row = array_ops.one_hot( @@ -729,8 +723,13 @@ def _mask_probs(probs, eos_token, finished): dtype=probs.dtype, on_value=0., off_value=probs.dtype.min) - finished_examples = finished_mask * finished_row - return finished_examples + non_finished_examples + finished_probs = array_ops.tile( + array_ops.reshape(finished_row, [1, 1, -1]), + array_ops.concat([array_ops.shape(finished), [1]], 0)) + finished_mask = array_ops.tile( + array_ops.expand_dims(finished, 2), [1, 1, vocab_size]) + + return array_ops.where(finished_mask, finished_probs, probs) def _maybe_tensor_gather_helper(gather_indices, gather_from, batch_size, -- cgit v1.2.3