aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2017-10-16 11:32:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-16 11:37:03 -0700
commit74bd8ff717eaf08bf64f4b16c0bca40173b19614 (patch)
treee9ee5cc03955a82cd3e20c42a2b65a89e4ae2cca
parent2487732ff111daedaf489672700ccfbf2088c3de (diff)
[tf.contrib.seq2seq] Some light cleanup in beam search decoder code.
PiperOrigin-RevId: 172352767
-rw-r--r--tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py3
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py71
2 files changed, 39 insertions, 35 deletions
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
index 2caeb9eb61..8d4ec4b4db 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
@@ -80,8 +80,7 @@ class TestEosMasking(test.TestCase):
])
eos_token = 0
- previously_finished = constant_op.constant(
- [[0, 1, 0], [0, 1, 1]], dtype=dtypes.float32)
+ previously_finished = np.array([[0, 1, 0], [0, 1, 1]], dtype=bool)
masked = beam_search_decoder._mask_probs(probs, eos_token,
previously_finished)
diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
index e22912ac5c..112ac57a1b 100644
--- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
+++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
@@ -20,9 +20,10 @@ from __future__ import print_function
import collections
+import numpy as np
+
from tensorflow.contrib.seq2seq.python.ops import beam_search_ops
from tensorflow.contrib.seq2seq.python.ops import decoder
-from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
@@ -390,17 +391,17 @@ class BeamSearchDecoder(decoder.Decoder):
We do this so that we can use nest and not run into problems with shapes.
Args:
- t: Tensor of dimension [batch_size*beam_width, s]
- s: Tensor, Python int, or TensorShape.
+ t: `Tensor`, either scalar or shaped `[batch_size * beam_width] + s`.
+ s: `Tensor`, Python int, or `TensorShape`.
Returns:
- Either a reshaped version of t with dimension
- [batch_size, beam_width, s] if t's first dimension is of size
- batch_size*beam_width or t if not.
+ If `t` is a matrix or higher order tensor, then the return value is
+ `t` reshaped to `[batch_size, beam_width] + s`. Otherwise `t` is
+ returned unchanged.
Raises:
- TypeError: If t is an instance of TensorArray.
- ValueError: If the rank of t is not statically known.
+ TypeError: If `t` is an instance of `TensorArray`.
+ ValueError: If the rank of `t` is not statically known.
"""
_check_maybe(t)
if t.shape.ndims >= 1:
@@ -411,19 +412,19 @@ class BeamSearchDecoder(decoder.Decoder):
def _maybe_merge_batch_beams(self, t, s):
"""Splits the tensor from a batch by beams into a batch of beams.
- More exactly, t is a tensor of dimension [batch_size*beam_width, s]. We
- reshape this into [batch_size, beam_width, s]
+ More exactly, `t` is a tensor of dimension `[batch_size * beam_width] + s`,
+ then we reshape it to `[batch_size, beam_width] + s`.
Args:
- t: Tensor of dimension [batch_size*beam_width, s]
- s: Tensor, Python int, or TensorShape.
+ t: `Tensor` of dimension `[batch_size * beam_width] + s`.
+ s: `Tensor`, Python int, or `TensorShape`.
Returns:
- A reshaped version of t with dimension [batch_size, beam_width, s].
+ A reshaped version of t with shape `[batch_size, beam_width] + s`.
Raises:
- TypeError: If t is an instance of TensorArray.
- ValueError: If the rank of t is not statically known.
+ TypeError: If `t` is an instance of `TensorArray`.
+ ValueError: If the rank of `t` is not statically known.
"""
_check_maybe(t)
if t.shape.ndims >= 2:
@@ -521,14 +522,12 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
# Calculate the continuation lengths by adding to all continuing beams.
vocab_size = logits.shape[-1].value or array_ops.shape(logits)[-1]
lengths_to_add = array_ops.one_hot(
- indices=array_ops.tile(
- array_ops.reshape(end_token, [1, 1]), [batch_size, beam_width]),
+ indices=array_ops.fill([batch_size, beam_width], end_token),
depth=vocab_size,
- on_value=constant_op.constant(0, dtype=dtypes.int64),
- off_value=constant_op.constant(1, dtype=dtypes.int64),
+ on_value=np.int64(0), off_value=np.int64(1),
dtype=dtypes.int64)
- add_mask = (1 - math_ops.to_int64(previously_finished))
- lengths_to_add = array_ops.expand_dims(add_mask, 2) * lengths_to_add
+ add_mask = math_ops.to_int64(math_ops.logical_not(previously_finished))
+ lengths_to_add *= array_ops.expand_dims(add_mask, 2)
new_prediction_lengths = (
lengths_to_add + array_ops.expand_dims(prediction_lengths, 2))
@@ -592,9 +591,7 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
# 1. Finished beams remain unchanged
# 2. Beams that are now finished (EOS predicted) remain unchanged
# 3. Beams that are not yet finished have their length increased by 1
- lengths_to_add = math_ops.to_int64(
- math_ops.not_equal(next_word_ids, end_token))
- lengths_to_add = (1 - math_ops.to_int64(next_finished)) * lengths_to_add
+ lengths_to_add = math_ops.to_int64(math_ops.logical_not(next_finished))
next_prediction_len = _tensor_gather_helper(
gather_indices=next_beam_ids,
gather_from=beam_state.lengths,
@@ -652,13 +649,20 @@ def _get_scores(log_probs, sequence_lengths, length_penalty_weight):
def _length_penalty(sequence_lengths, penalty_factor):
"""Calculates the length penalty. See https://arxiv.org/abs/1609.08144.
+ Returns the length penalty tensor:
+ ```
+ [(5+sequence_lengths)/6]**penalty_factor
+ ```
+ where all operations are performed element-wise.
+
Args:
- sequence_lengths: The sequence length of all hypotheses, a tensor
- of shape [beam_size, vocab_size].
+ sequence_lengths: `Tensor`, the sequence lengths of each hypotheses.
penalty_factor: A scalar that weights the length penalty.
Returns:
- The length penalty factor, a tensor fo shape [beam_size].
+ If the penalty is `0`, returns the scalar `1.0`. Otherwise returns
+ the length penalty factor, a tensor with the same shape as
+ `sequence_lengths`.
"""
penalty_factor = ops.convert_to_tensor(penalty_factor, name="penalty_factor")
penalty_factor.set_shape(()) # penalty should be a scalar.
@@ -680,8 +684,7 @@ def _mask_probs(probs, eos_token, finished):
eos_token: An int32 id corresponding to the EOS token to allocate
probability to.
finished: A boolean tensor of shape `[batch_size, beam_width]` that
- specifies which
- elements in the beam are finished already.
+ specifies which elements in the beam are finished already.
Returns:
A tensor of shape `[batch_size, beam_width, vocab_size]`, where unfinished
@@ -689,10 +692,12 @@ def _mask_probs(probs, eos_token, finished):
probability on the EOS token.
"""
vocab_size = array_ops.shape(probs)[2]
- finished_mask = array_ops.expand_dims(
- math_ops.to_float(1. - math_ops.to_float(finished)), 2)
+ finished_mask = math_ops.cast(array_ops.expand_dims(finished, 2), probs.dtype)
+ not_finished_mask = math_ops.cast(
+ array_ops.expand_dims(math_ops.logical_not(finished), 2),
+ probs.dtype)
# These examples are not finished and we leave them
- non_finished_examples = finished_mask * probs
+ non_finished_examples = not_finished_mask * probs
# All finished examples are replaced with a vector that has all
# probability on EOS
finished_row = array_ops.one_hot(
@@ -701,7 +706,7 @@ def _mask_probs(probs, eos_token, finished):
dtype=probs.dtype,
on_value=0.,
off_value=probs.dtype.min)
- finished_examples = (1. - finished_mask) * finished_row
+ finished_examples = finished_mask * finished_row
return finished_examples + non_finished_examples