diff options
author | Guillaume Klein <guillaumekln@users.noreply.github.com> | 2018-04-16 04:21:29 +0200 |
---|---|---|
committer | Jonathan Hseu <vomjom@vomjom.net> | 2018-04-15 19:21:29 -0700 |
commit | 0586c57292a7bd1a79b4a03270c0f1c32d02a4af (patch) | |
tree | 8f582fd4dc9b67753d87b815f0193a1e6e3c2995 /tensorflow/contrib/seq2seq | |
parent | 54772bb9a4a44badf4a70d75f41426c51f47cf3e (diff) |
Support passing layer instances to produce attentional hidden states (#14974)
* Support passing Layer instances to the AttentionWrapper.
* Use _compute_output_shape to get the attention layer depth
* compute_output_shape is now a public method
* Move new argument at the end
Diffstat (limited to 'tensorflow/contrib/seq2seq')
-rw-r--r-- | tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py | 77 | ||||
-rw-r--r-- | tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py | 35 |
2 files changed, 102 insertions, 10 deletions
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py index d508cf3f9d..84a7b45b5a 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py @@ -30,6 +30,7 @@ from tensorflow.contrib.seq2seq.python.ops import helper as helper_py from tensorflow.contrib.seq2seq.python.ops import basic_decoder from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.layers import core as layers_core from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops @@ -110,7 +111,12 @@ class AttentionWrapperTest(test.TestCase): alignment_history=False, expected_final_alignment_history=None, attention_layer_size=6, + attention_layer=None, name=''): + attention_layer_sizes = ( + [attention_layer_size] if attention_layer_size is not None else None) + attention_layers = ( + [attention_layer] if attention_layer is not None else None) self._testWithMaybeMultiAttention( is_multi=False, create_attention_mechanisms=[create_attention_mechanism], @@ -119,7 +125,8 @@ class AttentionWrapperTest(test.TestCase): attention_mechanism_depths=[attention_mechanism_depth], alignment_history=alignment_history, expected_final_alignment_history=expected_final_alignment_history, - attention_layer_sizes=[attention_layer_size], + attention_layer_sizes=attention_layer_sizes, + attention_layers=attention_layers, name=name) def _testWithMaybeMultiAttention(self, @@ -131,6 +138,7 @@ class AttentionWrapperTest(test.TestCase): alignment_history=False, expected_final_alignment_history=None, attention_layer_sizes=None, + attention_layers=None, name=''): # Allow is_multi to be True with a single mechanism to enable test for # passing in a single mechanism in a list. @@ -144,12 +152,18 @@ class AttentionWrapperTest(test.TestCase): encoder_output_depth = 10 cell_depth = 9 - if attention_layer_sizes is None: - attention_depth = encoder_output_depth * len(create_attention_mechanisms) - else: + if attention_layer_sizes is not None: # Compute sum of attention_layer_sizes. Use encoder_output_depth if None. attention_depth = sum([attention_layer_size or encoder_output_depth for attention_layer_size in attention_layer_sizes]) + elif attention_layers is not None: + # Compute sum of attention_layers output depth. + attention_depth = sum( + attention_layer.compute_output_shape( + [batch_size, cell_depth + encoder_output_depth])[-1].value + for attention_layer in attention_layers) + else: + attention_depth = encoder_output_depth * len(create_attention_mechanisms) decoder_inputs = array_ops.placeholder_with_default( np.random.randn(batch_size, decoder_max_time, @@ -171,13 +185,20 @@ class AttentionWrapperTest(test.TestCase): with vs.variable_scope( 'root', initializer=init_ops.random_normal_initializer(stddev=0.01, seed=3)): + attention_layer_size = attention_layer_sizes + attention_layer = attention_layers + if not is_multi: + if attention_layer_size is not None: + attention_layer_size = attention_layer_size[0] + if attention_layer is not None: + attention_layer = attention_layer[0] cell = rnn_cell.LSTMCell(cell_depth) cell = wrapper.AttentionWrapper( cell, attention_mechanisms if is_multi else attention_mechanisms[0], - attention_layer_size=(attention_layer_sizes if is_multi - else attention_layer_sizes[0]), - alignment_history=alignment_history) + attention_layer_size=attention_layer_size, + alignment_history=alignment_history, + attention_layer=attention_layer) helper = helper_py.TrainingHelper(decoder_inputs, decoder_sequence_length) my_decoder = basic_decoder.BasicDecoder( @@ -797,6 +818,48 @@ class AttentionWrapperTest(test.TestCase): expected_final_alignment_history=expected_final_alignment_history, name='testMultiAttention') + def testMultiAttentionWithLayerInstances(self): + create_attention_mechanisms = ( + wrapper.BahdanauAttention, wrapper.LuongAttention) + + expected_final_output = BasicDecoderOutput( + rnn_output=ResultSummary( + shape=(5, 3, 7), dtype=dtype('float32'), mean=0.0011709079), + sample_id=ResultSummary( + shape=(5, 3), dtype=dtype('int32'), mean=3.2000000000000002)) + expected_final_state = AttentionWrapperState( + cell_state=LSTMStateTuple( + c=ResultSummary( + shape=(5, 9), dtype=dtype('float32'), mean=-0.0038725811), + h=ResultSummary( + shape=(5, 9), dtype=dtype('float32'), mean=-0.0019329828)), + attention=ResultSummary( + shape=(5, 7), dtype=dtype('float32'), mean=0.001174294), + time=3, + alignments=( + ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125), + ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125)), + attention_state=( + ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125), + ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125)), + alignment_history=()) + + expected_final_alignment_history = ( + ResultSummary(shape=(3, 5, 8), dtype=dtype('float32'), mean=0.125), + ResultSummary(shape=(3, 5, 8), dtype=dtype('float32'), mean=0.125)) + + self._testWithMaybeMultiAttention( + True, + create_attention_mechanisms, + expected_final_output, + expected_final_state, + attention_mechanism_depths=[9, 9], + attention_layers=[layers_core.Dense(3, use_bias=False), + layers_core.Dense(4, use_bias=False)], + alignment_history=True, + expected_final_alignment_history=expected_final_alignment_history, + name='testMultiAttention') + def testLuongMonotonicHard(self): # Run attention mechanism with mode='hard', make sure probabilities are hard b, t, u, d = 10, 20, 30, 40 diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py index f0f143ddfc..9ba541ce23 100644 --- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py +++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py @@ -1082,7 +1082,8 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): cell_input_fn=None, output_attention=True, initial_cell_state=None, - name=None): + name=None, + attention_layer=None): """Construct the `AttentionWrapper`. **NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in @@ -1125,7 +1126,8 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): (default), use the context as attention at each time step. Otherwise, feed the context and cell output into the attention layer to generate attention at each time step. If attention_mechanism is a list, - attention_layer_size must be a list of the same length. + attention_layer_size must be a list of the same length. If + attention_layer is set, this must be None. alignment_history: Python boolean, whether to store alignment history from all time steps in the final output state (currently stored as a time major `TensorArray` on which you must call `stack()`). @@ -1145,12 +1147,19 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): does not match the batch size of `initial_cell_state`, proper behavior is not guaranteed. name: Name to use when creating ops. + attention_layer: A list of `tf.layers.Layer` instances or a + single `tf.layers.Layer` instance taking the context and cell output as + inputs to generate attention at each time step. If None (default), use + the context as attention at each time step. If attention_mechanism is a + list, attention_layer must be a list of the same length. If + attention_layers_size is set, this must be None. Raises: TypeError: `attention_layer_size` is not None and (`attention_mechanism` is a list but `attention_layer_size` is not; or vice versa). ValueError: if `attention_layer_size` is not None, `attention_mechanism` - is a list, and its length does not match that of `attention_layer_size`. + is a list, and its length does not match that of `attention_layer_size`; + if `attention_layer_size` and `attention_layer` are set simultaneously. """ super(AttentionWrapper, self).__init__(name=name) rnn_cell_impl.assert_like_rnncell("cell", cell) @@ -1181,6 +1190,10 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): "cell_input_fn must be callable, saw type: %s" % type(cell_input_fn).__name__) + if attention_layer_size is not None and attention_layer is not None: + raise ValueError("Only one of attention_layer_size and attention_layer " + "should be set") + if attention_layer_size is not None: attention_layer_sizes = tuple( attention_layer_size @@ -1199,6 +1212,22 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): dtype=attention_mechanisms[i].dtype) for i, attention_layer_size in enumerate(attention_layer_sizes)) self._attention_layer_size = sum(attention_layer_sizes) + elif attention_layer is not None: + self._attention_layers = tuple( + attention_layer + if isinstance(attention_layer, (list, tuple)) + else (attention_layer,)) + if len(self._attention_layers) != len(attention_mechanisms): + raise ValueError( + "If provided, attention_layer must contain exactly one " + "layer per attention_mechanism, saw: %d vs %d" + % (len(self._attention_layers), len(attention_mechanisms))) + self._attention_layer_size = sum( + layer.compute_output_shape( + [None, + cell.output_size + mechanism.values.shape[-1].value])[-1].value + for layer, mechanism in zip( + self._attention_layers, attention_mechanisms)) else: self._attention_layers = None self._attention_layer_size = sum( |