aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/seq2seq/python
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/seq2seq/python')
-rw-r--r--tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py112
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py38
2 files changed, 139 insertions, 11 deletions
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
index 0232103c41..cd162bae25 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
@@ -30,6 +30,7 @@ from tensorflow.contrib.seq2seq.python.ops import helper as helper_py
from tensorflow.contrib.seq2seq.python.ops import basic_decoder
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.layers import core as layers_core
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
@@ -110,7 +111,12 @@ class AttentionWrapperTest(test.TestCase):
alignment_history=False,
expected_final_alignment_history=None,
attention_layer_size=6,
+ attention_layer=None,
name=''):
+ attention_layer_sizes = (
+ [attention_layer_size] if attention_layer_size is not None else None)
+ attention_layers = (
+ [attention_layer] if attention_layer is not None else None)
self._testWithMaybeMultiAttention(
is_multi=False,
create_attention_mechanisms=[create_attention_mechanism],
@@ -119,7 +125,8 @@ class AttentionWrapperTest(test.TestCase):
attention_mechanism_depths=[attention_mechanism_depth],
alignment_history=alignment_history,
expected_final_alignment_history=expected_final_alignment_history,
- attention_layer_sizes=[attention_layer_size],
+ attention_layer_sizes=attention_layer_sizes,
+ attention_layers=attention_layers,
name=name)
def _testWithMaybeMultiAttention(self,
@@ -131,6 +138,7 @@ class AttentionWrapperTest(test.TestCase):
alignment_history=False,
expected_final_alignment_history=None,
attention_layer_sizes=None,
+ attention_layers=None,
name=''):
# Allow is_multi to be True with a single mechanism to enable test for
# passing in a single mechanism in a list.
@@ -144,12 +152,18 @@ class AttentionWrapperTest(test.TestCase):
encoder_output_depth = 10
cell_depth = 9
- if attention_layer_sizes is None:
- attention_depth = encoder_output_depth * len(create_attention_mechanisms)
- else:
+ if attention_layer_sizes is not None:
# Compute sum of attention_layer_sizes. Use encoder_output_depth if None.
attention_depth = sum([attention_layer_size or encoder_output_depth
for attention_layer_size in attention_layer_sizes])
+ elif attention_layers is not None:
+ # Compute sum of attention_layers output depth.
+ attention_depth = sum(
+ attention_layer.compute_output_shape(
+ [batch_size, cell_depth + encoder_output_depth])[-1].value
+ for attention_layer in attention_layers)
+ else:
+ attention_depth = encoder_output_depth * len(create_attention_mechanisms)
decoder_inputs = array_ops.placeholder_with_default(
np.random.randn(batch_size, decoder_max_time,
@@ -171,13 +185,20 @@ class AttentionWrapperTest(test.TestCase):
with vs.variable_scope(
'root',
initializer=init_ops.random_normal_initializer(stddev=0.01, seed=3)):
+ attention_layer_size = attention_layer_sizes
+ attention_layer = attention_layers
+ if not is_multi:
+ if attention_layer_size is not None:
+ attention_layer_size = attention_layer_size[0]
+ if attention_layer is not None:
+ attention_layer = attention_layer[0]
cell = rnn_cell.LSTMCell(cell_depth)
cell = wrapper.AttentionWrapper(
cell,
attention_mechanisms if is_multi else attention_mechanisms[0],
- attention_layer_size=(attention_layer_sizes if is_multi
- else attention_layer_sizes[0]),
- alignment_history=alignment_history)
+ attention_layer_size=attention_layer_size,
+ alignment_history=alignment_history,
+ attention_layer=attention_layer)
helper = helper_py.TrainingHelper(decoder_inputs,
decoder_sequence_length)
my_decoder = basic_decoder.BasicDecoder(
@@ -260,6 +281,41 @@ class AttentionWrapperTest(test.TestCase):
expected_final_alignment_history,
final_alignment_history_info)
+ def testBahdanauNormalizedDType(self):
+ for dtype in [np.float16, np.float32, np.float64]:
+ num_units = 128
+ encoder_outputs = array_ops.placeholder(dtype, shape=[64, None, 256])
+ encoder_sequence_length = array_ops.placeholder(dtypes.int32, shape=[64])
+ decoder_inputs = array_ops.placeholder(dtype, shape=[64, None, 128])
+ decoder_sequence_length = array_ops.placeholder(dtypes.int32, shape=[64])
+ batch_size = 64
+ attention_mechanism = wrapper.BahdanauAttention(
+ num_units=num_units,
+ memory=encoder_outputs,
+ memory_sequence_length=encoder_sequence_length,
+ normalize=True,
+ dtype=dtype,
+ )
+ cell = rnn_cell.LSTMCell(num_units)
+ cell = wrapper.AttentionWrapper(cell, attention_mechanism)
+
+ helper = helper_py.TrainingHelper(decoder_inputs,
+ decoder_sequence_length)
+ my_decoder = basic_decoder.BasicDecoder(
+ cell=cell,
+ helper=helper,
+ initial_state=cell.zero_state(
+ dtype=dtype, batch_size=batch_size))
+
+ final_outputs, final_state, _ = decoder.dynamic_decode(my_decoder)
+ self.assertTrue(
+ isinstance(final_outputs, basic_decoder.BasicDecoderOutput))
+ self.assertEqual(final_outputs.rnn_output.dtype, dtype)
+ self.assertTrue(
+ isinstance(final_state, wrapper.AttentionWrapperState))
+ self.assertTrue(
+ isinstance(final_state.cell_state, rnn_cell.LSTMStateTuple))
+
def testBahdanauNotNormalized(self):
create_attention_mechanism = wrapper.BahdanauAttention
@@ -797,6 +853,48 @@ class AttentionWrapperTest(test.TestCase):
expected_final_alignment_history=expected_final_alignment_history,
name='testMultiAttention')
+ def testMultiAttentionWithLayerInstances(self):
+ create_attention_mechanisms = (
+ wrapper.BahdanauAttention, wrapper.LuongAttention)
+
+ expected_final_output = BasicDecoderOutput(
+ rnn_output=ResultSummary(
+ shape=(5, 3, 7), dtype=dtype('float32'), mean=0.0011709079),
+ sample_id=ResultSummary(
+ shape=(5, 3), dtype=dtype('int32'), mean=3.2000000000000002))
+ expected_final_state = AttentionWrapperState(
+ cell_state=LSTMStateTuple(
+ c=ResultSummary(
+ shape=(5, 9), dtype=dtype('float32'), mean=-0.0038725811),
+ h=ResultSummary(
+ shape=(5, 9), dtype=dtype('float32'), mean=-0.0019329828)),
+ attention=ResultSummary(
+ shape=(5, 7), dtype=dtype('float32'), mean=0.001174294),
+ time=3,
+ alignments=(
+ ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125),
+ ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125)),
+ attention_state=(
+ ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125),
+ ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125)),
+ alignment_history=())
+
+ expected_final_alignment_history = (
+ ResultSummary(shape=(3, 5, 8), dtype=dtype('float32'), mean=0.125),
+ ResultSummary(shape=(3, 5, 8), dtype=dtype('float32'), mean=0.125))
+
+ self._testWithMaybeMultiAttention(
+ True,
+ create_attention_mechanisms,
+ expected_final_output,
+ expected_final_state,
+ attention_mechanism_depths=[9, 9],
+ attention_layers=[layers_core.Dense(3, use_bias=False),
+ layers_core.Dense(4, use_bias=False)],
+ alignment_history=True,
+ expected_final_alignment_history=expected_final_alignment_history,
+ name='testMultiAttention')
+
def testLuongMonotonicHard(self):
# Run attention mechanism with mode='hard', make sure probabilities are hard
b, t, u, d = 10, 20, 30, 40
diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
index 8a40a7ab53..1c9d179e3c 100644
--- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
+++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
@@ -472,7 +472,8 @@ def _bahdanau_score(processed_query, keys, normalize):
# Scalar used in weight normalization
g = variable_scope.get_variable(
"attention_g", dtype=dtype,
- initializer=math.sqrt((1. / num_units)))
+ initializer=init_ops.constant_initializer(math.sqrt((1. / num_units))),
+ shape=())
# Bias added prior to the nonlinearity
b = variable_scope.get_variable(
"attention_b", [num_units], dtype=dtype,
@@ -1082,7 +1083,8 @@ class AttentionWrapper(rnn_cell_impl.RNNCell):
cell_input_fn=None,
output_attention=True,
initial_cell_state=None,
- name=None):
+ name=None,
+ attention_layer=None):
"""Construct the `AttentionWrapper`.
**NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in
@@ -1125,7 +1127,8 @@ class AttentionWrapper(rnn_cell_impl.RNNCell):
(default), use the context as attention at each time step. Otherwise,
feed the context and cell output into the attention layer to generate
attention at each time step. If attention_mechanism is a list,
- attention_layer_size must be a list of the same length.
+ attention_layer_size must be a list of the same length. If
+ attention_layer is set, this must be None.
alignment_history: Python boolean, whether to store alignment history
from all time steps in the final output state (currently stored as a
time major `TensorArray` on which you must call `stack()`).
@@ -1145,12 +1148,19 @@ class AttentionWrapper(rnn_cell_impl.RNNCell):
does not match the batch size of `initial_cell_state`, proper
behavior is not guaranteed.
name: Name to use when creating ops.
+ attention_layer: A list of `tf.layers.Layer` instances or a
+ single `tf.layers.Layer` instance taking the context and cell output as
+ inputs to generate attention at each time step. If None (default), use
+ the context as attention at each time step. If attention_mechanism is a
+ list, attention_layer must be a list of the same length. If
+ attention_layers_size is set, this must be None.
Raises:
TypeError: `attention_layer_size` is not None and (`attention_mechanism`
is a list but `attention_layer_size` is not; or vice versa).
ValueError: if `attention_layer_size` is not None, `attention_mechanism`
- is a list, and its length does not match that of `attention_layer_size`.
+ is a list, and its length does not match that of `attention_layer_size`;
+ if `attention_layer_size` and `attention_layer` are set simultaneously.
"""
super(AttentionWrapper, self).__init__(name=name)
rnn_cell_impl.assert_like_rnncell("cell", cell)
@@ -1181,6 +1191,10 @@ class AttentionWrapper(rnn_cell_impl.RNNCell):
"cell_input_fn must be callable, saw type: %s"
% type(cell_input_fn).__name__)
+ if attention_layer_size is not None and attention_layer is not None:
+ raise ValueError("Only one of attention_layer_size and attention_layer "
+ "should be set")
+
if attention_layer_size is not None:
attention_layer_sizes = tuple(
attention_layer_size
@@ -1199,6 +1213,22 @@ class AttentionWrapper(rnn_cell_impl.RNNCell):
dtype=attention_mechanisms[i].dtype)
for i, attention_layer_size in enumerate(attention_layer_sizes))
self._attention_layer_size = sum(attention_layer_sizes)
+ elif attention_layer is not None:
+ self._attention_layers = tuple(
+ attention_layer
+ if isinstance(attention_layer, (list, tuple))
+ else (attention_layer,))
+ if len(self._attention_layers) != len(attention_mechanisms):
+ raise ValueError(
+ "If provided, attention_layer must contain exactly one "
+ "layer per attention_mechanism, saw: %d vs %d"
+ % (len(self._attention_layers), len(attention_mechanisms)))
+ self._attention_layer_size = sum(
+ layer.compute_output_shape(
+ [None,
+ cell.output_size + mechanism.values.shape[-1].value])[-1].value
+ for layer, mechanism in zip(
+ self._attention_layers, attention_mechanisms))
else:
self._attention_layers = None
self._attention_layer_size = sum(