aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/seq2seq
diff options
context:
space:
mode:
authorGravatar Guillaume Klein <guillaumekln@users.noreply.github.com>2018-04-16 04:21:29 +0200
committerGravatar Jonathan Hseu <vomjom@vomjom.net>2018-04-15 19:21:29 -0700
commit0586c57292a7bd1a79b4a03270c0f1c32d02a4af (patch)
tree8f582fd4dc9b67753d87b815f0193a1e6e3c2995 /tensorflow/contrib/seq2seq
parent54772bb9a4a44badf4a70d75f41426c51f47cf3e (diff)
Support passing layer instances to produce attentional hidden states (#14974)
* Support passing Layer instances to the AttentionWrapper. * Use _compute_output_shape to get the attention layer depth * compute_output_shape is now a public method * Move new argument at the end
Diffstat (limited to 'tensorflow/contrib/seq2seq')
-rw-r--r--tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py77
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py35
2 files changed, 102 insertions, 10 deletions
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
index d508cf3f9d..84a7b45b5a 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
@@ -30,6 +30,7 @@ from tensorflow.contrib.seq2seq.python.ops import helper as helper_py
from tensorflow.contrib.seq2seq.python.ops import basic_decoder
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.layers import core as layers_core
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
@@ -110,7 +111,12 @@ class AttentionWrapperTest(test.TestCase):
alignment_history=False,
expected_final_alignment_history=None,
attention_layer_size=6,
+ attention_layer=None,
name=''):
+ attention_layer_sizes = (
+ [attention_layer_size] if attention_layer_size is not None else None)
+ attention_layers = (
+ [attention_layer] if attention_layer is not None else None)
self._testWithMaybeMultiAttention(
is_multi=False,
create_attention_mechanisms=[create_attention_mechanism],
@@ -119,7 +125,8 @@ class AttentionWrapperTest(test.TestCase):
attention_mechanism_depths=[attention_mechanism_depth],
alignment_history=alignment_history,
expected_final_alignment_history=expected_final_alignment_history,
- attention_layer_sizes=[attention_layer_size],
+ attention_layer_sizes=attention_layer_sizes,
+ attention_layers=attention_layers,
name=name)
def _testWithMaybeMultiAttention(self,
@@ -131,6 +138,7 @@ class AttentionWrapperTest(test.TestCase):
alignment_history=False,
expected_final_alignment_history=None,
attention_layer_sizes=None,
+ attention_layers=None,
name=''):
# Allow is_multi to be True with a single mechanism to enable test for
# passing in a single mechanism in a list.
@@ -144,12 +152,18 @@ class AttentionWrapperTest(test.TestCase):
encoder_output_depth = 10
cell_depth = 9
- if attention_layer_sizes is None:
- attention_depth = encoder_output_depth * len(create_attention_mechanisms)
- else:
+ if attention_layer_sizes is not None:
# Compute sum of attention_layer_sizes. Use encoder_output_depth if None.
attention_depth = sum([attention_layer_size or encoder_output_depth
for attention_layer_size in attention_layer_sizes])
+ elif attention_layers is not None:
+ # Compute sum of attention_layers output depth.
+ attention_depth = sum(
+ attention_layer.compute_output_shape(
+ [batch_size, cell_depth + encoder_output_depth])[-1].value
+ for attention_layer in attention_layers)
+ else:
+ attention_depth = encoder_output_depth * len(create_attention_mechanisms)
decoder_inputs = array_ops.placeholder_with_default(
np.random.randn(batch_size, decoder_max_time,
@@ -171,13 +185,20 @@ class AttentionWrapperTest(test.TestCase):
with vs.variable_scope(
'root',
initializer=init_ops.random_normal_initializer(stddev=0.01, seed=3)):
+ attention_layer_size = attention_layer_sizes
+ attention_layer = attention_layers
+ if not is_multi:
+ if attention_layer_size is not None:
+ attention_layer_size = attention_layer_size[0]
+ if attention_layer is not None:
+ attention_layer = attention_layer[0]
cell = rnn_cell.LSTMCell(cell_depth)
cell = wrapper.AttentionWrapper(
cell,
attention_mechanisms if is_multi else attention_mechanisms[0],
- attention_layer_size=(attention_layer_sizes if is_multi
- else attention_layer_sizes[0]),
- alignment_history=alignment_history)
+ attention_layer_size=attention_layer_size,
+ alignment_history=alignment_history,
+ attention_layer=attention_layer)
helper = helper_py.TrainingHelper(decoder_inputs,
decoder_sequence_length)
my_decoder = basic_decoder.BasicDecoder(
@@ -797,6 +818,48 @@ class AttentionWrapperTest(test.TestCase):
expected_final_alignment_history=expected_final_alignment_history,
name='testMultiAttention')
+ def testMultiAttentionWithLayerInstances(self):
+ create_attention_mechanisms = (
+ wrapper.BahdanauAttention, wrapper.LuongAttention)
+
+ expected_final_output = BasicDecoderOutput(
+ rnn_output=ResultSummary(
+ shape=(5, 3, 7), dtype=dtype('float32'), mean=0.0011709079),
+ sample_id=ResultSummary(
+ shape=(5, 3), dtype=dtype('int32'), mean=3.2000000000000002))
+ expected_final_state = AttentionWrapperState(
+ cell_state=LSTMStateTuple(
+ c=ResultSummary(
+ shape=(5, 9), dtype=dtype('float32'), mean=-0.0038725811),
+ h=ResultSummary(
+ shape=(5, 9), dtype=dtype('float32'), mean=-0.0019329828)),
+ attention=ResultSummary(
+ shape=(5, 7), dtype=dtype('float32'), mean=0.001174294),
+ time=3,
+ alignments=(
+ ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125),
+ ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125)),
+ attention_state=(
+ ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125),
+ ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125)),
+ alignment_history=())
+
+ expected_final_alignment_history = (
+ ResultSummary(shape=(3, 5, 8), dtype=dtype('float32'), mean=0.125),
+ ResultSummary(shape=(3, 5, 8), dtype=dtype('float32'), mean=0.125))
+
+ self._testWithMaybeMultiAttention(
+ True,
+ create_attention_mechanisms,
+ expected_final_output,
+ expected_final_state,
+ attention_mechanism_depths=[9, 9],
+ attention_layers=[layers_core.Dense(3, use_bias=False),
+ layers_core.Dense(4, use_bias=False)],
+ alignment_history=True,
+ expected_final_alignment_history=expected_final_alignment_history,
+ name='testMultiAttention')
+
def testLuongMonotonicHard(self):
# Run attention mechanism with mode='hard', make sure probabilities are hard
b, t, u, d = 10, 20, 30, 40
diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
index f0f143ddfc..9ba541ce23 100644
--- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
+++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
@@ -1082,7 +1082,8 @@ class AttentionWrapper(rnn_cell_impl.RNNCell):
cell_input_fn=None,
output_attention=True,
initial_cell_state=None,
- name=None):
+ name=None,
+ attention_layer=None):
"""Construct the `AttentionWrapper`.
**NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in
@@ -1125,7 +1126,8 @@ class AttentionWrapper(rnn_cell_impl.RNNCell):
(default), use the context as attention at each time step. Otherwise,
feed the context and cell output into the attention layer to generate
attention at each time step. If attention_mechanism is a list,
- attention_layer_size must be a list of the same length.
+ attention_layer_size must be a list of the same length. If
+ attention_layer is set, this must be None.
alignment_history: Python boolean, whether to store alignment history
from all time steps in the final output state (currently stored as a
time major `TensorArray` on which you must call `stack()`).
@@ -1145,12 +1147,19 @@ class AttentionWrapper(rnn_cell_impl.RNNCell):
does not match the batch size of `initial_cell_state`, proper
behavior is not guaranteed.
name: Name to use when creating ops.
+ attention_layer: A list of `tf.layers.Layer` instances or a
+ single `tf.layers.Layer` instance taking the context and cell output as
+ inputs to generate attention at each time step. If None (default), use
+ the context as attention at each time step. If attention_mechanism is a
+ list, attention_layer must be a list of the same length. If
+ attention_layers_size is set, this must be None.
Raises:
TypeError: `attention_layer_size` is not None and (`attention_mechanism`
is a list but `attention_layer_size` is not; or vice versa).
ValueError: if `attention_layer_size` is not None, `attention_mechanism`
- is a list, and its length does not match that of `attention_layer_size`.
+ is a list, and its length does not match that of `attention_layer_size`;
+ if `attention_layer_size` and `attention_layer` are set simultaneously.
"""
super(AttentionWrapper, self).__init__(name=name)
rnn_cell_impl.assert_like_rnncell("cell", cell)
@@ -1181,6 +1190,10 @@ class AttentionWrapper(rnn_cell_impl.RNNCell):
"cell_input_fn must be callable, saw type: %s"
% type(cell_input_fn).__name__)
+ if attention_layer_size is not None and attention_layer is not None:
+ raise ValueError("Only one of attention_layer_size and attention_layer "
+ "should be set")
+
if attention_layer_size is not None:
attention_layer_sizes = tuple(
attention_layer_size
@@ -1199,6 +1212,22 @@ class AttentionWrapper(rnn_cell_impl.RNNCell):
dtype=attention_mechanisms[i].dtype)
for i, attention_layer_size in enumerate(attention_layer_sizes))
self._attention_layer_size = sum(attention_layer_sizes)
+ elif attention_layer is not None:
+ self._attention_layers = tuple(
+ attention_layer
+ if isinstance(attention_layer, (list, tuple))
+ else (attention_layer,))
+ if len(self._attention_layers) != len(attention_mechanisms):
+ raise ValueError(
+ "If provided, attention_layer must contain exactly one "
+ "layer per attention_mechanism, saw: %d vs %d"
+ % (len(self._attention_layers), len(attention_mechanisms)))
+ self._attention_layer_size = sum(
+ layer.compute_output_shape(
+ [None,
+ cell.output_size + mechanism.values.shape[-1].value])[-1].value
+ for layer, mechanism in zip(
+ self._attention_layers, attention_mechanisms))
else:
self._attention_layers = None
self._attention_layer_size = sum(