aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py')
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py175
1 files changed, 165 insertions, 10 deletions
diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
index 74741a7bd6..605e3143fd 100644
--- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
+++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import collections
import numpy as np
+from tensorflow.contrib.seq2seq.python.ops import attention_wrapper
from tensorflow.contrib.seq2seq.python.ops import beam_search_ops
from tensorflow.contrib.seq2seq.python.ops import decoder
from tensorflow.python.framework import dtypes
@@ -49,7 +50,8 @@ __all__ = [
class BeamSearchDecoderState(
collections.namedtuple("BeamSearchDecoderState",
- ("cell_state", "log_probs", "finished", "lengths"))):
+ ("cell_state", "log_probs", "finished", "lengths",
+ "accumulated_attention_probs"))):
pass
@@ -260,6 +262,10 @@ class BeamSearchDecoder(decoder.Decoder):
decoder_initial_state = decoder_initial_state.clone(
cell_state=tiled_encoder_final_state)
```
+
+ Meanwhile, with `AttentionWrapper`, coverage penalty is suggested to use
+ when computing scores(https://arxiv.org/pdf/1609.08144.pdf). It encourages
+ the translation to cover all inputs.
"""
def __init__(self,
@@ -271,6 +277,7 @@ class BeamSearchDecoder(decoder.Decoder):
beam_width,
output_layer=None,
length_penalty_weight=0.0,
+ coverage_penalty_weight=0.0,
reorder_tensor_arrays=True):
"""Initialize the BeamSearchDecoder.
@@ -286,6 +293,8 @@ class BeamSearchDecoder(decoder.Decoder):
`tf.layers.Dense`. Optional layer to apply to the RNN output prior
to storing the result or sampling.
length_penalty_weight: Float weight to penalize length. Disabled with 0.0.
+ coverage_penalty_weight: Float weight to penalize the coverage of source
+ sentence. Disabled with 0.0.
reorder_tensor_arrays: If `True`, `TensorArray`s' elements within the cell
state will be reordered according to the beam search path. If the
`TensorArray` can be reordered, the stacked form will be returned.
@@ -326,6 +335,7 @@ class BeamSearchDecoder(decoder.Decoder):
self._batch_size = array_ops.size(start_tokens)
self._beam_width = beam_width
self._length_penalty_weight = length_penalty_weight
+ self._coverage_penalty_weight = coverage_penalty_weight
self._initial_cell_state = nest.map_structure(
self._maybe_split_batch_beams, initial_state, self._cell.state_size)
self._start_tokens = array_ops.tile(
@@ -411,13 +421,18 @@ class BeamSearchDecoder(decoder.Decoder):
on_value=ops.convert_to_tensor(0.0, dtype=dtype),
off_value=ops.convert_to_tensor(-np.Inf, dtype=dtype),
dtype=dtype)
+ init_attention_probs = get_attention_probs(
+ self._initial_cell_state, self._coverage_penalty_weight)
+ if init_attention_probs is None:
+ init_attention_probs = ()
initial_state = BeamSearchDecoderState(
cell_state=self._initial_cell_state,
log_probs=log_probs,
finished=finished,
lengths=array_ops.zeros(
- [self._batch_size, self._beam_width], dtype=dtypes.int64))
+ [self._batch_size, self._beam_width], dtype=dtypes.int64),
+ accumulated_attention_probs=init_attention_probs)
return (finished, start_inputs, initial_state)
@@ -631,6 +646,7 @@ class BeamSearchDecoder(decoder.Decoder):
beam_width = self._beam_width
end_token = self._end_token
length_penalty_weight = self._length_penalty_weight
+ coverage_penalty_weight = self._coverage_penalty_weight
with ops.name_scope(name, "BeamSearchDecoderStep", (time, inputs, state)):
cell_state = state.cell_state
@@ -655,7 +671,8 @@ class BeamSearchDecoder(decoder.Decoder):
batch_size=batch_size,
beam_width=beam_width,
end_token=end_token,
- length_penalty_weight=length_penalty_weight)
+ length_penalty_weight=length_penalty_weight,
+ coverage_penalty_weight=coverage_penalty_weight)
finished = beam_search_state.finished
sample_ids = beam_search_output.predicted_ids
@@ -667,7 +684,8 @@ class BeamSearchDecoder(decoder.Decoder):
def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
- beam_width, end_token, length_penalty_weight):
+ beam_width, end_token, length_penalty_weight,
+ coverage_penalty_weight):
"""Performs a single step of Beam Search Decoding.
Args:
@@ -684,6 +702,8 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
beam_width: Python int. The size of the beams.
end_token: The int32 end token.
length_penalty_weight: Float weight to penalize length. Disabled with 0.0.
+ coverage_penalty_weight: Float weight to penalize the coverage of source
+ sentence. Disabled with 0.0.
Returns:
A new beam state.
@@ -693,6 +713,7 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
# Calculate the current lengths of the predictions
prediction_lengths = beam_state.lengths
previously_finished = beam_state.finished
+ not_finished = math_ops.logical_not(previously_finished)
# Calculate the total log probs for the new hypotheses
# Final Shape: [batch_size, beam_width, vocab_size]
@@ -708,16 +729,29 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
on_value=np.int64(0),
off_value=np.int64(1),
dtype=dtypes.int64)
- add_mask = math_ops.to_int64(math_ops.logical_not(previously_finished))
+ add_mask = math_ops.to_int64(not_finished)
lengths_to_add *= array_ops.expand_dims(add_mask, 2)
new_prediction_lengths = (
lengths_to_add + array_ops.expand_dims(prediction_lengths, 2))
+ # Calculate the accumulated attention probabilities if coverage penalty is
+ # enabled.
+ accumulated_attention_probs = None
+ attention_probs = get_attention_probs(
+ next_cell_state, coverage_penalty_weight)
+ if attention_probs is not None:
+ attention_probs *= array_ops.expand_dims(math_ops.to_float(not_finished), 2)
+ accumulated_attention_probs = (
+ beam_state.accumulated_attention_probs + attention_probs)
+
# Calculate the scores for each beam
scores = _get_scores(
log_probs=total_probs,
sequence_lengths=new_prediction_lengths,
- length_penalty_weight=length_penalty_weight)
+ length_penalty_weight=length_penalty_weight,
+ coverage_penalty_weight=coverage_penalty_weight,
+ finished=previously_finished,
+ accumulated_attention_probs=accumulated_attention_probs)
time = ops.convert_to_tensor(time, name="time")
# During the first time step we only consider the initial beam
@@ -775,6 +809,15 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
range_size=beam_width,
gather_shape=[-1])
next_prediction_len += lengths_to_add
+ next_accumulated_attention_probs = ()
+ if accumulated_attention_probs is not None:
+ next_accumulated_attention_probs = _tensor_gather_helper(
+ gather_indices=next_beam_ids,
+ gather_from=accumulated_attention_probs,
+ batch_size=batch_size,
+ range_size=beam_width,
+ gather_shape=[batch_size * beam_width, -1],
+ name="next_accumulated_attention_probs")
# Pick out the cell_states according to the next_beam_ids. We use a
# different gather_shape here because the cell_state tensors, i.e.
@@ -795,7 +838,8 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
cell_state=next_cell_state,
log_probs=next_beam_probs,
lengths=next_prediction_len,
- finished=next_finished)
+ finished=next_finished,
+ accumulated_attention_probs=next_accumulated_attention_probs)
output = BeamSearchDecoderOutput(
scores=next_beam_scores,
@@ -805,7 +849,53 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
return output, next_state
-def _get_scores(log_probs, sequence_lengths, length_penalty_weight):
+def get_attention_probs(next_cell_state, coverage_penalty_weight):
+ """Get attention probabilities from the cell state.
+
+ Args:
+ next_cell_state: The next state from the cell, e.g. an instance of
+ AttentionWrapperState if the cell is attentional.
+ coverage_penalty_weight: Float weight to penalize the coverage of source
+ sentence. Disabled with 0.0.
+
+ Returns:
+ The attention probabilities with shape `[batch_size, beam_width, max_time]`
+ if coverage penalty is enabled. Otherwise, returns None.
+
+ Raises:
+ ValueError: If no cell is attentional but coverage penalty is enabled.
+ """
+ if coverage_penalty_weight == 0.0:
+ return None
+
+ # Attention probabilities of each attention layer. Each with shape
+ # `[batch_size, beam_width, max_time]`.
+ probs_per_attn_layer = []
+ if isinstance(next_cell_state, attention_wrapper.AttentionWrapperState):
+ probs_per_attn_layer = [attention_probs_from_attn_state(next_cell_state)]
+ elif isinstance(next_cell_state, tuple):
+ for state in next_cell_state:
+ if isinstance(state, attention_wrapper.AttentionWrapperState):
+ probs_per_attn_layer.append(attention_probs_from_attn_state(state))
+
+ if not probs_per_attn_layer:
+ raise ValueError(
+ "coverage_penalty_weight must be 0.0 if no cell is attentional.")
+
+ if len(probs_per_attn_layer) == 1:
+ attention_probs = probs_per_attn_layer[0]
+ else:
+ # Calculate the average attention probabilities from all attention layers.
+ attention_probs = [
+ array_ops.expand_dims(prob, -1) for prob in probs_per_attn_layer]
+ attention_probs = array_ops.concat(attention_probs, -1)
+ attention_probs = math_ops.reduce_mean(attention_probs, -1)
+
+ return attention_probs
+
+
+def _get_scores(log_probs, sequence_lengths, length_penalty_weight,
+ coverage_penalty_weight, finished, accumulated_attention_probs):
"""Calculates scores for beam search hypotheses.
Args:
@@ -813,13 +903,78 @@ def _get_scores(log_probs, sequence_lengths, length_penalty_weight):
`[batch_size, beam_width, vocab_size]`.
sequence_lengths: The array of sequence lengths.
length_penalty_weight: Float weight to penalize length. Disabled with 0.0.
+ coverage_penalty_weight: Float weight to penalize the coverage of source
+ sentence. Disabled with 0.0.
+ finished: A boolean tensor of shape `[batch_size, beam_width]` that
+ specifies which elements in the beam are finished already.
+ accumulated_attention_probs: Accumulated attention probabilities up to the
+ current time step, with shape `[batch_size, beam_width, max_time]` if
+ coverage_penalty_weight is not 0.0.
Returns:
- The scores normalized by the length_penalty.
+ The scores normalized by the length_penalty and coverage_penalty.
+
+ Raises:
+ ValueError: accumulated_attention_probs is None when coverage penalty is
+ enabled.
"""
length_penalty_ = _length_penalty(
sequence_lengths=sequence_lengths, penalty_factor=length_penalty_weight)
- return log_probs / length_penalty_
+ scores = log_probs / length_penalty_
+
+ coverage_penalty_weight = ops.convert_to_tensor(
+ coverage_penalty_weight, name="coverage_penalty_weight")
+ if coverage_penalty_weight.shape.ndims != 0:
+ raise ValueError("coverage_penalty_weight should be a scalar, "
+ "but saw shape: %s" % coverage_penalty_weight.shape)
+
+ if tensor_util.constant_value(coverage_penalty_weight) == 0.0:
+ return scores
+
+ if accumulated_attention_probs is None:
+ raise ValueError(
+ "accumulated_attention_probs can be None only if coverage penalty is "
+ "disabled.")
+
+ # Add source sequence length mask before computing coverage penalty.
+ accumulated_attention_probs = array_ops.where(
+ math_ops.equal(accumulated_attention_probs, 0.0),
+ array_ops.ones_like(accumulated_attention_probs),
+ accumulated_attention_probs)
+
+ # coverage penalty =
+ # sum over `max_time` {log(min(accumulated_attention_probs, 1.0))}
+ coverage_penalty = math_ops.reduce_sum(
+ math_ops.log(math_ops.minimum(accumulated_attention_probs, 1.0)), 2)
+ # Apply coverage penalty to finished predictions.
+ coverage_penalty *= math_ops.to_float(finished)
+ weighted_coverage_penalty = coverage_penalty * coverage_penalty_weight
+ # Reshape from [batch_size, beam_width] to [batch_size, beam_width, 1]
+ weighted_coverage_penalty = array_ops.expand_dims(
+ weighted_coverage_penalty, 2)
+ return scores + weighted_coverage_penalty
+
+
+def attention_probs_from_attn_state(attention_state):
+ """Calculates the average attention probabilities.
+
+ Args:
+ attention_state: An instance of `AttentionWrapperState`.
+
+ Returns:
+ The attention probabilities in the given AttentionWrapperState.
+ If there're multiple attention mechanisms, return the average value from
+ all attention mechanisms.
+ """
+ # Attention probabilities over time steps, with shape
+ # `[batch_size, beam_width, max_time]`.
+ attention_probs = attention_state.alignments
+ if isinstance(attention_probs, tuple):
+ attention_probs = [
+ array_ops.expand_dims(prob, -1) for prob in attention_probs]
+ attention_probs = array_ops.concat(attention_probs, -1)
+ attention_probs = math_ops.reduce_mean(attention_probs, -1)
+ return attention_probs
def _length_penalty(sequence_lengths, penalty_factor):