diff options
Diffstat (limited to 'tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py')
-rw-r--r-- | tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py | 175 |
1 files changed, 165 insertions, 10 deletions
diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py index 74741a7bd6..605e3143fd 100644 --- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py @@ -21,6 +21,7 @@ from __future__ import print_function import collections import numpy as np +from tensorflow.contrib.seq2seq.python.ops import attention_wrapper from tensorflow.contrib.seq2seq.python.ops import beam_search_ops from tensorflow.contrib.seq2seq.python.ops import decoder from tensorflow.python.framework import dtypes @@ -49,7 +50,8 @@ __all__ = [ class BeamSearchDecoderState( collections.namedtuple("BeamSearchDecoderState", - ("cell_state", "log_probs", "finished", "lengths"))): + ("cell_state", "log_probs", "finished", "lengths", + "accumulated_attention_probs"))): pass @@ -260,6 +262,10 @@ class BeamSearchDecoder(decoder.Decoder): decoder_initial_state = decoder_initial_state.clone( cell_state=tiled_encoder_final_state) ``` + + Meanwhile, with `AttentionWrapper`, coverage penalty is suggested to use + when computing scores(https://arxiv.org/pdf/1609.08144.pdf). It encourages + the translation to cover all inputs. """ def __init__(self, @@ -271,6 +277,7 @@ class BeamSearchDecoder(decoder.Decoder): beam_width, output_layer=None, length_penalty_weight=0.0, + coverage_penalty_weight=0.0, reorder_tensor_arrays=True): """Initialize the BeamSearchDecoder. @@ -286,6 +293,8 @@ class BeamSearchDecoder(decoder.Decoder): `tf.layers.Dense`. Optional layer to apply to the RNN output prior to storing the result or sampling. length_penalty_weight: Float weight to penalize length. Disabled with 0.0. + coverage_penalty_weight: Float weight to penalize the coverage of source + sentence. Disabled with 0.0. reorder_tensor_arrays: If `True`, `TensorArray`s' elements within the cell state will be reordered according to the beam search path. If the `TensorArray` can be reordered, the stacked form will be returned. @@ -326,6 +335,7 @@ class BeamSearchDecoder(decoder.Decoder): self._batch_size = array_ops.size(start_tokens) self._beam_width = beam_width self._length_penalty_weight = length_penalty_weight + self._coverage_penalty_weight = coverage_penalty_weight self._initial_cell_state = nest.map_structure( self._maybe_split_batch_beams, initial_state, self._cell.state_size) self._start_tokens = array_ops.tile( @@ -411,13 +421,18 @@ class BeamSearchDecoder(decoder.Decoder): on_value=ops.convert_to_tensor(0.0, dtype=dtype), off_value=ops.convert_to_tensor(-np.Inf, dtype=dtype), dtype=dtype) + init_attention_probs = get_attention_probs( + self._initial_cell_state, self._coverage_penalty_weight) + if init_attention_probs is None: + init_attention_probs = () initial_state = BeamSearchDecoderState( cell_state=self._initial_cell_state, log_probs=log_probs, finished=finished, lengths=array_ops.zeros( - [self._batch_size, self._beam_width], dtype=dtypes.int64)) + [self._batch_size, self._beam_width], dtype=dtypes.int64), + accumulated_attention_probs=init_attention_probs) return (finished, start_inputs, initial_state) @@ -631,6 +646,7 @@ class BeamSearchDecoder(decoder.Decoder): beam_width = self._beam_width end_token = self._end_token length_penalty_weight = self._length_penalty_weight + coverage_penalty_weight = self._coverage_penalty_weight with ops.name_scope(name, "BeamSearchDecoderStep", (time, inputs, state)): cell_state = state.cell_state @@ -655,7 +671,8 @@ class BeamSearchDecoder(decoder.Decoder): batch_size=batch_size, beam_width=beam_width, end_token=end_token, - length_penalty_weight=length_penalty_weight) + length_penalty_weight=length_penalty_weight, + coverage_penalty_weight=coverage_penalty_weight) finished = beam_search_state.finished sample_ids = beam_search_output.predicted_ids @@ -667,7 +684,8 @@ class BeamSearchDecoder(decoder.Decoder): def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size, - beam_width, end_token, length_penalty_weight): + beam_width, end_token, length_penalty_weight, + coverage_penalty_weight): """Performs a single step of Beam Search Decoding. Args: @@ -684,6 +702,8 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size, beam_width: Python int. The size of the beams. end_token: The int32 end token. length_penalty_weight: Float weight to penalize length. Disabled with 0.0. + coverage_penalty_weight: Float weight to penalize the coverage of source + sentence. Disabled with 0.0. Returns: A new beam state. @@ -693,6 +713,7 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size, # Calculate the current lengths of the predictions prediction_lengths = beam_state.lengths previously_finished = beam_state.finished + not_finished = math_ops.logical_not(previously_finished) # Calculate the total log probs for the new hypotheses # Final Shape: [batch_size, beam_width, vocab_size] @@ -708,16 +729,29 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size, on_value=np.int64(0), off_value=np.int64(1), dtype=dtypes.int64) - add_mask = math_ops.to_int64(math_ops.logical_not(previously_finished)) + add_mask = math_ops.to_int64(not_finished) lengths_to_add *= array_ops.expand_dims(add_mask, 2) new_prediction_lengths = ( lengths_to_add + array_ops.expand_dims(prediction_lengths, 2)) + # Calculate the accumulated attention probabilities if coverage penalty is + # enabled. + accumulated_attention_probs = None + attention_probs = get_attention_probs( + next_cell_state, coverage_penalty_weight) + if attention_probs is not None: + attention_probs *= array_ops.expand_dims(math_ops.to_float(not_finished), 2) + accumulated_attention_probs = ( + beam_state.accumulated_attention_probs + attention_probs) + # Calculate the scores for each beam scores = _get_scores( log_probs=total_probs, sequence_lengths=new_prediction_lengths, - length_penalty_weight=length_penalty_weight) + length_penalty_weight=length_penalty_weight, + coverage_penalty_weight=coverage_penalty_weight, + finished=previously_finished, + accumulated_attention_probs=accumulated_attention_probs) time = ops.convert_to_tensor(time, name="time") # During the first time step we only consider the initial beam @@ -775,6 +809,15 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size, range_size=beam_width, gather_shape=[-1]) next_prediction_len += lengths_to_add + next_accumulated_attention_probs = () + if accumulated_attention_probs is not None: + next_accumulated_attention_probs = _tensor_gather_helper( + gather_indices=next_beam_ids, + gather_from=accumulated_attention_probs, + batch_size=batch_size, + range_size=beam_width, + gather_shape=[batch_size * beam_width, -1], + name="next_accumulated_attention_probs") # Pick out the cell_states according to the next_beam_ids. We use a # different gather_shape here because the cell_state tensors, i.e. @@ -795,7 +838,8 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size, cell_state=next_cell_state, log_probs=next_beam_probs, lengths=next_prediction_len, - finished=next_finished) + finished=next_finished, + accumulated_attention_probs=next_accumulated_attention_probs) output = BeamSearchDecoderOutput( scores=next_beam_scores, @@ -805,7 +849,53 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size, return output, next_state -def _get_scores(log_probs, sequence_lengths, length_penalty_weight): +def get_attention_probs(next_cell_state, coverage_penalty_weight): + """Get attention probabilities from the cell state. + + Args: + next_cell_state: The next state from the cell, e.g. an instance of + AttentionWrapperState if the cell is attentional. + coverage_penalty_weight: Float weight to penalize the coverage of source + sentence. Disabled with 0.0. + + Returns: + The attention probabilities with shape `[batch_size, beam_width, max_time]` + if coverage penalty is enabled. Otherwise, returns None. + + Raises: + ValueError: If no cell is attentional but coverage penalty is enabled. + """ + if coverage_penalty_weight == 0.0: + return None + + # Attention probabilities of each attention layer. Each with shape + # `[batch_size, beam_width, max_time]`. + probs_per_attn_layer = [] + if isinstance(next_cell_state, attention_wrapper.AttentionWrapperState): + probs_per_attn_layer = [attention_probs_from_attn_state(next_cell_state)] + elif isinstance(next_cell_state, tuple): + for state in next_cell_state: + if isinstance(state, attention_wrapper.AttentionWrapperState): + probs_per_attn_layer.append(attention_probs_from_attn_state(state)) + + if not probs_per_attn_layer: + raise ValueError( + "coverage_penalty_weight must be 0.0 if no cell is attentional.") + + if len(probs_per_attn_layer) == 1: + attention_probs = probs_per_attn_layer[0] + else: + # Calculate the average attention probabilities from all attention layers. + attention_probs = [ + array_ops.expand_dims(prob, -1) for prob in probs_per_attn_layer] + attention_probs = array_ops.concat(attention_probs, -1) + attention_probs = math_ops.reduce_mean(attention_probs, -1) + + return attention_probs + + +def _get_scores(log_probs, sequence_lengths, length_penalty_weight, + coverage_penalty_weight, finished, accumulated_attention_probs): """Calculates scores for beam search hypotheses. Args: @@ -813,13 +903,78 @@ def _get_scores(log_probs, sequence_lengths, length_penalty_weight): `[batch_size, beam_width, vocab_size]`. sequence_lengths: The array of sequence lengths. length_penalty_weight: Float weight to penalize length. Disabled with 0.0. + coverage_penalty_weight: Float weight to penalize the coverage of source + sentence. Disabled with 0.0. + finished: A boolean tensor of shape `[batch_size, beam_width]` that + specifies which elements in the beam are finished already. + accumulated_attention_probs: Accumulated attention probabilities up to the + current time step, with shape `[batch_size, beam_width, max_time]` if + coverage_penalty_weight is not 0.0. Returns: - The scores normalized by the length_penalty. + The scores normalized by the length_penalty and coverage_penalty. + + Raises: + ValueError: accumulated_attention_probs is None when coverage penalty is + enabled. """ length_penalty_ = _length_penalty( sequence_lengths=sequence_lengths, penalty_factor=length_penalty_weight) - return log_probs / length_penalty_ + scores = log_probs / length_penalty_ + + coverage_penalty_weight = ops.convert_to_tensor( + coverage_penalty_weight, name="coverage_penalty_weight") + if coverage_penalty_weight.shape.ndims != 0: + raise ValueError("coverage_penalty_weight should be a scalar, " + "but saw shape: %s" % coverage_penalty_weight.shape) + + if tensor_util.constant_value(coverage_penalty_weight) == 0.0: + return scores + + if accumulated_attention_probs is None: + raise ValueError( + "accumulated_attention_probs can be None only if coverage penalty is " + "disabled.") + + # Add source sequence length mask before computing coverage penalty. + accumulated_attention_probs = array_ops.where( + math_ops.equal(accumulated_attention_probs, 0.0), + array_ops.ones_like(accumulated_attention_probs), + accumulated_attention_probs) + + # coverage penalty = + # sum over `max_time` {log(min(accumulated_attention_probs, 1.0))} + coverage_penalty = math_ops.reduce_sum( + math_ops.log(math_ops.minimum(accumulated_attention_probs, 1.0)), 2) + # Apply coverage penalty to finished predictions. + coverage_penalty *= math_ops.to_float(finished) + weighted_coverage_penalty = coverage_penalty * coverage_penalty_weight + # Reshape from [batch_size, beam_width] to [batch_size, beam_width, 1] + weighted_coverage_penalty = array_ops.expand_dims( + weighted_coverage_penalty, 2) + return scores + weighted_coverage_penalty + + +def attention_probs_from_attn_state(attention_state): + """Calculates the average attention probabilities. + + Args: + attention_state: An instance of `AttentionWrapperState`. + + Returns: + The attention probabilities in the given AttentionWrapperState. + If there're multiple attention mechanisms, return the average value from + all attention mechanisms. + """ + # Attention probabilities over time steps, with shape + # `[batch_size, beam_width, max_time]`. + attention_probs = attention_state.alignments + if isinstance(attention_probs, tuple): + attention_probs = [ + array_ops.expand_dims(prob, -1) for prob in attention_probs] + attention_probs = array_ops.concat(attention_probs, -1) + attention_probs = math_ops.reduce_mean(attention_probs, -1) + return attention_probs def _length_penalty(sequence_lengths, penalty_factor): |