aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Adam Roberts <adarob@google.com>2017-07-19 17:08:07 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-19 17:12:05 -0700
commit965ae3e34e5e5dd33a4910fdf3a31d0f7a250ac1 (patch)
tree4803b646dffa9d98e1d58a5ccd59ff420ece9762
parent1ee98618dd74404e26ee0202f27e05cb3dcef5c3 (diff)
Add multi-head attention capabilities to AttentionWrapper via the specification of multiple AttentionMechanisms.
PiperOrigin-RevId: 162557562
-rw-r--r--tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py172
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py216
2 files changed, 307 insertions, 81 deletions
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
index 8401918128..91493302b1 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
@@ -88,6 +88,30 @@ class AttentionWrapperTest(test.TestCase):
expected_final_alignment_history=None,
attention_layer_size=6,
name=''):
+ self._testWithMaybeMultiAttention(
+ is_multi=False,
+ create_attention_mechanisms=[create_attention_mechanism],
+ expected_final_output=expected_final_output,
+ expected_final_state=expected_final_state,
+ attention_mechanism_depths=[attention_mechanism_depth],
+ alignment_history=alignment_history,
+ expected_final_alignment_history=expected_final_alignment_history,
+ attention_layer_sizes=[attention_layer_size],
+ name=name)
+
+ def _testWithMaybeMultiAttention(self,
+ is_multi,
+ create_attention_mechanisms,
+ expected_final_output,
+ expected_final_state,
+ attention_mechanism_depths,
+ alignment_history=False,
+ expected_final_alignment_history=None,
+ attention_layer_sizes=None,
+ name=''):
+ # Allow is_multi to be True with a single mechanism to enable test for
+ # passing in a single mechanism in a list.
+ assert len(create_attention_mechanisms) == 1 or is_multi
encoder_sequence_length = [3, 2, 3, 1, 1]
decoder_sequence_length = [2, 0, 1, 2, 3]
batch_size = 5
@@ -97,10 +121,12 @@ class AttentionWrapperTest(test.TestCase):
encoder_output_depth = 10
cell_depth = 9
- if attention_layer_size is not None:
- attention_depth = attention_layer_size
+ if attention_layer_sizes is None:
+ attention_depth = encoder_output_depth * len(create_attention_mechanisms)
else:
- attention_depth = encoder_output_depth
+ # Compute sum of attention_layer_sizes. Use encoder_output_depth if None.
+ attention_depth = sum([attention_layer_size or encoder_output_depth
+ for attention_layer_size in attention_layer_sizes])
decoder_inputs = array_ops.placeholder_with_default(
np.random.randn(batch_size, decoder_max_time,
@@ -111,10 +137,12 @@ class AttentionWrapperTest(test.TestCase):
encoder_output_depth).astype(np.float32),
shape=(None, None, encoder_output_depth))
- attention_mechanism = create_attention_mechanism(
- num_units=attention_mechanism_depth,
- memory=encoder_outputs,
- memory_sequence_length=encoder_sequence_length)
+ attention_mechanisms = [
+ creator(num_units=depth,
+ memory=encoder_outputs,
+ memory_sequence_length=encoder_sequence_length)
+ for creator, depth in zip(create_attention_mechanisms,
+ attention_mechanism_depths)]
with self.test_session(use_gpu=True) as sess:
with vs.variable_scope(
@@ -123,8 +151,9 @@ class AttentionWrapperTest(test.TestCase):
cell = rnn_cell.LSTMCell(cell_depth)
cell = wrapper.AttentionWrapper(
cell,
- attention_mechanism,
- attention_layer_size=attention_layer_size,
+ attention_mechanisms if is_multi else attention_mechanisms[0],
+ attention_layer_size=(attention_layer_sizes if is_multi
+ else attention_layer_sizes[0]),
alignment_history=alignment_history)
helper = helper_py.TrainingHelper(decoder_inputs,
decoder_sequence_length)
@@ -156,12 +185,23 @@ class AttentionWrapperTest(test.TestCase):
tuple(final_state.cell_state.h.get_shape().as_list()))
if alignment_history:
- state_alignment_history = final_state.alignment_history.stack()
+ if is_multi:
+ state_alignment_history = []
+ for history_array in final_state.alignment_history:
+ history = history_array.stack()
+ self.assertEqual(
+ (None, batch_size, None),
+ tuple(history.get_shape().as_list()))
+ state_alignment_history.append(history)
+ state_alignment_history = tuple(state_alignment_history)
+ else:
+ state_alignment_history = final_state.alignment_history.stack()
+ self.assertEqual(
+ (None, batch_size, None),
+ tuple(state_alignment_history.get_shape().as_list()))
# Remove the history from final_state for purposes of the
# remainder of the tests.
final_state = final_state._replace(alignment_history=()) # pylint: disable=protected-access
- self.assertEqual((None, batch_size, None),
- tuple(state_alignment_history.get_shape().as_list()))
else:
state_alignment_history = ()
@@ -617,6 +657,114 @@ class AttentionWrapperTest(test.TestCase):
expected_final_alignment_history=expected_final_alignment_history,
name='testLuongMonotonicScaled')
+ def testMultiAttention(self):
+ create_attention_mechanisms = (
+ wrapper.BahdanauAttention, wrapper.LuongAttention)
+
+ expected_final_output = BasicDecoderOutput(
+ rnn_output=ResultSummary(
+ shape=(5, 3, 7), dtype=dtype('float32'), mean=0.0011709079),
+ sample_id=ResultSummary(
+ shape=(5, 3), dtype=dtype('int32'), mean=3.2000000000000002))
+ expected_final_state = AttentionWrapperState(
+ cell_state=LSTMStateTuple(
+ c=ResultSummary(
+ shape=(5, 9), dtype=dtype('float32'), mean=-0.0038725811),
+ h=ResultSummary(
+ shape=(5, 9), dtype=dtype('float32'), mean=-0.0019329828)),
+ attention=ResultSummary(
+ shape=(5, 7), dtype=dtype('float32'), mean=0.001174294),
+ time=3,
+ alignments=(
+ ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125),
+ ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125)),
+ alignment_history=())
+
+ expected_final_alignment_history = (
+ ResultSummary(shape=(3, 5, 8), dtype=dtype('float32'), mean=0.125),
+ ResultSummary(shape=(3, 5, 8), dtype=dtype('float32'), mean=0.125))
+
+ self._testWithMaybeMultiAttention(
+ True,
+ create_attention_mechanisms,
+ expected_final_output,
+ expected_final_state,
+ attention_mechanism_depths=[9, 9],
+ attention_layer_sizes=[3, 4],
+ alignment_history=True,
+ expected_final_alignment_history=expected_final_alignment_history,
+ name='testMultiAttention')
+
+ def testMultiAttentionNoAttentionLayer(self):
+ create_attention_mechanisms = (
+ wrapper.BahdanauAttention, wrapper.LuongAttention)
+
+ expected_final_output = BasicDecoderOutput(
+ rnn_output=ResultSummary(
+ shape=(5, 3, 20), dtype=dtype('float32'), mean=0.11691988),
+ sample_id=ResultSummary(
+ shape=(5, 3), dtype=dtype('int32'), mean=7.2666666666666666))
+ expected_final_state = AttentionWrapperState(
+ cell_state=LSTMStateTuple(
+ c=ResultSummary(
+ shape=(5, 9), dtype=dtype('float32'), mean=-0.0036486709),
+ h=ResultSummary(
+ shape=(5, 9), dtype=dtype('float32'), mean=-0.0018835809)),
+ attention=ResultSummary(
+ shape=(5, 20), dtype=dtype('float32'), mean=0.11680689),
+ time=3,
+ alignments=(
+ ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125),
+ ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125)),
+ alignment_history=())
+ expected_final_alignment_history = (
+ ResultSummary(shape=(3, 5, 8), dtype=dtype('float32'), mean=0.125),
+ ResultSummary(shape=(3, 5, 8), dtype=dtype('float32'), mean=0.125))
+
+ self._testWithMaybeMultiAttention(
+ is_multi=True,
+ create_attention_mechanisms=create_attention_mechanisms,
+ expected_final_output=expected_final_output,
+ expected_final_state=expected_final_state,
+ attention_mechanism_depths=[9, 9],
+ alignment_history=True,
+ expected_final_alignment_history=expected_final_alignment_history,
+ name='testMultiAttention')
+
+ def testSingleAttentionAsList(self):
+ create_attention_mechanisms = [wrapper.BahdanauAttention]
+
+ expected_final_output = BasicDecoderOutput(
+ rnn_output=ResultSummary(
+ shape=(5, 3, 3), dtype=dtype('float32'), mean=-0.0098485695),
+ sample_id=ResultSummary(
+ shape=(5, 3), dtype=dtype('int32'), mean=1.8))
+ expected_final_state = AttentionWrapperState(
+ cell_state=LSTMStateTuple(
+ c=ResultSummary(
+ shape=(5, 9), dtype=dtype('float32'), mean=-0.0040023471),
+ h=ResultSummary(
+ shape=(5, 9), dtype=dtype('float32'), mean=-0.0019979973)),
+ attention=ResultSummary(
+ shape=(5, 3), dtype=dtype('float32'), mean=-0.0098808752),
+ time=3,
+ alignments=(
+ ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125),),
+ alignment_history=())
+
+ expected_final_alignment_history = (
+ ResultSummary(shape=(3, 5, 8), dtype=dtype('float32'), mean=0.125),)
+
+ self._testWithMaybeMultiAttention(
+ is_multi=True, # pass the AttentionMechanism wrapped in a list
+ create_attention_mechanisms=create_attention_mechanisms,
+ expected_final_output=expected_final_output,
+ expected_final_state=expected_final_state,
+ attention_mechanism_depths=[9],
+ attention_layer_sizes=[3],
+ alignment_history=True,
+ expected_final_alignment_history=expected_final_alignment_history,
+ name='testMultiAttention')
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
index 9c6939bb46..a162a919cf 100644
--- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
+++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
@@ -917,9 +917,11 @@ class AttentionWrapperState(
step.
- `attention`: The attention emitted at the previous time step.
- `time`: int32 scalar containing the current time step.
- - `alignments`: The alignment emitted at the previous time step.
- - `alignment_history`: (if enabled) a `TensorArray` containing alignment
- matrices from all time steps. Call `stack()` to convert to a `Tensor`.
+ - `alignments`: A single or tuple of `Tensor`(s) containing the alignments
+ emitted at the previous time step for each attention mechanism.
+ - `alignment_history`: (if enabled) a single or tuple of `TensorArray`(s)
+ containing alignment matrices from all time steps for each attention
+ mechanism. Call `stack()` on each to convert to a `Tensor`.
"""
def clone(self, **kwargs):
@@ -964,6 +966,34 @@ def hardmax(logits, name=None):
math_ops.argmax(logits, -1), depth, dtype=logits.dtype)
+def _compute_attention(attention_mechanism, cell_output, previous_alignments,
+ attention_layer):
+ """Computes the attention and alignments for a given attention_mechanism."""
+ alignments = attention_mechanism(
+ cell_output, previous_alignments=previous_alignments)
+
+ # Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time]
+ expanded_alignments = array_ops.expand_dims(alignments, 1)
+ # Context is the inner product of alignments and values along the
+ # memory time dimension.
+ # alignments shape is
+ # [batch_size, 1, memory_time]
+ # attention_mechanism.values shape is
+ # [batch_size, memory_time, attention_mechanism.num_units]
+ # the batched matmul is over memory_time, so the output shape is
+ # [batch_size, 1, attention_mechanism.num_units].
+ # we then squeeze out the singleton dim.
+ context = math_ops.matmul(expanded_alignments, attention_mechanism.values)
+ context = array_ops.squeeze(context, [1])
+
+ if attention_layer is not None:
+ attention = attention_layer(array_ops.concat([cell_output, context], 1))
+ else:
+ attention = context
+
+ return attention, alignments
+
+
class AttentionWrapper(rnn_cell_impl.RNNCell):
"""Wraps another `RNNCell` with attention.
"""
@@ -981,11 +1011,14 @@ class AttentionWrapper(rnn_cell_impl.RNNCell):
Args:
cell: An instance of `RNNCell`.
- attention_mechanism: An instance of `AttentionMechanism`.
- attention_layer_size: Python integer, the depth of the attention (output)
- layer. If None (default), use the context as attention at each time
- step. Otherwise, feed the context and cell output into the attention
- layer to generate attention at each time step.
+ attention_mechanism: A list of `AttentionMechanism` instances or a single
+ instance.
+ attention_layer_size: A list of Python integers or a single Python
+ integer, the depth of the attention (output) layer(s). If None
+ (default), use the context as attention at each time step. Otherwise,
+ feed the context and cell output into the attention layer to generate
+ attention at each time step. If attention_mechanism is a list,
+ attention_layer_size must be a list of the same length.
alignment_history: Python boolean, whether to store alignment history
from all time steps in the final output state (currently stored as a
time major `TensorArray` on which you must call `stack()`).
@@ -1005,15 +1038,35 @@ class AttentionWrapper(rnn_cell_impl.RNNCell):
does not match the batch size of `initial_cell_state`, proper
behavior is not guaranteed.
name: Name to use when creating ops.
+
+ Raises:
+ TypeError: `attention_layer_size` is not None and (`attention_mechanism`
+ is a list but `attention_layer_size` is not; or vice versa).
+ ValueError: if `attention_layer_size` is not None, `attention_mechanism`
+ is a list, and its length does not match that of `attention_layer_size`.
"""
super(AttentionWrapper, self).__init__(name=name)
if not rnn_cell_impl._like_rnncell(cell): # pylint: disable=protected-access
raise TypeError(
"cell must be an RNNCell, saw type: %s" % type(cell).__name__)
- if not isinstance(attention_mechanism, AttentionMechanism):
- raise TypeError(
- "attention_mechanism must be a AttentionMechanism, saw type: %s"
- % type(attention_mechanism).__name__)
+ if isinstance(attention_mechanism, (list, tuple)):
+ self._is_multi = True
+ attention_mechanisms = attention_mechanism
+ for attention_mechanism in attention_mechanisms:
+ if not isinstance(attention_mechanism, AttentionMechanism):
+ raise TypeError(
+ "attention_mechanism must contain only instances of "
+ "AttentionMechanism, saw type: %s"
+ % type(attention_mechanism).__name__)
+ else:
+ self._is_multi = False
+ if not isinstance(attention_mechanism, AttentionMechanism):
+ raise TypeError(
+ "attention_mechanism must be an AttentionMechanism or list of "
+ "multiple AttentionMechanism instances, saw type: %s"
+ % type(attention_mechanism).__name__)
+ attention_mechanisms = (attention_mechanism,)
+
if cell_input_fn is None:
cell_input_fn = (
lambda inputs, attention: array_ops.concat([inputs, attention], -1))
@@ -1024,16 +1077,28 @@ class AttentionWrapper(rnn_cell_impl.RNNCell):
% type(cell_input_fn).__name__)
if attention_layer_size is not None:
- self._attention_layer = layers_core.Dense(
- attention_layer_size, name="attention_layer", use_bias=False)
- self._attention_layer_size = attention_layer_size
+ attention_layer_sizes = tuple(
+ attention_layer_size
+ if isinstance(attention_layer_size, (list, tuple))
+ else (attention_layer_size,))
+ if len(attention_layer_sizes) != len(attention_mechanisms):
+ raise ValueError(
+ "If provided, attention_layer_size must contain exactly one "
+ "integer per attention_mechanism, saw: %d vs %d"
+ % (len(attention_layer_sizes), len(attention_mechanisms)))
+ self._attention_layers = tuple(
+ layers_core.Dense(
+ attention_layer_size, name="attention_layer", use_bias=False)
+ for attention_layer_size in attention_layer_sizes)
+ self._attention_layer_size = sum(attention_layer_sizes)
else:
- self._attention_layer = None
- self._attention_layer_size = attention_mechanism.values.get_shape()[
- -1].value
+ self._attention_layers = None
+ self._attention_layer_size = sum(
+ attention_mechanism.values.get_shape()[-1].value
+ for attention_mechanism in attention_mechanisms)
self._cell = cell
- self._attention_mechanism = attention_mechanism
+ self._attention_mechanisms = attention_mechanisms
self._cell_input_fn = cell_input_fn
self._output_attention = output_attention
self._alignment_history = alignment_history
@@ -1053,13 +1118,36 @@ class AttentionWrapper(rnn_cell_impl.RNNCell):
"via the tf.contrib.seq2seq.tile_batch function with argument "
"multiple=beam_width.")
with ops.control_dependencies(
- [check_ops.assert_equal(state_batch_size,
- self._attention_mechanism.batch_size,
- message=error_message)]):
+ self._batch_size_checks(state_batch_size, error_message)):
self._initial_cell_state = nest.map_structure(
lambda s: array_ops.identity(s, name="check_initial_cell_state"),
initial_cell_state)
+ def _batch_size_checks(self, batch_size, error_message):
+ return [check_ops.assert_equal(batch_size,
+ attention_mechanism.batch_size,
+ message=error_message)
+ for attention_mechanism in self._attention_mechanisms]
+
+ def _item_or_tuple(self, seq):
+ """Returns `seq` as tuple or the singular element.
+
+ Which is returned is determined by how the AttentionMechanism(s) were passed
+ to the constructor.
+
+ Args:
+ seq: A non-empty sequence of items or generator.
+
+ Returns:
+ Either the values in the sequence as a tuple if AttentionMechanism(s)
+ were passed to the constructor as a sequence or the singular element.
+ """
+ t = tuple(seq)
+ if self._is_multi:
+ return t
+ else:
+ return t[0]
+
@property
def output_size(self):
if self._output_attention:
@@ -1073,8 +1161,10 @@ class AttentionWrapper(rnn_cell_impl.RNNCell):
cell_state=self._cell.state_size,
time=tensor_shape.TensorShape([]),
attention=self._attention_layer_size,
- alignments=self._attention_mechanism.alignments_size,
- alignment_history=()) # alignment_history is sometimes a TensorArray
+ alignments=self._item_or_tuple(
+ a.alignments_size for a in self._attention_mechanisms),
+ alignment_history=self._item_or_tuple(
+ () for _ in self._attention_mechanisms)) # sometimes a TensorArray
def zero_state(self, batch_size, dtype):
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
@@ -1091,25 +1181,23 @@ class AttentionWrapper(rnn_cell_impl.RNNCell):
"the batch_size= argument passed to zero_state is "
"batch_size * beam_width.")
with ops.control_dependencies(
- [check_ops.assert_equal(batch_size,
- self._attention_mechanism.batch_size,
- message=error_message)]):
+ self._batch_size_checks(batch_size, error_message)):
cell_state = nest.map_structure(
lambda s: array_ops.identity(s, name="checked_cell_state"),
cell_state)
- if self._alignment_history:
- alignment_history = tensor_array_ops.TensorArray(
- dtype=dtype, size=0, dynamic_size=True)
- else:
- alignment_history = ()
return AttentionWrapperState(
cell_state=cell_state,
time=array_ops.zeros([], dtype=dtypes.int32),
attention=_zero_state_tensors(self._attention_layer_size, batch_size,
dtype),
- alignments=self._attention_mechanism.initial_alignments(
- batch_size, dtype),
- alignment_history=alignment_history)
+ alignments=self._item_or_tuple(
+ attention_mechanism.initial_alignments(batch_size, dtype)
+ for attention_mechanism in self._attention_mechanisms),
+ alignment_history=self._item_or_tuple(
+ tensor_array_ops.TensorArray(dtype=dtype, size=0,
+ dynamic_size=True)
+ if self._alignment_history else ()
+ for _ in self._attention_mechanisms))
def call(self, inputs, state):
"""Perform a step of attention-wrapped RNN.
@@ -1161,48 +1249,38 @@ class AttentionWrapper(rnn_cell_impl.RNNCell):
"the tf.contrib.seq2seq.tile_batch function with argument "
"multiple=beam_width.")
with ops.control_dependencies(
- [check_ops.assert_equal(cell_batch_size,
- self._attention_mechanism.batch_size,
- message=error_message)]):
+ self._batch_size_checks(cell_batch_size, error_message)):
cell_output = array_ops.identity(
cell_output, name="checked_cell_output")
- alignments = self._attention_mechanism(
- cell_output, previous_alignments=state.alignments)
-
- # Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time]
- expanded_alignments = array_ops.expand_dims(alignments, 1)
- # Context is the inner product of alignments and values along the
- # memory time dimension.
- # alignments shape is
- # [batch_size, 1, memory_time]
- # attention_mechanism.values shape is
- # [batch_size, memory_time, attention_mechanism.num_units]
- # the batched matmul is over memory_time, so the output shape is
- # [batch_size, 1, attention_mechanism.num_units].
- # we then squeeze out the singleton dim.
- attention_mechanism_values = self._attention_mechanism.values
- context = math_ops.matmul(expanded_alignments, attention_mechanism_values)
- context = array_ops.squeeze(context, [1])
-
- if self._attention_layer is not None:
- attention = self._attention_layer(
- array_ops.concat([cell_output, context], 1))
+ if self._is_multi:
+ previous_alignments = state.alignments
+ previous_alignment_history = state.alignment_history
else:
- attention = context
-
- if self._alignment_history:
- alignment_history = state.alignment_history.write(
- state.time, alignments)
- else:
- alignment_history = ()
-
+ previous_alignments = [state.alignments]
+ previous_alignment_history = [state.alignment_history]
+
+ all_alignments = []
+ all_attentions = []
+ all_histories = []
+ for i, attention_mechanism in enumerate(self._attention_mechanisms):
+ attention, alignments = _compute_attention(
+ attention_mechanism, cell_output, previous_alignments[i],
+ self._attention_layers[i] if self._attention_layers else None)
+ alignment_history = previous_alignment_history[i].write(
+ state.time, alignments) if self._alignment_history else ()
+
+ all_alignments.append(alignments)
+ all_histories.append(alignment_history)
+ all_attentions.append(attention)
+
+ attention = array_ops.concat(all_attentions, 1)
next_state = AttentionWrapperState(
time=state.time + 1,
cell_state=next_cell_state,
attention=attention,
- alignments=alignments,
- alignment_history=alignment_history)
+ alignments=self._item_or_tuple(all_alignments),
+ alignment_history=self._item_or_tuple(all_histories))
if self._output_attention:
return attention, next_state