aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/seq2seq
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-12 12:23:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-12 12:27:53 -0700
commitf02a2fad042bc401f3d1c89a9fd52e40ca5d1835 (patch)
tree30d4ab7ef3aed5aeef4c1d8fe5214614e6bb1f0a /tensorflow/contrib/seq2seq
parent53b57715f5604a5d09a9ddc73bbbf54f1d1142ed (diff)
Support coverage penalty for beam search decoder (according to https://arxiv.org/pdf/1609.08144.pdf).
PiperOrigin-RevId: 212683753
Diffstat (limited to 'tensorflow/contrib/seq2seq')
-rw-r--r--tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py25
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py175
2 files changed, 183 insertions, 17 deletions
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
index f5b6b1bde9..5e28e651c6 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
@@ -248,6 +248,7 @@ class TestBeamStep(test.TestCase):
self.vocab_size = 5
self.end_token = 0
self.length_penalty_weight = 0.6
+ self.coverage_penalty_weight = 0.0
def test_step(self):
dummy_cell_state = array_ops.zeros([self.batch_size, self.beam_width])
@@ -258,7 +259,8 @@ class TestBeamStep(test.TestCase):
lengths=constant_op.constant(
2, shape=[self.batch_size, self.beam_width], dtype=dtypes.int64),
finished=array_ops.zeros(
- [self.batch_size, self.beam_width], dtype=dtypes.bool))
+ [self.batch_size, self.beam_width], dtype=dtypes.bool),
+ accumulated_attention_probs=())
logits_ = np.full([self.batch_size, self.beam_width, self.vocab_size],
0.0001)
@@ -281,7 +283,8 @@ class TestBeamStep(test.TestCase):
batch_size=ops.convert_to_tensor(self.batch_size),
beam_width=self.beam_width,
end_token=self.end_token,
- length_penalty_weight=self.length_penalty_weight)
+ length_penalty_weight=self.length_penalty_weight,
+ coverage_penalty_weight=self.coverage_penalty_weight)
with self.cached_session() as sess:
outputs_, next_state_, state_, log_probs_ = sess.run(
@@ -313,7 +316,8 @@ class TestBeamStep(test.TestCase):
lengths=ops.convert_to_tensor(
[[2, 1, 2], [2, 2, 1]], dtype=dtypes.int64),
finished=ops.convert_to_tensor(
- [[False, True, False], [False, False, True]], dtype=dtypes.bool))
+ [[False, True, False], [False, False, True]], dtype=dtypes.bool),
+ accumulated_attention_probs=())
logits_ = np.full([self.batch_size, self.beam_width, self.vocab_size],
0.0001)
@@ -336,7 +340,8 @@ class TestBeamStep(test.TestCase):
batch_size=ops.convert_to_tensor(self.batch_size),
beam_width=self.beam_width,
end_token=self.end_token,
- length_penalty_weight=self.length_penalty_weight)
+ length_penalty_weight=self.length_penalty_weight,
+ coverage_penalty_weight=self.coverage_penalty_weight)
with self.cached_session() as sess:
outputs_, next_state_, state_, log_probs_ = sess.run(
@@ -372,6 +377,7 @@ class TestLargeBeamStep(test.TestCase):
self.vocab_size = 5
self.end_token = 0
self.length_penalty_weight = 0.6
+ self.coverage_penalty_weight = 0.0
def test_step(self):
@@ -411,7 +417,8 @@ class TestLargeBeamStep(test.TestCase):
cell_state=dummy_cell_state,
log_probs=log_probs,
lengths=_lengths,
- finished=_finished)
+ finished=_finished,
+ accumulated_attention_probs=())
logits_ = np.full([self.batch_size, self.beam_width, self.vocab_size],
0.0001)
@@ -434,7 +441,8 @@ class TestLargeBeamStep(test.TestCase):
batch_size=ops.convert_to_tensor(self.batch_size),
beam_width=self.beam_width,
end_token=self.end_token,
- length_penalty_weight=self.length_penalty_weight)
+ length_penalty_weight=self.length_penalty_weight,
+ coverage_penalty_weight=self.coverage_penalty_weight)
with self.cached_session() as sess:
outputs_, next_state_, _, _ = sess.run(
@@ -476,7 +484,9 @@ class BeamSearchDecoderTest(test.TestCase):
embedding = np.random.randn(vocab_size, embedding_dim).astype(np.float32)
cell = rnn_cell.LSTMCell(cell_depth)
initial_state = cell.zero_state(batch_size, dtypes.float32)
+ coverage_penalty_weight = 0.0
if has_attention:
+ coverage_penalty_weight = 0.2
inputs = array_ops.placeholder_with_default(
np.random.randn(batch_size, decoder_max_time, input_depth).astype(
np.float32),
@@ -508,7 +518,8 @@ class BeamSearchDecoderTest(test.TestCase):
initial_state=cell_state,
beam_width=beam_width,
output_layer=output_layer,
- length_penalty_weight=0.0)
+ length_penalty_weight=0.0,
+ coverage_penalty_weight=coverage_penalty_weight)
final_outputs, final_state, final_sequence_lengths = (
decoder.dynamic_decode(
diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
index 74741a7bd6..605e3143fd 100644
--- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
+++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import collections
import numpy as np
+from tensorflow.contrib.seq2seq.python.ops import attention_wrapper
from tensorflow.contrib.seq2seq.python.ops import beam_search_ops
from tensorflow.contrib.seq2seq.python.ops import decoder
from tensorflow.python.framework import dtypes
@@ -49,7 +50,8 @@ __all__ = [
class BeamSearchDecoderState(
collections.namedtuple("BeamSearchDecoderState",
- ("cell_state", "log_probs", "finished", "lengths"))):
+ ("cell_state", "log_probs", "finished", "lengths",
+ "accumulated_attention_probs"))):
pass
@@ -260,6 +262,10 @@ class BeamSearchDecoder(decoder.Decoder):
decoder_initial_state = decoder_initial_state.clone(
cell_state=tiled_encoder_final_state)
```
+
+ Meanwhile, with `AttentionWrapper`, coverage penalty is suggested to use
+ when computing scores(https://arxiv.org/pdf/1609.08144.pdf). It encourages
+ the translation to cover all inputs.
"""
def __init__(self,
@@ -271,6 +277,7 @@ class BeamSearchDecoder(decoder.Decoder):
beam_width,
output_layer=None,
length_penalty_weight=0.0,
+ coverage_penalty_weight=0.0,
reorder_tensor_arrays=True):
"""Initialize the BeamSearchDecoder.
@@ -286,6 +293,8 @@ class BeamSearchDecoder(decoder.Decoder):
`tf.layers.Dense`. Optional layer to apply to the RNN output prior
to storing the result or sampling.
length_penalty_weight: Float weight to penalize length. Disabled with 0.0.
+ coverage_penalty_weight: Float weight to penalize the coverage of source
+ sentence. Disabled with 0.0.
reorder_tensor_arrays: If `True`, `TensorArray`s' elements within the cell
state will be reordered according to the beam search path. If the
`TensorArray` can be reordered, the stacked form will be returned.
@@ -326,6 +335,7 @@ class BeamSearchDecoder(decoder.Decoder):
self._batch_size = array_ops.size(start_tokens)
self._beam_width = beam_width
self._length_penalty_weight = length_penalty_weight
+ self._coverage_penalty_weight = coverage_penalty_weight
self._initial_cell_state = nest.map_structure(
self._maybe_split_batch_beams, initial_state, self._cell.state_size)
self._start_tokens = array_ops.tile(
@@ -411,13 +421,18 @@ class BeamSearchDecoder(decoder.Decoder):
on_value=ops.convert_to_tensor(0.0, dtype=dtype),
off_value=ops.convert_to_tensor(-np.Inf, dtype=dtype),
dtype=dtype)
+ init_attention_probs = get_attention_probs(
+ self._initial_cell_state, self._coverage_penalty_weight)
+ if init_attention_probs is None:
+ init_attention_probs = ()
initial_state = BeamSearchDecoderState(
cell_state=self._initial_cell_state,
log_probs=log_probs,
finished=finished,
lengths=array_ops.zeros(
- [self._batch_size, self._beam_width], dtype=dtypes.int64))
+ [self._batch_size, self._beam_width], dtype=dtypes.int64),
+ accumulated_attention_probs=init_attention_probs)
return (finished, start_inputs, initial_state)
@@ -631,6 +646,7 @@ class BeamSearchDecoder(decoder.Decoder):
beam_width = self._beam_width
end_token = self._end_token
length_penalty_weight = self._length_penalty_weight
+ coverage_penalty_weight = self._coverage_penalty_weight
with ops.name_scope(name, "BeamSearchDecoderStep", (time, inputs, state)):
cell_state = state.cell_state
@@ -655,7 +671,8 @@ class BeamSearchDecoder(decoder.Decoder):
batch_size=batch_size,
beam_width=beam_width,
end_token=end_token,
- length_penalty_weight=length_penalty_weight)
+ length_penalty_weight=length_penalty_weight,
+ coverage_penalty_weight=coverage_penalty_weight)
finished = beam_search_state.finished
sample_ids = beam_search_output.predicted_ids
@@ -667,7 +684,8 @@ class BeamSearchDecoder(decoder.Decoder):
def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
- beam_width, end_token, length_penalty_weight):
+ beam_width, end_token, length_penalty_weight,
+ coverage_penalty_weight):
"""Performs a single step of Beam Search Decoding.
Args:
@@ -684,6 +702,8 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
beam_width: Python int. The size of the beams.
end_token: The int32 end token.
length_penalty_weight: Float weight to penalize length. Disabled with 0.0.
+ coverage_penalty_weight: Float weight to penalize the coverage of source
+ sentence. Disabled with 0.0.
Returns:
A new beam state.
@@ -693,6 +713,7 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
# Calculate the current lengths of the predictions
prediction_lengths = beam_state.lengths
previously_finished = beam_state.finished
+ not_finished = math_ops.logical_not(previously_finished)
# Calculate the total log probs for the new hypotheses
# Final Shape: [batch_size, beam_width, vocab_size]
@@ -708,16 +729,29 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
on_value=np.int64(0),
off_value=np.int64(1),
dtype=dtypes.int64)
- add_mask = math_ops.to_int64(math_ops.logical_not(previously_finished))
+ add_mask = math_ops.to_int64(not_finished)
lengths_to_add *= array_ops.expand_dims(add_mask, 2)
new_prediction_lengths = (
lengths_to_add + array_ops.expand_dims(prediction_lengths, 2))
+ # Calculate the accumulated attention probabilities if coverage penalty is
+ # enabled.
+ accumulated_attention_probs = None
+ attention_probs = get_attention_probs(
+ next_cell_state, coverage_penalty_weight)
+ if attention_probs is not None:
+ attention_probs *= array_ops.expand_dims(math_ops.to_float(not_finished), 2)
+ accumulated_attention_probs = (
+ beam_state.accumulated_attention_probs + attention_probs)
+
# Calculate the scores for each beam
scores = _get_scores(
log_probs=total_probs,
sequence_lengths=new_prediction_lengths,
- length_penalty_weight=length_penalty_weight)
+ length_penalty_weight=length_penalty_weight,
+ coverage_penalty_weight=coverage_penalty_weight,
+ finished=previously_finished,
+ accumulated_attention_probs=accumulated_attention_probs)
time = ops.convert_to_tensor(time, name="time")
# During the first time step we only consider the initial beam
@@ -775,6 +809,15 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
range_size=beam_width,
gather_shape=[-1])
next_prediction_len += lengths_to_add
+ next_accumulated_attention_probs = ()
+ if accumulated_attention_probs is not None:
+ next_accumulated_attention_probs = _tensor_gather_helper(
+ gather_indices=next_beam_ids,
+ gather_from=accumulated_attention_probs,
+ batch_size=batch_size,
+ range_size=beam_width,
+ gather_shape=[batch_size * beam_width, -1],
+ name="next_accumulated_attention_probs")
# Pick out the cell_states according to the next_beam_ids. We use a
# different gather_shape here because the cell_state tensors, i.e.
@@ -795,7 +838,8 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
cell_state=next_cell_state,
log_probs=next_beam_probs,
lengths=next_prediction_len,
- finished=next_finished)
+ finished=next_finished,
+ accumulated_attention_probs=next_accumulated_attention_probs)
output = BeamSearchDecoderOutput(
scores=next_beam_scores,
@@ -805,7 +849,53 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
return output, next_state
-def _get_scores(log_probs, sequence_lengths, length_penalty_weight):
+def get_attention_probs(next_cell_state, coverage_penalty_weight):
+ """Get attention probabilities from the cell state.
+
+ Args:
+ next_cell_state: The next state from the cell, e.g. an instance of
+ AttentionWrapperState if the cell is attentional.
+ coverage_penalty_weight: Float weight to penalize the coverage of source
+ sentence. Disabled with 0.0.
+
+ Returns:
+ The attention probabilities with shape `[batch_size, beam_width, max_time]`
+ if coverage penalty is enabled. Otherwise, returns None.
+
+ Raises:
+ ValueError: If no cell is attentional but coverage penalty is enabled.
+ """
+ if coverage_penalty_weight == 0.0:
+ return None
+
+ # Attention probabilities of each attention layer. Each with shape
+ # `[batch_size, beam_width, max_time]`.
+ probs_per_attn_layer = []
+ if isinstance(next_cell_state, attention_wrapper.AttentionWrapperState):
+ probs_per_attn_layer = [attention_probs_from_attn_state(next_cell_state)]
+ elif isinstance(next_cell_state, tuple):
+ for state in next_cell_state:
+ if isinstance(state, attention_wrapper.AttentionWrapperState):
+ probs_per_attn_layer.append(attention_probs_from_attn_state(state))
+
+ if not probs_per_attn_layer:
+ raise ValueError(
+ "coverage_penalty_weight must be 0.0 if no cell is attentional.")
+
+ if len(probs_per_attn_layer) == 1:
+ attention_probs = probs_per_attn_layer[0]
+ else:
+ # Calculate the average attention probabilities from all attention layers.
+ attention_probs = [
+ array_ops.expand_dims(prob, -1) for prob in probs_per_attn_layer]
+ attention_probs = array_ops.concat(attention_probs, -1)
+ attention_probs = math_ops.reduce_mean(attention_probs, -1)
+
+ return attention_probs
+
+
+def _get_scores(log_probs, sequence_lengths, length_penalty_weight,
+ coverage_penalty_weight, finished, accumulated_attention_probs):
"""Calculates scores for beam search hypotheses.
Args:
@@ -813,13 +903,78 @@ def _get_scores(log_probs, sequence_lengths, length_penalty_weight):
`[batch_size, beam_width, vocab_size]`.
sequence_lengths: The array of sequence lengths.
length_penalty_weight: Float weight to penalize length. Disabled with 0.0.
+ coverage_penalty_weight: Float weight to penalize the coverage of source
+ sentence. Disabled with 0.0.
+ finished: A boolean tensor of shape `[batch_size, beam_width]` that
+ specifies which elements in the beam are finished already.
+ accumulated_attention_probs: Accumulated attention probabilities up to the
+ current time step, with shape `[batch_size, beam_width, max_time]` if
+ coverage_penalty_weight is not 0.0.
Returns:
- The scores normalized by the length_penalty.
+ The scores normalized by the length_penalty and coverage_penalty.
+
+ Raises:
+ ValueError: accumulated_attention_probs is None when coverage penalty is
+ enabled.
"""
length_penalty_ = _length_penalty(
sequence_lengths=sequence_lengths, penalty_factor=length_penalty_weight)
- return log_probs / length_penalty_
+ scores = log_probs / length_penalty_
+
+ coverage_penalty_weight = ops.convert_to_tensor(
+ coverage_penalty_weight, name="coverage_penalty_weight")
+ if coverage_penalty_weight.shape.ndims != 0:
+ raise ValueError("coverage_penalty_weight should be a scalar, "
+ "but saw shape: %s" % coverage_penalty_weight.shape)
+
+ if tensor_util.constant_value(coverage_penalty_weight) == 0.0:
+ return scores
+
+ if accumulated_attention_probs is None:
+ raise ValueError(
+ "accumulated_attention_probs can be None only if coverage penalty is "
+ "disabled.")
+
+ # Add source sequence length mask before computing coverage penalty.
+ accumulated_attention_probs = array_ops.where(
+ math_ops.equal(accumulated_attention_probs, 0.0),
+ array_ops.ones_like(accumulated_attention_probs),
+ accumulated_attention_probs)
+
+ # coverage penalty =
+ # sum over `max_time` {log(min(accumulated_attention_probs, 1.0))}
+ coverage_penalty = math_ops.reduce_sum(
+ math_ops.log(math_ops.minimum(accumulated_attention_probs, 1.0)), 2)
+ # Apply coverage penalty to finished predictions.
+ coverage_penalty *= math_ops.to_float(finished)
+ weighted_coverage_penalty = coverage_penalty * coverage_penalty_weight
+ # Reshape from [batch_size, beam_width] to [batch_size, beam_width, 1]
+ weighted_coverage_penalty = array_ops.expand_dims(
+ weighted_coverage_penalty, 2)
+ return scores + weighted_coverage_penalty
+
+
+def attention_probs_from_attn_state(attention_state):
+ """Calculates the average attention probabilities.
+
+ Args:
+ attention_state: An instance of `AttentionWrapperState`.
+
+ Returns:
+ The attention probabilities in the given AttentionWrapperState.
+ If there're multiple attention mechanisms, return the average value from
+ all attention mechanisms.
+ """
+ # Attention probabilities over time steps, with shape
+ # `[batch_size, beam_width, max_time]`.
+ attention_probs = attention_state.alignments
+ if isinstance(attention_probs, tuple):
+ attention_probs = [
+ array_ops.expand_dims(prob, -1) for prob in attention_probs]
+ attention_probs = array_ops.concat(attention_probs, -1)
+ attention_probs = math_ops.reduce_mean(attention_probs, -1)
+ return attention_probs
def _length_penalty(sequence_lengths, penalty_factor):