aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2017-03-27 16:00:23 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-27 17:17:23 -0700
commit295ce8f8207675b34174700f0484c938f7bdb3a5 (patch)
treed5800645444b03686e678db4fbe75f793930ffd5
parent7002817f1b1133f6780c2f3d300e5f1c6b7df85e (diff)
[contrib seq2seq] Allow AttentionWrapper to store its attention history.
Not enabled by default, but enable it with construction flag attention_history=True. Change: 151391741
-rw-r--r--tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py35
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py36
2 files changed, 63 insertions, 8 deletions
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
index 5d952e97b7..2aacce9e9c 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
@@ -50,7 +50,8 @@ class AttentionWrapperTest(test.TestCase):
create_attention_mechanism,
expected_final_outputs,
expected_final_state,
- attention_mechanism_depth=3):
+ attention_mechanism_depth=3,
+ attention_history=False):
encoder_sequence_length = [3, 2, 3, 1, 0]
decoder_sequence_length = [2, 0, 1, 2, 3]
batch_size = 5
@@ -77,7 +78,8 @@ class AttentionWrapperTest(test.TestCase):
initializer=init_ops.random_normal_initializer(stddev=0.01, seed=3)):
cell = core_rnn_cell.LSTMCell(cell_depth)
cell = wrapper.AttentionWrapper(
- cell, attention_mechanism, attention_size=attention_depth)
+ cell, attention_mechanism, attention_size=attention_depth,
+ attention_history=attention_history)
helper = helper_py.TrainingHelper(decoder_inputs,
decoder_sequence_length)
my_decoder = basic_decoder.BasicDecoder(
@@ -107,16 +109,33 @@ class AttentionWrapperTest(test.TestCase):
self.assertEqual((batch_size, cell_depth),
tuple(final_state.cell_state.h.get_shape().as_list()))
+ if attention_history:
+ state_attention_history = final_state.attention_history.stack()
+ # Remove the history from final_state for purposes of the
+ # remainder of the tests.
+ final_state = final_state._replace(attention_history=()) # pylint: disable=protected-access
+ self.assertEqual((None, batch_size, attention_depth),
+ tuple(state_attention_history.get_shape().as_list()))
+ else:
+ state_attention_history = ()
+
sess.run(variables.global_variables_initializer())
sess_results = sess.run({
"final_outputs": final_outputs,
- "final_state": final_state
+ "final_state": final_state,
+ "state_attention_history": state_attention_history,
})
nest.map_structure(self.assertAllClose, expected_final_outputs,
sess_results["final_outputs"])
nest.map_structure(self.assertAllClose, expected_final_state,
sess_results["final_state"])
+ if attention_history: # by default, the wrapper emits attention as output
+ self.assertAllClose(
+ # outputs are batch major but the stacked TensorArray is time major
+ sess_results["state_attention_history"],
+ np.transpose(sess_results["final_outputs"].rnn_output,
+ (1, 0, 2)))
def testBahndahauNotNormalized(self):
create_attention_mechanism = wrapper.BahdanauAttention
@@ -179,6 +198,8 @@ class AttentionWrapperTest(test.TestCase):
dtype=int32))
expected_final_state = wrapper.AttentionWrapperState(
+ time=3,
+ attention_history=(),
cell_state=core_rnn_cell.LSTMStateTuple(
c=array(
[[
@@ -243,7 +264,7 @@ class AttentionWrapperTest(test.TestCase):
]],
dtype=float32))
self._testWithAttention(create_attention_mechanism, expected_final_outputs,
- expected_final_state)
+ expected_final_state, attention_history=True)
def testBahndahauNormalized(self):
create_attention_mechanism = functools.partial(
@@ -307,6 +328,8 @@ class AttentionWrapperTest(test.TestCase):
dtype=int32))
expected_final_state = wrapper.AttentionWrapperState(
+ time=3,
+ attention_history=(),
cell_state=core_rnn_cell.LSTMStateTuple(
c=array(
[[
@@ -435,6 +458,8 @@ class AttentionWrapperTest(test.TestCase):
[[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
dtype=int32))
expected_final_state = wrapper.AttentionWrapperState(
+ time=3,
+ attention_history=(),
cell_state=core_rnn_cell.LSTMStateTuple(
c=array(
[[
@@ -567,6 +592,8 @@ class AttentionWrapperTest(test.TestCase):
[[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
dtype=int32))
expected_final_state = wrapper.AttentionWrapperState(
+ time=3,
+ attention_history=(),
cell_state=core_rnn_cell.LSTMStateTuple(
c=array(
[[
diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
index 0af001d274..562d7a8c76 100644
--- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
+++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
@@ -25,6 +25,7 @@ import math
from tensorflow.contrib.rnn import core_rnn_cell
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.layers import base as layers_base
from tensorflow.python.layers import core as layers_core
from tensorflow.python.ops import array_ops
@@ -32,6 +33,7 @@ from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import rnn_cell_impl
+from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.util import nest
@@ -378,13 +380,17 @@ class BahdanauAttention(_BaseAttentionMechanism):
class AttentionWrapperState(
collections.namedtuple(
- "AttentionWrapperState", ("cell_state", "attention"))):
+ "AttentionWrapperState", (
+ "cell_state", "attention", "time", "attention_history"))):
"""`namedtuple` storing the state of a `AttentionWrapper`.
Contains:
- `cell_state`: The state of the wrapped `RNNCell`.
- `attention`: The attention emitted at the previous time step.
+ - `time`: int32 scalar containing the current time step.
+ - `attention_history`: (if enabled) a `TensorArray` containing attention
+ matrices from all time steps. Call `stack()` to convert to a `Tensor`.
"""
pass
@@ -418,6 +424,7 @@ class AttentionWrapper(core_rnn_cell.RNNCell):
cell,
attention_mechanism,
attention_size,
+ attention_history=False,
cell_input_fn=None,
probability_fn=None,
output_attention=True,
@@ -429,6 +436,9 @@ class AttentionWrapper(core_rnn_cell.RNNCell):
attention_mechanism: An instance of `AttentionMechanism`.
attention_size: Python integer, the depth of the attention (output)
tensor.
+ attention_history: Python boolean, whether to store attention history
+ from all time steps in the final output state (currently stored as a
+ time major `TensorArray` on which you must call `stack()`).
cell_input_fn: (optional) A `callable`. The default is:
`lambda inputs, attention: array_ops.concat([inputs, attention], -1)`.
probability_fn: (optional) A `callable`. Converts the score to
@@ -474,6 +484,7 @@ class AttentionWrapper(core_rnn_cell.RNNCell):
self._cell_input_fn = cell_input_fn
self._probability_fn = probability_fn
self._output_attention = output_attention
+ self._attention_history = attention_history
@property
def output_size(self):
@@ -486,14 +497,23 @@ class AttentionWrapper(core_rnn_cell.RNNCell):
def state_size(self):
return AttentionWrapperState(
cell_state=self._cell.state_size,
- attention=self._attention_size)
+ time=tensor_shape.TensorShape([]),
+ attention=self._attention_size,
+ attention_history=()) # attention_history is sometimes a TensorArray
def zero_state(self, batch_size, dtype):
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
+ if self._attention_history:
+ attention_history = tensor_array_ops.TensorArray(
+ dtype=dtype, size=0, dynamic_size=True)
+ else:
+ attention_history = ()
return AttentionWrapperState(
cell_state=self._cell.zero_state(batch_size, dtype),
+ time=array_ops.zeros([], dtype=dtypes.int32),
attention=_zero_state_tensors(
- self._attention_size, batch_size, dtype))
+ self._attention_size, batch_size, dtype),
+ attention_history=attention_history)
def __call__(self, inputs, state, scope=None):
"""Perform a step of attention-wrapped RNN.
@@ -555,9 +575,17 @@ class AttentionWrapper(core_rnn_cell.RNNCell):
attention = self._attention_layer(
array_ops.concat([cell_output, context], 1))
+ if self._attention_history:
+ attention_history = state.attention_history.write(
+ state.time, attention)
+ else:
+ attention_history = ()
+
next_state = AttentionWrapperState(
+ time=state.time + 1,
cell_state=next_cell_state,
- attention=attention)
+ attention=attention,
+ attention_history=attention_history)
if self._output_attention:
return attention, next_state