diff options
author | 2017-12-15 16:06:31 -0800 | |
---|---|---|
committer | 2017-12-15 16:13:24 -0800 | |
commit | f3df9fcaefeb3ab0fd83f255bec93e1a3c013a5e (patch) | |
tree | bb51b78abdaa81e41b5ee0d6f27ed96658221835 | |
parent | deb50f80e5a87325adec5db826673e05acb4f5ab (diff) |
[tf.contrib.seq2seq] Modify AttentionMechanisms to propagate state.
By default, the state is just the previous alignment.
This allows for more complex attention mechanisms (upcoming).
PiperOrigin-RevId: 179251889
-rw-r--r-- | tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py | 32 | ||||
-rw-r--r-- | tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py | 103 |
2 files changed, 103 insertions, 32 deletions
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py index 01a5540121..e5d591788f 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py @@ -254,6 +254,8 @@ class AttentionWrapperTest(test.TestCase): time=3, alignments=ResultSummary( shape=(5, 8), dtype=dtype('float32'), mean=0.125), + attention_state=ResultSummary( + shape=(5, 8), dtype=dtype('float32'), mean=0.125), alignment_history=()) expected_final_alignment_history = ResultSummary( shape=(3, 5, 8), dtype=dtype('float32'), mean=0.12500001) @@ -286,6 +288,8 @@ class AttentionWrapperTest(test.TestCase): time=3, alignments=ResultSummary( shape=(5, 8), dtype=dtype('float32'), mean=0.125), + attention_state=ResultSummary( + shape=(5, 8), dtype=dtype('float32'), mean=0.125), alignment_history=()) self._testWithAttention( @@ -313,6 +317,8 @@ class AttentionWrapperTest(test.TestCase): time=3, alignments=ResultSummary( shape=(5, 8), dtype=dtype('float32'), mean=0.125), + attention_state=ResultSummary( + shape=(5, 8), dtype=dtype('float32'), mean=0.125), alignment_history=()) self._testWithAttention( @@ -342,6 +348,8 @@ class AttentionWrapperTest(test.TestCase): time=3, alignments=ResultSummary( shape=(5, 8), dtype=dtype('float32'), mean=0.125), + attention_state=ResultSummary( + shape=(5, 8), dtype=dtype('float32'), mean=0.125), alignment_history=()) self._testWithAttention( @@ -370,6 +378,8 @@ class AttentionWrapperTest(test.TestCase): time=3, alignments=ResultSummary( shape=(5, 8), dtype=dtype('float32'), mean=0.125), + attention_state=ResultSummary( + shape=(5, 8), dtype=dtype('float32'), mean=0.125), alignment_history=()) self._testWithAttention( @@ -545,6 +555,8 @@ class AttentionWrapperTest(test.TestCase): time=3, alignments=ResultSummary( shape=(5, 8), dtype=dtype('float32'), mean=0.032228071), + attention_state=ResultSummary( + shape=(5, 8), dtype=dtype('float32'), mean=0.032228071), alignment_history=()) expected_final_alignment_history = ResultSummary( shape=(3, 5, 8), dtype=dtype('float32'), mean=0.050430927) @@ -578,6 +590,8 @@ class AttentionWrapperTest(test.TestCase): time=3, alignments=ResultSummary( shape=(5, 8), dtype=dtype('float32'), mean=0.028698336), + attention_state=ResultSummary( + shape=(5, 8), dtype=dtype('float32'), mean=0.028698336), alignment_history=()) expected_final_alignment_history = ResultSummary( shape=(3, 5, 8), dtype=dtype('float32'), mean=0.046009291) @@ -599,7 +613,8 @@ class AttentionWrapperTest(test.TestCase): random_ops.random_normal((b, t, u)), mode='hard') # Just feed previous attention as [1, 0, 0, ...] - attn = a(random_ops.random_normal((b, d)), array_ops.one_hot([0]*b, t)) + attn, unused_state = a( + random_ops.random_normal((b, d)), array_ops.one_hot([0]*b, t)) sess.run(variables.global_variables_initializer()) attn_out = attn.eval() # All values should be 0 or 1 @@ -629,6 +644,8 @@ class AttentionWrapperTest(test.TestCase): time=3, alignments=ResultSummary( shape=(5, 8), dtype=dtype('float32'), mean=0.032198936), + attention_state=ResultSummary( + shape=(5, 8), dtype=dtype('float32'), mean=0.032198936), alignment_history=()) expected_final_alignment_history = ResultSummary( shape=(3, 5, 8), dtype=dtype('float32'), mean=0.050387777) @@ -663,6 +680,8 @@ class AttentionWrapperTest(test.TestCase): time=3, alignments=ResultSummary( shape=(5, 8), dtype=dtype('float32'), mean=0.032198936), + attention_state=ResultSummary( + shape=(5, 8), dtype=dtype('float32'), mean=0.032198936), alignment_history=()) expected_final_alignment_history = ResultSummary( shape=(3, 5, 8), dtype=dtype('float32'), mean=0.050387777) @@ -697,6 +716,9 @@ class AttentionWrapperTest(test.TestCase): alignments=( ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125), ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125)), + attention_state=( + ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125), + ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125)), alignment_history=()) expected_final_alignment_history = ( @@ -723,7 +745,8 @@ class AttentionWrapperTest(test.TestCase): random_ops.random_normal((b, t, u)), mode='hard') # Just feed previous attention as [1, 0, 0, ...] - attn = a(random_ops.random_normal((b, d)), array_ops.one_hot([0]*b, t)) + attn, unused_state = a( + random_ops.random_normal((b, d)), array_ops.one_hot([0]*b, t)) sess.run(variables.global_variables_initializer()) attn_out = attn.eval() # All values should be 0 or 1 @@ -753,6 +776,9 @@ class AttentionWrapperTest(test.TestCase): alignments=( ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125), ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125)), + attention_state=( + ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125), + ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125)), alignment_history=()) expected_final_alignment_history = ( ResultSummary(shape=(3, 5, 8), dtype=dtype('float32'), mean=0.125), @@ -787,6 +813,8 @@ class AttentionWrapperTest(test.TestCase): time=3, alignments=( ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125),), + attention_state=( + ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125),), alignment_history=()) expected_final_alignment_history = ( diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py index e87ef41388..36bfc5685d 100644 --- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py +++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py @@ -61,7 +61,14 @@ _zero_state_tensors = rnn_cell_impl._zero_state_tensors # pylint: disable=prote class AttentionMechanism(object): - pass + + @property + def alignments_size(self): + raise NotImplementedError + + @property + def state_size(self): + raise NotImplementedError def _prepare_memory(memory, memory_sequence_length, check_inner_dims_defined): @@ -161,7 +168,7 @@ class _BaseAttentionMechanism(AttentionMechanism): tensor should be shaped `[batch_size, max_time, ...]`. probability_fn: A `callable`. Converts the score and previous alignments to probabilities. Its signature should be: - `probabilities = probability_fn(score, previous_alignments)`. + `probabilities = probability_fn(score, state)`. memory_sequence_length (optional): Sequence lengths for the batch entries in memory. If provided, the memory tensor rows are masked with zeros for values past the respective sequence lengths. @@ -235,6 +242,10 @@ class _BaseAttentionMechanism(AttentionMechanism): def alignments_size(self): return self._alignments_size + @property + def state_size(self): + return self._alignments_size + def initial_alignments(self, batch_size, dtype): """Creates the initial alignment values for the `AttentionWrapper` class. @@ -254,6 +265,23 @@ class _BaseAttentionMechanism(AttentionMechanism): max_time = self._alignments_size return _zero_state_tensors(max_time, batch_size, dtype) + def initial_state(self, batch_size, dtype): + """Creates the initial state values for the `AttentionWrapper` class. + + This is important for AttentionMechanisms that use the previous alignment + to calculate the alignment at the next time step (e.g. monotonic attention). + + The default behavior is to return the same output as initial_alignments. + + Args: + batch_size: `int32` scalar, the batch_size. + dtype: The `dtype`. + + Returns: + A structure of all-zero tensors with shapes as described by `state_size`. + """ + return self.initial_alignments(batch_size, dtype) + def _luong_score(query, keys, scale): """Implements Luong-style (multiplicative) scoring function. @@ -381,13 +409,13 @@ class LuongAttention(_BaseAttentionMechanism): self._scale = scale self._name = name - def __call__(self, query, previous_alignments): + def __call__(self, query, state): """Score the query based on the keys and values. Args: query: Tensor of dtype matching `self.values` and shape `[batch_size, query_depth]`. - previous_alignments: Tensor of dtype matching `self.values` and shape + state: Tensor of dtype matching `self.values` and shape `[batch_size, alignments_size]` (`alignments_size` is memory's `max_time`). @@ -398,8 +426,9 @@ class LuongAttention(_BaseAttentionMechanism): """ with variable_scope.variable_scope(None, "luong_attention", [query]): score = _luong_score(query, self._keys, self._scale) - alignments = self._probability_fn(score, previous_alignments) - return alignments + alignments = self._probability_fn(score, state) + next_state = alignments + return alignments, next_state def _bahdanau_score(processed_query, keys, normalize): @@ -526,13 +555,13 @@ class BahdanauAttention(_BaseAttentionMechanism): self._normalize = normalize self._name = name - def __call__(self, query, previous_alignments): + def __call__(self, query, state): """Score the query based on the keys and values. Args: query: Tensor of dtype matching `self.values` and shape `[batch_size, query_depth]`. - previous_alignments: Tensor of dtype matching `self.values` and shape + state: Tensor of dtype matching `self.values` and shape `[batch_size, alignments_size]` (`alignments_size` is memory's `max_time`). @@ -544,8 +573,9 @@ class BahdanauAttention(_BaseAttentionMechanism): with variable_scope.variable_scope(None, "bahdanau_attention", [query]): processed_query = self.query_layer(query) if self.query_layer else query score = _bahdanau_score(processed_query, self._keys, self._normalize) - alignments = self._probability_fn(score, previous_alignments) - return alignments + alignments = self._probability_fn(score, state) + next_state = alignments + return alignments, next_state def safe_cumprod(x, *args, **kwargs): @@ -805,13 +835,13 @@ class BahdanauMonotonicAttention(_BaseMonotonicAttentionMechanism): self._name = name self._score_bias_init = score_bias_init - def __call__(self, query, previous_alignments): + def __call__(self, query, state): """Score the query based on the keys and values. Args: query: Tensor of dtype matching `self.values` and shape `[batch_size, query_depth]`. - previous_alignments: Tensor of dtype matching `self.values` and shape + state: Tensor of dtype matching `self.values` and shape `[batch_size, alignments_size]` (`alignments_size` is memory's `max_time`). @@ -828,8 +858,9 @@ class BahdanauMonotonicAttention(_BaseMonotonicAttentionMechanism): "attention_score_bias", dtype=processed_query.dtype, initializer=self._score_bias_init) score += score_bias - alignments = self._probability_fn(score, previous_alignments) - return alignments + alignments = self._probability_fn(score, state) + next_state = alignments + return alignments, next_state class LuongMonotonicAttention(_BaseMonotonicAttentionMechanism): @@ -906,13 +937,13 @@ class LuongMonotonicAttention(_BaseMonotonicAttentionMechanism): self._score_bias_init = score_bias_init self._name = name - def __call__(self, query, previous_alignments): + def __call__(self, query, state): """Score the query based on the keys and values. Args: query: Tensor of dtype matching `self.values` and shape `[batch_size, query_depth]`. - previous_alignments: Tensor of dtype matching `self.values` and shape + state: Tensor of dtype matching `self.values` and shape `[batch_size, alignments_size]` (`alignments_size` is memory's `max_time`). @@ -928,14 +959,15 @@ class LuongMonotonicAttention(_BaseMonotonicAttentionMechanism): "attention_score_bias", dtype=query.dtype, initializer=self._score_bias_init) score += score_bias - alignments = self._probability_fn(score, previous_alignments) - return alignments + alignments = self._probability_fn(score, state) + next_state = alignments + return alignments, next_state class AttentionWrapperState( collections.namedtuple("AttentionWrapperState", ("cell_state", "attention", "time", "alignments", - "alignment_history"))): + "alignment_history", "attention_state"))): """`namedtuple` storing the state of a `AttentionWrapper`. Contains: @@ -949,6 +981,9 @@ class AttentionWrapperState( - `alignment_history`: (if enabled) a single or tuple of `TensorArray`(s) containing alignment matrices from all time steps for each attention mechanism. Call `stack()` on each to convert to a `Tensor`. + - `attention_state`: A single or tuple of nested objects + containing attention mechanism state for each attention mechanism. + The objects may contain Tensors or TensorArrays. """ def clone(self, **kwargs): @@ -993,11 +1028,11 @@ def hardmax(logits, name=None): math_ops.argmax(logits, -1), depth, dtype=logits.dtype) -def _compute_attention(attention_mechanism, cell_output, previous_alignments, +def _compute_attention(attention_mechanism, cell_output, attention_state, attention_layer): """Computes the attention and alignments for a given attention_mechanism.""" - alignments = attention_mechanism( - cell_output, previous_alignments=previous_alignments) + alignments, next_attention_state = attention_mechanism( + cell_output, state=attention_state) # Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time] expanded_alignments = array_ops.expand_dims(alignments, 1) @@ -1018,7 +1053,7 @@ def _compute_attention(attention_mechanism, cell_output, previous_alignments, else: attention = context - return attention, alignments + return attention, alignments, next_attention_state class AttentionWrapper(rnn_cell_impl.RNNCell): @@ -1229,6 +1264,8 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): attention=self._attention_layer_size, alignments=self._item_or_tuple( a.alignments_size for a in self._attention_mechanisms), + attention_state=self._item_or_tuple( + a.state_size for a in self._attention_mechanisms), alignment_history=self._item_or_tuple( () for _ in self._attention_mechanisms)) # sometimes a TensorArray @@ -1278,6 +1315,9 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): alignments=self._item_or_tuple( attention_mechanism.initial_alignments(batch_size, dtype) for attention_mechanism in self._attention_mechanisms), + attention_state=self._item_or_tuple( + attention_mechanism.initial_state(batch_size, dtype) + for attention_mechanism in self._attention_mechanisms), alignment_history=self._item_or_tuple( tensor_array_ops.TensorArray(dtype=dtype, size=0, dynamic_size=True) @@ -1339,33 +1379,36 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): cell_output, name="checked_cell_output") if self._is_multi: - previous_alignments = state.alignments + previous_attention_state = state.attention_state previous_alignment_history = state.alignment_history else: - previous_alignments = [state.alignments] + previous_attention_state = [state.attention_state] previous_alignment_history = [state.alignment_history] all_alignments = [] all_attentions = [] - all_histories = [] + all_attention_states = [] + maybe_all_histories = [] for i, attention_mechanism in enumerate(self._attention_mechanisms): - attention, alignments = _compute_attention( - attention_mechanism, cell_output, previous_alignments[i], + attention, alignments, next_attention_state = _compute_attention( + attention_mechanism, cell_output, previous_attention_state[i], self._attention_layers[i] if self._attention_layers else None) alignment_history = previous_alignment_history[i].write( state.time, alignments) if self._alignment_history else () + all_attention_states.append(next_attention_state) all_alignments.append(alignments) - all_histories.append(alignment_history) all_attentions.append(attention) + maybe_all_histories.append(alignment_history) attention = array_ops.concat(all_attentions, 1) next_state = AttentionWrapperState( time=state.time + 1, cell_state=next_cell_state, attention=attention, + attention_state=self._item_or_tuple(all_attention_states), alignments=self._item_or_tuple(all_alignments), - alignment_history=self._item_or_tuple(all_histories)) + alignment_history=self._item_or_tuple(maybe_all_histories)) if self._output_attention: return attention, next_state |