From 965ae3e34e5e5dd33a4910fdf3a31d0f7a250ac1 Mon Sep 17 00:00:00 2001 From: Adam Roberts Date: Wed, 19 Jul 2017 17:08:07 -0700 Subject: Add multi-head attention capabilities to AttentionWrapper via the specification of multiple AttentionMechanisms. PiperOrigin-RevId: 162557562 --- .../python/kernel_tests/attention_wrapper_test.py | 172 ++++++++++++++-- .../seq2seq/python/ops/attention_wrapper.py | 216 ++++++++++++++------- 2 files changed, 307 insertions(+), 81 deletions(-) diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py index 8401918128..91493302b1 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py @@ -88,6 +88,30 @@ class AttentionWrapperTest(test.TestCase): expected_final_alignment_history=None, attention_layer_size=6, name=''): + self._testWithMaybeMultiAttention( + is_multi=False, + create_attention_mechanisms=[create_attention_mechanism], + expected_final_output=expected_final_output, + expected_final_state=expected_final_state, + attention_mechanism_depths=[attention_mechanism_depth], + alignment_history=alignment_history, + expected_final_alignment_history=expected_final_alignment_history, + attention_layer_sizes=[attention_layer_size], + name=name) + + def _testWithMaybeMultiAttention(self, + is_multi, + create_attention_mechanisms, + expected_final_output, + expected_final_state, + attention_mechanism_depths, + alignment_history=False, + expected_final_alignment_history=None, + attention_layer_sizes=None, + name=''): + # Allow is_multi to be True with a single mechanism to enable test for + # passing in a single mechanism in a list. + assert len(create_attention_mechanisms) == 1 or is_multi encoder_sequence_length = [3, 2, 3, 1, 1] decoder_sequence_length = [2, 0, 1, 2, 3] batch_size = 5 @@ -97,10 +121,12 @@ class AttentionWrapperTest(test.TestCase): encoder_output_depth = 10 cell_depth = 9 - if attention_layer_size is not None: - attention_depth = attention_layer_size + if attention_layer_sizes is None: + attention_depth = encoder_output_depth * len(create_attention_mechanisms) else: - attention_depth = encoder_output_depth + # Compute sum of attention_layer_sizes. Use encoder_output_depth if None. + attention_depth = sum([attention_layer_size or encoder_output_depth + for attention_layer_size in attention_layer_sizes]) decoder_inputs = array_ops.placeholder_with_default( np.random.randn(batch_size, decoder_max_time, @@ -111,10 +137,12 @@ class AttentionWrapperTest(test.TestCase): encoder_output_depth).astype(np.float32), shape=(None, None, encoder_output_depth)) - attention_mechanism = create_attention_mechanism( - num_units=attention_mechanism_depth, - memory=encoder_outputs, - memory_sequence_length=encoder_sequence_length) + attention_mechanisms = [ + creator(num_units=depth, + memory=encoder_outputs, + memory_sequence_length=encoder_sequence_length) + for creator, depth in zip(create_attention_mechanisms, + attention_mechanism_depths)] with self.test_session(use_gpu=True) as sess: with vs.variable_scope( @@ -123,8 +151,9 @@ class AttentionWrapperTest(test.TestCase): cell = rnn_cell.LSTMCell(cell_depth) cell = wrapper.AttentionWrapper( cell, - attention_mechanism, - attention_layer_size=attention_layer_size, + attention_mechanisms if is_multi else attention_mechanisms[0], + attention_layer_size=(attention_layer_sizes if is_multi + else attention_layer_sizes[0]), alignment_history=alignment_history) helper = helper_py.TrainingHelper(decoder_inputs, decoder_sequence_length) @@ -156,12 +185,23 @@ class AttentionWrapperTest(test.TestCase): tuple(final_state.cell_state.h.get_shape().as_list())) if alignment_history: - state_alignment_history = final_state.alignment_history.stack() + if is_multi: + state_alignment_history = [] + for history_array in final_state.alignment_history: + history = history_array.stack() + self.assertEqual( + (None, batch_size, None), + tuple(history.get_shape().as_list())) + state_alignment_history.append(history) + state_alignment_history = tuple(state_alignment_history) + else: + state_alignment_history = final_state.alignment_history.stack() + self.assertEqual( + (None, batch_size, None), + tuple(state_alignment_history.get_shape().as_list())) # Remove the history from final_state for purposes of the # remainder of the tests. final_state = final_state._replace(alignment_history=()) # pylint: disable=protected-access - self.assertEqual((None, batch_size, None), - tuple(state_alignment_history.get_shape().as_list())) else: state_alignment_history = () @@ -617,6 +657,114 @@ class AttentionWrapperTest(test.TestCase): expected_final_alignment_history=expected_final_alignment_history, name='testLuongMonotonicScaled') + def testMultiAttention(self): + create_attention_mechanisms = ( + wrapper.BahdanauAttention, wrapper.LuongAttention) + + expected_final_output = BasicDecoderOutput( + rnn_output=ResultSummary( + shape=(5, 3, 7), dtype=dtype('float32'), mean=0.0011709079), + sample_id=ResultSummary( + shape=(5, 3), dtype=dtype('int32'), mean=3.2000000000000002)) + expected_final_state = AttentionWrapperState( + cell_state=LSTMStateTuple( + c=ResultSummary( + shape=(5, 9), dtype=dtype('float32'), mean=-0.0038725811), + h=ResultSummary( + shape=(5, 9), dtype=dtype('float32'), mean=-0.0019329828)), + attention=ResultSummary( + shape=(5, 7), dtype=dtype('float32'), mean=0.001174294), + time=3, + alignments=( + ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125), + ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125)), + alignment_history=()) + + expected_final_alignment_history = ( + ResultSummary(shape=(3, 5, 8), dtype=dtype('float32'), mean=0.125), + ResultSummary(shape=(3, 5, 8), dtype=dtype('float32'), mean=0.125)) + + self._testWithMaybeMultiAttention( + True, + create_attention_mechanisms, + expected_final_output, + expected_final_state, + attention_mechanism_depths=[9, 9], + attention_layer_sizes=[3, 4], + alignment_history=True, + expected_final_alignment_history=expected_final_alignment_history, + name='testMultiAttention') + + def testMultiAttentionNoAttentionLayer(self): + create_attention_mechanisms = ( + wrapper.BahdanauAttention, wrapper.LuongAttention) + + expected_final_output = BasicDecoderOutput( + rnn_output=ResultSummary( + shape=(5, 3, 20), dtype=dtype('float32'), mean=0.11691988), + sample_id=ResultSummary( + shape=(5, 3), dtype=dtype('int32'), mean=7.2666666666666666)) + expected_final_state = AttentionWrapperState( + cell_state=LSTMStateTuple( + c=ResultSummary( + shape=(5, 9), dtype=dtype('float32'), mean=-0.0036486709), + h=ResultSummary( + shape=(5, 9), dtype=dtype('float32'), mean=-0.0018835809)), + attention=ResultSummary( + shape=(5, 20), dtype=dtype('float32'), mean=0.11680689), + time=3, + alignments=( + ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125), + ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125)), + alignment_history=()) + expected_final_alignment_history = ( + ResultSummary(shape=(3, 5, 8), dtype=dtype('float32'), mean=0.125), + ResultSummary(shape=(3, 5, 8), dtype=dtype('float32'), mean=0.125)) + + self._testWithMaybeMultiAttention( + is_multi=True, + create_attention_mechanisms=create_attention_mechanisms, + expected_final_output=expected_final_output, + expected_final_state=expected_final_state, + attention_mechanism_depths=[9, 9], + alignment_history=True, + expected_final_alignment_history=expected_final_alignment_history, + name='testMultiAttention') + + def testSingleAttentionAsList(self): + create_attention_mechanisms = [wrapper.BahdanauAttention] + + expected_final_output = BasicDecoderOutput( + rnn_output=ResultSummary( + shape=(5, 3, 3), dtype=dtype('float32'), mean=-0.0098485695), + sample_id=ResultSummary( + shape=(5, 3), dtype=dtype('int32'), mean=1.8)) + expected_final_state = AttentionWrapperState( + cell_state=LSTMStateTuple( + c=ResultSummary( + shape=(5, 9), dtype=dtype('float32'), mean=-0.0040023471), + h=ResultSummary( + shape=(5, 9), dtype=dtype('float32'), mean=-0.0019979973)), + attention=ResultSummary( + shape=(5, 3), dtype=dtype('float32'), mean=-0.0098808752), + time=3, + alignments=( + ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125),), + alignment_history=()) + + expected_final_alignment_history = ( + ResultSummary(shape=(3, 5, 8), dtype=dtype('float32'), mean=0.125),) + + self._testWithMaybeMultiAttention( + is_multi=True, # pass the AttentionMechanism wrapped in a list + create_attention_mechanisms=create_attention_mechanisms, + expected_final_output=expected_final_output, + expected_final_state=expected_final_state, + attention_mechanism_depths=[9], + attention_layer_sizes=[3], + alignment_history=True, + expected_final_alignment_history=expected_final_alignment_history, + name='testMultiAttention') if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py index 9c6939bb46..a162a919cf 100644 --- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py +++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py @@ -917,9 +917,11 @@ class AttentionWrapperState( step. - `attention`: The attention emitted at the previous time step. - `time`: int32 scalar containing the current time step. - - `alignments`: The alignment emitted at the previous time step. - - `alignment_history`: (if enabled) a `TensorArray` containing alignment - matrices from all time steps. Call `stack()` to convert to a `Tensor`. + - `alignments`: A single or tuple of `Tensor`(s) containing the alignments + emitted at the previous time step for each attention mechanism. + - `alignment_history`: (if enabled) a single or tuple of `TensorArray`(s) + containing alignment matrices from all time steps for each attention + mechanism. Call `stack()` on each to convert to a `Tensor`. """ def clone(self, **kwargs): @@ -964,6 +966,34 @@ def hardmax(logits, name=None): math_ops.argmax(logits, -1), depth, dtype=logits.dtype) +def _compute_attention(attention_mechanism, cell_output, previous_alignments, + attention_layer): + """Computes the attention and alignments for a given attention_mechanism.""" + alignments = attention_mechanism( + cell_output, previous_alignments=previous_alignments) + + # Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time] + expanded_alignments = array_ops.expand_dims(alignments, 1) + # Context is the inner product of alignments and values along the + # memory time dimension. + # alignments shape is + # [batch_size, 1, memory_time] + # attention_mechanism.values shape is + # [batch_size, memory_time, attention_mechanism.num_units] + # the batched matmul is over memory_time, so the output shape is + # [batch_size, 1, attention_mechanism.num_units]. + # we then squeeze out the singleton dim. + context = math_ops.matmul(expanded_alignments, attention_mechanism.values) + context = array_ops.squeeze(context, [1]) + + if attention_layer is not None: + attention = attention_layer(array_ops.concat([cell_output, context], 1)) + else: + attention = context + + return attention, alignments + + class AttentionWrapper(rnn_cell_impl.RNNCell): """Wraps another `RNNCell` with attention. """ @@ -981,11 +1011,14 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): Args: cell: An instance of `RNNCell`. - attention_mechanism: An instance of `AttentionMechanism`. - attention_layer_size: Python integer, the depth of the attention (output) - layer. If None (default), use the context as attention at each time - step. Otherwise, feed the context and cell output into the attention - layer to generate attention at each time step. + attention_mechanism: A list of `AttentionMechanism` instances or a single + instance. + attention_layer_size: A list of Python integers or a single Python + integer, the depth of the attention (output) layer(s). If None + (default), use the context as attention at each time step. Otherwise, + feed the context and cell output into the attention layer to generate + attention at each time step. If attention_mechanism is a list, + attention_layer_size must be a list of the same length. alignment_history: Python boolean, whether to store alignment history from all time steps in the final output state (currently stored as a time major `TensorArray` on which you must call `stack()`). @@ -1005,15 +1038,35 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): does not match the batch size of `initial_cell_state`, proper behavior is not guaranteed. name: Name to use when creating ops. + + Raises: + TypeError: `attention_layer_size` is not None and (`attention_mechanism` + is a list but `attention_layer_size` is not; or vice versa). + ValueError: if `attention_layer_size` is not None, `attention_mechanism` + is a list, and its length does not match that of `attention_layer_size`. """ super(AttentionWrapper, self).__init__(name=name) if not rnn_cell_impl._like_rnncell(cell): # pylint: disable=protected-access raise TypeError( "cell must be an RNNCell, saw type: %s" % type(cell).__name__) - if not isinstance(attention_mechanism, AttentionMechanism): - raise TypeError( - "attention_mechanism must be a AttentionMechanism, saw type: %s" - % type(attention_mechanism).__name__) + if isinstance(attention_mechanism, (list, tuple)): + self._is_multi = True + attention_mechanisms = attention_mechanism + for attention_mechanism in attention_mechanisms: + if not isinstance(attention_mechanism, AttentionMechanism): + raise TypeError( + "attention_mechanism must contain only instances of " + "AttentionMechanism, saw type: %s" + % type(attention_mechanism).__name__) + else: + self._is_multi = False + if not isinstance(attention_mechanism, AttentionMechanism): + raise TypeError( + "attention_mechanism must be an AttentionMechanism or list of " + "multiple AttentionMechanism instances, saw type: %s" + % type(attention_mechanism).__name__) + attention_mechanisms = (attention_mechanism,) + if cell_input_fn is None: cell_input_fn = ( lambda inputs, attention: array_ops.concat([inputs, attention], -1)) @@ -1024,16 +1077,28 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): % type(cell_input_fn).__name__) if attention_layer_size is not None: - self._attention_layer = layers_core.Dense( - attention_layer_size, name="attention_layer", use_bias=False) - self._attention_layer_size = attention_layer_size + attention_layer_sizes = tuple( + attention_layer_size + if isinstance(attention_layer_size, (list, tuple)) + else (attention_layer_size,)) + if len(attention_layer_sizes) != len(attention_mechanisms): + raise ValueError( + "If provided, attention_layer_size must contain exactly one " + "integer per attention_mechanism, saw: %d vs %d" + % (len(attention_layer_sizes), len(attention_mechanisms))) + self._attention_layers = tuple( + layers_core.Dense( + attention_layer_size, name="attention_layer", use_bias=False) + for attention_layer_size in attention_layer_sizes) + self._attention_layer_size = sum(attention_layer_sizes) else: - self._attention_layer = None - self._attention_layer_size = attention_mechanism.values.get_shape()[ - -1].value + self._attention_layers = None + self._attention_layer_size = sum( + attention_mechanism.values.get_shape()[-1].value + for attention_mechanism in attention_mechanisms) self._cell = cell - self._attention_mechanism = attention_mechanism + self._attention_mechanisms = attention_mechanisms self._cell_input_fn = cell_input_fn self._output_attention = output_attention self._alignment_history = alignment_history @@ -1053,13 +1118,36 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): "via the tf.contrib.seq2seq.tile_batch function with argument " "multiple=beam_width.") with ops.control_dependencies( - [check_ops.assert_equal(state_batch_size, - self._attention_mechanism.batch_size, - message=error_message)]): + self._batch_size_checks(state_batch_size, error_message)): self._initial_cell_state = nest.map_structure( lambda s: array_ops.identity(s, name="check_initial_cell_state"), initial_cell_state) + def _batch_size_checks(self, batch_size, error_message): + return [check_ops.assert_equal(batch_size, + attention_mechanism.batch_size, + message=error_message) + for attention_mechanism in self._attention_mechanisms] + + def _item_or_tuple(self, seq): + """Returns `seq` as tuple or the singular element. + + Which is returned is determined by how the AttentionMechanism(s) were passed + to the constructor. + + Args: + seq: A non-empty sequence of items or generator. + + Returns: + Either the values in the sequence as a tuple if AttentionMechanism(s) + were passed to the constructor as a sequence or the singular element. + """ + t = tuple(seq) + if self._is_multi: + return t + else: + return t[0] + @property def output_size(self): if self._output_attention: @@ -1073,8 +1161,10 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): cell_state=self._cell.state_size, time=tensor_shape.TensorShape([]), attention=self._attention_layer_size, - alignments=self._attention_mechanism.alignments_size, - alignment_history=()) # alignment_history is sometimes a TensorArray + alignments=self._item_or_tuple( + a.alignments_size for a in self._attention_mechanisms), + alignment_history=self._item_or_tuple( + () for _ in self._attention_mechanisms)) # sometimes a TensorArray def zero_state(self, batch_size, dtype): with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): @@ -1091,25 +1181,23 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): "the batch_size= argument passed to zero_state is " "batch_size * beam_width.") with ops.control_dependencies( - [check_ops.assert_equal(batch_size, - self._attention_mechanism.batch_size, - message=error_message)]): + self._batch_size_checks(batch_size, error_message)): cell_state = nest.map_structure( lambda s: array_ops.identity(s, name="checked_cell_state"), cell_state) - if self._alignment_history: - alignment_history = tensor_array_ops.TensorArray( - dtype=dtype, size=0, dynamic_size=True) - else: - alignment_history = () return AttentionWrapperState( cell_state=cell_state, time=array_ops.zeros([], dtype=dtypes.int32), attention=_zero_state_tensors(self._attention_layer_size, batch_size, dtype), - alignments=self._attention_mechanism.initial_alignments( - batch_size, dtype), - alignment_history=alignment_history) + alignments=self._item_or_tuple( + attention_mechanism.initial_alignments(batch_size, dtype) + for attention_mechanism in self._attention_mechanisms), + alignment_history=self._item_or_tuple( + tensor_array_ops.TensorArray(dtype=dtype, size=0, + dynamic_size=True) + if self._alignment_history else () + for _ in self._attention_mechanisms)) def call(self, inputs, state): """Perform a step of attention-wrapped RNN. @@ -1161,48 +1249,38 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): "the tf.contrib.seq2seq.tile_batch function with argument " "multiple=beam_width.") with ops.control_dependencies( - [check_ops.assert_equal(cell_batch_size, - self._attention_mechanism.batch_size, - message=error_message)]): + self._batch_size_checks(cell_batch_size, error_message)): cell_output = array_ops.identity( cell_output, name="checked_cell_output") - alignments = self._attention_mechanism( - cell_output, previous_alignments=state.alignments) - - # Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time] - expanded_alignments = array_ops.expand_dims(alignments, 1) - # Context is the inner product of alignments and values along the - # memory time dimension. - # alignments shape is - # [batch_size, 1, memory_time] - # attention_mechanism.values shape is - # [batch_size, memory_time, attention_mechanism.num_units] - # the batched matmul is over memory_time, so the output shape is - # [batch_size, 1, attention_mechanism.num_units]. - # we then squeeze out the singleton dim. - attention_mechanism_values = self._attention_mechanism.values - context = math_ops.matmul(expanded_alignments, attention_mechanism_values) - context = array_ops.squeeze(context, [1]) - - if self._attention_layer is not None: - attention = self._attention_layer( - array_ops.concat([cell_output, context], 1)) + if self._is_multi: + previous_alignments = state.alignments + previous_alignment_history = state.alignment_history else: - attention = context - - if self._alignment_history: - alignment_history = state.alignment_history.write( - state.time, alignments) - else: - alignment_history = () - + previous_alignments = [state.alignments] + previous_alignment_history = [state.alignment_history] + + all_alignments = [] + all_attentions = [] + all_histories = [] + for i, attention_mechanism in enumerate(self._attention_mechanisms): + attention, alignments = _compute_attention( + attention_mechanism, cell_output, previous_alignments[i], + self._attention_layers[i] if self._attention_layers else None) + alignment_history = previous_alignment_history[i].write( + state.time, alignments) if self._alignment_history else () + + all_alignments.append(alignments) + all_histories.append(alignment_history) + all_attentions.append(attention) + + attention = array_ops.concat(all_attentions, 1) next_state = AttentionWrapperState( time=state.time + 1, cell_state=next_cell_state, attention=attention, - alignments=alignments, - alignment_history=alignment_history) + alignments=self._item_or_tuple(all_alignments), + alignment_history=self._item_or_tuple(all_histories)) if self._output_attention: return attention, next_state -- cgit v1.2.3