aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2017-12-15 16:06:31 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-15 16:13:24 -0800
commitf3df9fcaefeb3ab0fd83f255bec93e1a3c013a5e (patch)
treebb51b78abdaa81e41b5ee0d6f27ed96658221835
parentdeb50f80e5a87325adec5db826673e05acb4f5ab (diff)
[tf.contrib.seq2seq] Modify AttentionMechanisms to propagate state.
By default, the state is just the previous alignment. This allows for more complex attention mechanisms (upcoming). PiperOrigin-RevId: 179251889
-rw-r--r--tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py32
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py103
2 files changed, 103 insertions, 32 deletions
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
index 01a5540121..e5d591788f 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
@@ -254,6 +254,8 @@ class AttentionWrapperTest(test.TestCase):
time=3,
alignments=ResultSummary(
shape=(5, 8), dtype=dtype('float32'), mean=0.125),
+ attention_state=ResultSummary(
+ shape=(5, 8), dtype=dtype('float32'), mean=0.125),
alignment_history=())
expected_final_alignment_history = ResultSummary(
shape=(3, 5, 8), dtype=dtype('float32'), mean=0.12500001)
@@ -286,6 +288,8 @@ class AttentionWrapperTest(test.TestCase):
time=3,
alignments=ResultSummary(
shape=(5, 8), dtype=dtype('float32'), mean=0.125),
+ attention_state=ResultSummary(
+ shape=(5, 8), dtype=dtype('float32'), mean=0.125),
alignment_history=())
self._testWithAttention(
@@ -313,6 +317,8 @@ class AttentionWrapperTest(test.TestCase):
time=3,
alignments=ResultSummary(
shape=(5, 8), dtype=dtype('float32'), mean=0.125),
+ attention_state=ResultSummary(
+ shape=(5, 8), dtype=dtype('float32'), mean=0.125),
alignment_history=())
self._testWithAttention(
@@ -342,6 +348,8 @@ class AttentionWrapperTest(test.TestCase):
time=3,
alignments=ResultSummary(
shape=(5, 8), dtype=dtype('float32'), mean=0.125),
+ attention_state=ResultSummary(
+ shape=(5, 8), dtype=dtype('float32'), mean=0.125),
alignment_history=())
self._testWithAttention(
@@ -370,6 +378,8 @@ class AttentionWrapperTest(test.TestCase):
time=3,
alignments=ResultSummary(
shape=(5, 8), dtype=dtype('float32'), mean=0.125),
+ attention_state=ResultSummary(
+ shape=(5, 8), dtype=dtype('float32'), mean=0.125),
alignment_history=())
self._testWithAttention(
@@ -545,6 +555,8 @@ class AttentionWrapperTest(test.TestCase):
time=3,
alignments=ResultSummary(
shape=(5, 8), dtype=dtype('float32'), mean=0.032228071),
+ attention_state=ResultSummary(
+ shape=(5, 8), dtype=dtype('float32'), mean=0.032228071),
alignment_history=())
expected_final_alignment_history = ResultSummary(
shape=(3, 5, 8), dtype=dtype('float32'), mean=0.050430927)
@@ -578,6 +590,8 @@ class AttentionWrapperTest(test.TestCase):
time=3,
alignments=ResultSummary(
shape=(5, 8), dtype=dtype('float32'), mean=0.028698336),
+ attention_state=ResultSummary(
+ shape=(5, 8), dtype=dtype('float32'), mean=0.028698336),
alignment_history=())
expected_final_alignment_history = ResultSummary(
shape=(3, 5, 8), dtype=dtype('float32'), mean=0.046009291)
@@ -599,7 +613,8 @@ class AttentionWrapperTest(test.TestCase):
random_ops.random_normal((b, t, u)),
mode='hard')
# Just feed previous attention as [1, 0, 0, ...]
- attn = a(random_ops.random_normal((b, d)), array_ops.one_hot([0]*b, t))
+ attn, unused_state = a(
+ random_ops.random_normal((b, d)), array_ops.one_hot([0]*b, t))
sess.run(variables.global_variables_initializer())
attn_out = attn.eval()
# All values should be 0 or 1
@@ -629,6 +644,8 @@ class AttentionWrapperTest(test.TestCase):
time=3,
alignments=ResultSummary(
shape=(5, 8), dtype=dtype('float32'), mean=0.032198936),
+ attention_state=ResultSummary(
+ shape=(5, 8), dtype=dtype('float32'), mean=0.032198936),
alignment_history=())
expected_final_alignment_history = ResultSummary(
shape=(3, 5, 8), dtype=dtype('float32'), mean=0.050387777)
@@ -663,6 +680,8 @@ class AttentionWrapperTest(test.TestCase):
time=3,
alignments=ResultSummary(
shape=(5, 8), dtype=dtype('float32'), mean=0.032198936),
+ attention_state=ResultSummary(
+ shape=(5, 8), dtype=dtype('float32'), mean=0.032198936),
alignment_history=())
expected_final_alignment_history = ResultSummary(
shape=(3, 5, 8), dtype=dtype('float32'), mean=0.050387777)
@@ -697,6 +716,9 @@ class AttentionWrapperTest(test.TestCase):
alignments=(
ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125),
ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125)),
+ attention_state=(
+ ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125),
+ ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125)),
alignment_history=())
expected_final_alignment_history = (
@@ -723,7 +745,8 @@ class AttentionWrapperTest(test.TestCase):
random_ops.random_normal((b, t, u)),
mode='hard')
# Just feed previous attention as [1, 0, 0, ...]
- attn = a(random_ops.random_normal((b, d)), array_ops.one_hot([0]*b, t))
+ attn, unused_state = a(
+ random_ops.random_normal((b, d)), array_ops.one_hot([0]*b, t))
sess.run(variables.global_variables_initializer())
attn_out = attn.eval()
# All values should be 0 or 1
@@ -753,6 +776,9 @@ class AttentionWrapperTest(test.TestCase):
alignments=(
ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125),
ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125)),
+ attention_state=(
+ ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125),
+ ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125)),
alignment_history=())
expected_final_alignment_history = (
ResultSummary(shape=(3, 5, 8), dtype=dtype('float32'), mean=0.125),
@@ -787,6 +813,8 @@ class AttentionWrapperTest(test.TestCase):
time=3,
alignments=(
ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125),),
+ attention_state=(
+ ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125),),
alignment_history=())
expected_final_alignment_history = (
diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
index e87ef41388..36bfc5685d 100644
--- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
+++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
@@ -61,7 +61,14 @@ _zero_state_tensors = rnn_cell_impl._zero_state_tensors # pylint: disable=prote
class AttentionMechanism(object):
- pass
+
+ @property
+ def alignments_size(self):
+ raise NotImplementedError
+
+ @property
+ def state_size(self):
+ raise NotImplementedError
def _prepare_memory(memory, memory_sequence_length, check_inner_dims_defined):
@@ -161,7 +168,7 @@ class _BaseAttentionMechanism(AttentionMechanism):
tensor should be shaped `[batch_size, max_time, ...]`.
probability_fn: A `callable`. Converts the score and previous alignments
to probabilities. Its signature should be:
- `probabilities = probability_fn(score, previous_alignments)`.
+ `probabilities = probability_fn(score, state)`.
memory_sequence_length (optional): Sequence lengths for the batch entries
in memory. If provided, the memory tensor rows are masked with zeros
for values past the respective sequence lengths.
@@ -235,6 +242,10 @@ class _BaseAttentionMechanism(AttentionMechanism):
def alignments_size(self):
return self._alignments_size
+ @property
+ def state_size(self):
+ return self._alignments_size
+
def initial_alignments(self, batch_size, dtype):
"""Creates the initial alignment values for the `AttentionWrapper` class.
@@ -254,6 +265,23 @@ class _BaseAttentionMechanism(AttentionMechanism):
max_time = self._alignments_size
return _zero_state_tensors(max_time, batch_size, dtype)
+ def initial_state(self, batch_size, dtype):
+ """Creates the initial state values for the `AttentionWrapper` class.
+
+ This is important for AttentionMechanisms that use the previous alignment
+ to calculate the alignment at the next time step (e.g. monotonic attention).
+
+ The default behavior is to return the same output as initial_alignments.
+
+ Args:
+ batch_size: `int32` scalar, the batch_size.
+ dtype: The `dtype`.
+
+ Returns:
+ A structure of all-zero tensors with shapes as described by `state_size`.
+ """
+ return self.initial_alignments(batch_size, dtype)
+
def _luong_score(query, keys, scale):
"""Implements Luong-style (multiplicative) scoring function.
@@ -381,13 +409,13 @@ class LuongAttention(_BaseAttentionMechanism):
self._scale = scale
self._name = name
- def __call__(self, query, previous_alignments):
+ def __call__(self, query, state):
"""Score the query based on the keys and values.
Args:
query: Tensor of dtype matching `self.values` and shape
`[batch_size, query_depth]`.
- previous_alignments: Tensor of dtype matching `self.values` and shape
+ state: Tensor of dtype matching `self.values` and shape
`[batch_size, alignments_size]`
(`alignments_size` is memory's `max_time`).
@@ -398,8 +426,9 @@ class LuongAttention(_BaseAttentionMechanism):
"""
with variable_scope.variable_scope(None, "luong_attention", [query]):
score = _luong_score(query, self._keys, self._scale)
- alignments = self._probability_fn(score, previous_alignments)
- return alignments
+ alignments = self._probability_fn(score, state)
+ next_state = alignments
+ return alignments, next_state
def _bahdanau_score(processed_query, keys, normalize):
@@ -526,13 +555,13 @@ class BahdanauAttention(_BaseAttentionMechanism):
self._normalize = normalize
self._name = name
- def __call__(self, query, previous_alignments):
+ def __call__(self, query, state):
"""Score the query based on the keys and values.
Args:
query: Tensor of dtype matching `self.values` and shape
`[batch_size, query_depth]`.
- previous_alignments: Tensor of dtype matching `self.values` and shape
+ state: Tensor of dtype matching `self.values` and shape
`[batch_size, alignments_size]`
(`alignments_size` is memory's `max_time`).
@@ -544,8 +573,9 @@ class BahdanauAttention(_BaseAttentionMechanism):
with variable_scope.variable_scope(None, "bahdanau_attention", [query]):
processed_query = self.query_layer(query) if self.query_layer else query
score = _bahdanau_score(processed_query, self._keys, self._normalize)
- alignments = self._probability_fn(score, previous_alignments)
- return alignments
+ alignments = self._probability_fn(score, state)
+ next_state = alignments
+ return alignments, next_state
def safe_cumprod(x, *args, **kwargs):
@@ -805,13 +835,13 @@ class BahdanauMonotonicAttention(_BaseMonotonicAttentionMechanism):
self._name = name
self._score_bias_init = score_bias_init
- def __call__(self, query, previous_alignments):
+ def __call__(self, query, state):
"""Score the query based on the keys and values.
Args:
query: Tensor of dtype matching `self.values` and shape
`[batch_size, query_depth]`.
- previous_alignments: Tensor of dtype matching `self.values` and shape
+ state: Tensor of dtype matching `self.values` and shape
`[batch_size, alignments_size]`
(`alignments_size` is memory's `max_time`).
@@ -828,8 +858,9 @@ class BahdanauMonotonicAttention(_BaseMonotonicAttentionMechanism):
"attention_score_bias", dtype=processed_query.dtype,
initializer=self._score_bias_init)
score += score_bias
- alignments = self._probability_fn(score, previous_alignments)
- return alignments
+ alignments = self._probability_fn(score, state)
+ next_state = alignments
+ return alignments, next_state
class LuongMonotonicAttention(_BaseMonotonicAttentionMechanism):
@@ -906,13 +937,13 @@ class LuongMonotonicAttention(_BaseMonotonicAttentionMechanism):
self._score_bias_init = score_bias_init
self._name = name
- def __call__(self, query, previous_alignments):
+ def __call__(self, query, state):
"""Score the query based on the keys and values.
Args:
query: Tensor of dtype matching `self.values` and shape
`[batch_size, query_depth]`.
- previous_alignments: Tensor of dtype matching `self.values` and shape
+ state: Tensor of dtype matching `self.values` and shape
`[batch_size, alignments_size]`
(`alignments_size` is memory's `max_time`).
@@ -928,14 +959,15 @@ class LuongMonotonicAttention(_BaseMonotonicAttentionMechanism):
"attention_score_bias", dtype=query.dtype,
initializer=self._score_bias_init)
score += score_bias
- alignments = self._probability_fn(score, previous_alignments)
- return alignments
+ alignments = self._probability_fn(score, state)
+ next_state = alignments
+ return alignments, next_state
class AttentionWrapperState(
collections.namedtuple("AttentionWrapperState",
("cell_state", "attention", "time", "alignments",
- "alignment_history"))):
+ "alignment_history", "attention_state"))):
"""`namedtuple` storing the state of a `AttentionWrapper`.
Contains:
@@ -949,6 +981,9 @@ class AttentionWrapperState(
- `alignment_history`: (if enabled) a single or tuple of `TensorArray`(s)
containing alignment matrices from all time steps for each attention
mechanism. Call `stack()` on each to convert to a `Tensor`.
+ - `attention_state`: A single or tuple of nested objects
+ containing attention mechanism state for each attention mechanism.
+ The objects may contain Tensors or TensorArrays.
"""
def clone(self, **kwargs):
@@ -993,11 +1028,11 @@ def hardmax(logits, name=None):
math_ops.argmax(logits, -1), depth, dtype=logits.dtype)
-def _compute_attention(attention_mechanism, cell_output, previous_alignments,
+def _compute_attention(attention_mechanism, cell_output, attention_state,
attention_layer):
"""Computes the attention and alignments for a given attention_mechanism."""
- alignments = attention_mechanism(
- cell_output, previous_alignments=previous_alignments)
+ alignments, next_attention_state = attention_mechanism(
+ cell_output, state=attention_state)
# Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time]
expanded_alignments = array_ops.expand_dims(alignments, 1)
@@ -1018,7 +1053,7 @@ def _compute_attention(attention_mechanism, cell_output, previous_alignments,
else:
attention = context
- return attention, alignments
+ return attention, alignments, next_attention_state
class AttentionWrapper(rnn_cell_impl.RNNCell):
@@ -1229,6 +1264,8 @@ class AttentionWrapper(rnn_cell_impl.RNNCell):
attention=self._attention_layer_size,
alignments=self._item_or_tuple(
a.alignments_size for a in self._attention_mechanisms),
+ attention_state=self._item_or_tuple(
+ a.state_size for a in self._attention_mechanisms),
alignment_history=self._item_or_tuple(
() for _ in self._attention_mechanisms)) # sometimes a TensorArray
@@ -1278,6 +1315,9 @@ class AttentionWrapper(rnn_cell_impl.RNNCell):
alignments=self._item_or_tuple(
attention_mechanism.initial_alignments(batch_size, dtype)
for attention_mechanism in self._attention_mechanisms),
+ attention_state=self._item_or_tuple(
+ attention_mechanism.initial_state(batch_size, dtype)
+ for attention_mechanism in self._attention_mechanisms),
alignment_history=self._item_or_tuple(
tensor_array_ops.TensorArray(dtype=dtype, size=0,
dynamic_size=True)
@@ -1339,33 +1379,36 @@ class AttentionWrapper(rnn_cell_impl.RNNCell):
cell_output, name="checked_cell_output")
if self._is_multi:
- previous_alignments = state.alignments
+ previous_attention_state = state.attention_state
previous_alignment_history = state.alignment_history
else:
- previous_alignments = [state.alignments]
+ previous_attention_state = [state.attention_state]
previous_alignment_history = [state.alignment_history]
all_alignments = []
all_attentions = []
- all_histories = []
+ all_attention_states = []
+ maybe_all_histories = []
for i, attention_mechanism in enumerate(self._attention_mechanisms):
- attention, alignments = _compute_attention(
- attention_mechanism, cell_output, previous_alignments[i],
+ attention, alignments, next_attention_state = _compute_attention(
+ attention_mechanism, cell_output, previous_attention_state[i],
self._attention_layers[i] if self._attention_layers else None)
alignment_history = previous_alignment_history[i].write(
state.time, alignments) if self._alignment_history else ()
+ all_attention_states.append(next_attention_state)
all_alignments.append(alignments)
- all_histories.append(alignment_history)
all_attentions.append(attention)
+ maybe_all_histories.append(alignment_history)
attention = array_ops.concat(all_attentions, 1)
next_state = AttentionWrapperState(
time=state.time + 1,
cell_state=next_cell_state,
attention=attention,
+ attention_state=self._item_or_tuple(all_attention_states),
alignments=self._item_or_tuple(all_alignments),
- alignment_history=self._item_or_tuple(all_histories))
+ alignment_history=self._item_or_tuple(maybe_all_histories))
if self._output_attention:
return attention, next_state