diff options
author | Eugene Brevdo <ebrevdo@google.com> | 2017-03-27 16:00:23 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-03-27 17:17:23 -0700 |
commit | 295ce8f8207675b34174700f0484c938f7bdb3a5 (patch) | |
tree | d5800645444b03686e678db4fbe75f793930ffd5 | |
parent | 7002817f1b1133f6780c2f3d300e5f1c6b7df85e (diff) |
[contrib seq2seq] Allow AttentionWrapper to store its attention history.
Not enabled by default, but enable it with construction flag
attention_history=True.
Change: 151391741
-rw-r--r-- | tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py | 35 | ||||
-rw-r--r-- | tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py | 36 |
2 files changed, 63 insertions, 8 deletions
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py index 5d952e97b7..2aacce9e9c 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py @@ -50,7 +50,8 @@ class AttentionWrapperTest(test.TestCase): create_attention_mechanism, expected_final_outputs, expected_final_state, - attention_mechanism_depth=3): + attention_mechanism_depth=3, + attention_history=False): encoder_sequence_length = [3, 2, 3, 1, 0] decoder_sequence_length = [2, 0, 1, 2, 3] batch_size = 5 @@ -77,7 +78,8 @@ class AttentionWrapperTest(test.TestCase): initializer=init_ops.random_normal_initializer(stddev=0.01, seed=3)): cell = core_rnn_cell.LSTMCell(cell_depth) cell = wrapper.AttentionWrapper( - cell, attention_mechanism, attention_size=attention_depth) + cell, attention_mechanism, attention_size=attention_depth, + attention_history=attention_history) helper = helper_py.TrainingHelper(decoder_inputs, decoder_sequence_length) my_decoder = basic_decoder.BasicDecoder( @@ -107,16 +109,33 @@ class AttentionWrapperTest(test.TestCase): self.assertEqual((batch_size, cell_depth), tuple(final_state.cell_state.h.get_shape().as_list())) + if attention_history: + state_attention_history = final_state.attention_history.stack() + # Remove the history from final_state for purposes of the + # remainder of the tests. + final_state = final_state._replace(attention_history=()) # pylint: disable=protected-access + self.assertEqual((None, batch_size, attention_depth), + tuple(state_attention_history.get_shape().as_list())) + else: + state_attention_history = () + sess.run(variables.global_variables_initializer()) sess_results = sess.run({ "final_outputs": final_outputs, - "final_state": final_state + "final_state": final_state, + "state_attention_history": state_attention_history, }) nest.map_structure(self.assertAllClose, expected_final_outputs, sess_results["final_outputs"]) nest.map_structure(self.assertAllClose, expected_final_state, sess_results["final_state"]) + if attention_history: # by default, the wrapper emits attention as output + self.assertAllClose( + # outputs are batch major but the stacked TensorArray is time major + sess_results["state_attention_history"], + np.transpose(sess_results["final_outputs"].rnn_output, + (1, 0, 2))) def testBahndahauNotNormalized(self): create_attention_mechanism = wrapper.BahdanauAttention @@ -179,6 +198,8 @@ class AttentionWrapperTest(test.TestCase): dtype=int32)) expected_final_state = wrapper.AttentionWrapperState( + time=3, + attention_history=(), cell_state=core_rnn_cell.LSTMStateTuple( c=array( [[ @@ -243,7 +264,7 @@ class AttentionWrapperTest(test.TestCase): ]], dtype=float32)) self._testWithAttention(create_attention_mechanism, expected_final_outputs, - expected_final_state) + expected_final_state, attention_history=True) def testBahndahauNormalized(self): create_attention_mechanism = functools.partial( @@ -307,6 +328,8 @@ class AttentionWrapperTest(test.TestCase): dtype=int32)) expected_final_state = wrapper.AttentionWrapperState( + time=3, + attention_history=(), cell_state=core_rnn_cell.LSTMStateTuple( c=array( [[ @@ -435,6 +458,8 @@ class AttentionWrapperTest(test.TestCase): [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]], dtype=int32)) expected_final_state = wrapper.AttentionWrapperState( + time=3, + attention_history=(), cell_state=core_rnn_cell.LSTMStateTuple( c=array( [[ @@ -567,6 +592,8 @@ class AttentionWrapperTest(test.TestCase): [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]], dtype=int32)) expected_final_state = wrapper.AttentionWrapperState( + time=3, + attention_history=(), cell_state=core_rnn_cell.LSTMStateTuple( c=array( [[ diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py index 0af001d274..562d7a8c76 100644 --- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py +++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py @@ -25,6 +25,7 @@ import math from tensorflow.contrib.rnn import core_rnn_cell from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape from tensorflow.python.layers import base as layers_base from tensorflow.python.layers import core as layers_core from tensorflow.python.ops import array_ops @@ -32,6 +33,7 @@ from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import rnn_cell_impl +from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import variable_scope from tensorflow.python.util import nest @@ -378,13 +380,17 @@ class BahdanauAttention(_BaseAttentionMechanism): class AttentionWrapperState( collections.namedtuple( - "AttentionWrapperState", ("cell_state", "attention"))): + "AttentionWrapperState", ( + "cell_state", "attention", "time", "attention_history"))): """`namedtuple` storing the state of a `AttentionWrapper`. Contains: - `cell_state`: The state of the wrapped `RNNCell`. - `attention`: The attention emitted at the previous time step. + - `time`: int32 scalar containing the current time step. + - `attention_history`: (if enabled) a `TensorArray` containing attention + matrices from all time steps. Call `stack()` to convert to a `Tensor`. """ pass @@ -418,6 +424,7 @@ class AttentionWrapper(core_rnn_cell.RNNCell): cell, attention_mechanism, attention_size, + attention_history=False, cell_input_fn=None, probability_fn=None, output_attention=True, @@ -429,6 +436,9 @@ class AttentionWrapper(core_rnn_cell.RNNCell): attention_mechanism: An instance of `AttentionMechanism`. attention_size: Python integer, the depth of the attention (output) tensor. + attention_history: Python boolean, whether to store attention history + from all time steps in the final output state (currently stored as a + time major `TensorArray` on which you must call `stack()`). cell_input_fn: (optional) A `callable`. The default is: `lambda inputs, attention: array_ops.concat([inputs, attention], -1)`. probability_fn: (optional) A `callable`. Converts the score to @@ -474,6 +484,7 @@ class AttentionWrapper(core_rnn_cell.RNNCell): self._cell_input_fn = cell_input_fn self._probability_fn = probability_fn self._output_attention = output_attention + self._attention_history = attention_history @property def output_size(self): @@ -486,14 +497,23 @@ class AttentionWrapper(core_rnn_cell.RNNCell): def state_size(self): return AttentionWrapperState( cell_state=self._cell.state_size, - attention=self._attention_size) + time=tensor_shape.TensorShape([]), + attention=self._attention_size, + attention_history=()) # attention_history is sometimes a TensorArray def zero_state(self, batch_size, dtype): with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): + if self._attention_history: + attention_history = tensor_array_ops.TensorArray( + dtype=dtype, size=0, dynamic_size=True) + else: + attention_history = () return AttentionWrapperState( cell_state=self._cell.zero_state(batch_size, dtype), + time=array_ops.zeros([], dtype=dtypes.int32), attention=_zero_state_tensors( - self._attention_size, batch_size, dtype)) + self._attention_size, batch_size, dtype), + attention_history=attention_history) def __call__(self, inputs, state, scope=None): """Perform a step of attention-wrapped RNN. @@ -555,9 +575,17 @@ class AttentionWrapper(core_rnn_cell.RNNCell): attention = self._attention_layer( array_ops.concat([cell_output, context], 1)) + if self._attention_history: + attention_history = state.attention_history.write( + state.time, attention) + else: + attention_history = () + next_state = AttentionWrapperState( + time=state.time + 1, cell_state=next_cell_state, - attention=attention) + attention=attention, + attention_history=attention_history) if self._output_attention: return attention, next_state |