aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py')
-rw-r--r--tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py112
1 files changed, 105 insertions, 7 deletions
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
index 0232103c41..cd162bae25 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
@@ -30,6 +30,7 @@ from tensorflow.contrib.seq2seq.python.ops import helper as helper_py
from tensorflow.contrib.seq2seq.python.ops import basic_decoder
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.layers import core as layers_core
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
@@ -110,7 +111,12 @@ class AttentionWrapperTest(test.TestCase):
alignment_history=False,
expected_final_alignment_history=None,
attention_layer_size=6,
+ attention_layer=None,
name=''):
+ attention_layer_sizes = (
+ [attention_layer_size] if attention_layer_size is not None else None)
+ attention_layers = (
+ [attention_layer] if attention_layer is not None else None)
self._testWithMaybeMultiAttention(
is_multi=False,
create_attention_mechanisms=[create_attention_mechanism],
@@ -119,7 +125,8 @@ class AttentionWrapperTest(test.TestCase):
attention_mechanism_depths=[attention_mechanism_depth],
alignment_history=alignment_history,
expected_final_alignment_history=expected_final_alignment_history,
- attention_layer_sizes=[attention_layer_size],
+ attention_layer_sizes=attention_layer_sizes,
+ attention_layers=attention_layers,
name=name)
def _testWithMaybeMultiAttention(self,
@@ -131,6 +138,7 @@ class AttentionWrapperTest(test.TestCase):
alignment_history=False,
expected_final_alignment_history=None,
attention_layer_sizes=None,
+ attention_layers=None,
name=''):
# Allow is_multi to be True with a single mechanism to enable test for
# passing in a single mechanism in a list.
@@ -144,12 +152,18 @@ class AttentionWrapperTest(test.TestCase):
encoder_output_depth = 10
cell_depth = 9
- if attention_layer_sizes is None:
- attention_depth = encoder_output_depth * len(create_attention_mechanisms)
- else:
+ if attention_layer_sizes is not None:
# Compute sum of attention_layer_sizes. Use encoder_output_depth if None.
attention_depth = sum([attention_layer_size or encoder_output_depth
for attention_layer_size in attention_layer_sizes])
+ elif attention_layers is not None:
+ # Compute sum of attention_layers output depth.
+ attention_depth = sum(
+ attention_layer.compute_output_shape(
+ [batch_size, cell_depth + encoder_output_depth])[-1].value
+ for attention_layer in attention_layers)
+ else:
+ attention_depth = encoder_output_depth * len(create_attention_mechanisms)
decoder_inputs = array_ops.placeholder_with_default(
np.random.randn(batch_size, decoder_max_time,
@@ -171,13 +185,20 @@ class AttentionWrapperTest(test.TestCase):
with vs.variable_scope(
'root',
initializer=init_ops.random_normal_initializer(stddev=0.01, seed=3)):
+ attention_layer_size = attention_layer_sizes
+ attention_layer = attention_layers
+ if not is_multi:
+ if attention_layer_size is not None:
+ attention_layer_size = attention_layer_size[0]
+ if attention_layer is not None:
+ attention_layer = attention_layer[0]
cell = rnn_cell.LSTMCell(cell_depth)
cell = wrapper.AttentionWrapper(
cell,
attention_mechanisms if is_multi else attention_mechanisms[0],
- attention_layer_size=(attention_layer_sizes if is_multi
- else attention_layer_sizes[0]),
- alignment_history=alignment_history)
+ attention_layer_size=attention_layer_size,
+ alignment_history=alignment_history,
+ attention_layer=attention_layer)
helper = helper_py.TrainingHelper(decoder_inputs,
decoder_sequence_length)
my_decoder = basic_decoder.BasicDecoder(
@@ -260,6 +281,41 @@ class AttentionWrapperTest(test.TestCase):
expected_final_alignment_history,
final_alignment_history_info)
+ def testBahdanauNormalizedDType(self):
+ for dtype in [np.float16, np.float32, np.float64]:
+ num_units = 128
+ encoder_outputs = array_ops.placeholder(dtype, shape=[64, None, 256])
+ encoder_sequence_length = array_ops.placeholder(dtypes.int32, shape=[64])
+ decoder_inputs = array_ops.placeholder(dtype, shape=[64, None, 128])
+ decoder_sequence_length = array_ops.placeholder(dtypes.int32, shape=[64])
+ batch_size = 64
+ attention_mechanism = wrapper.BahdanauAttention(
+ num_units=num_units,
+ memory=encoder_outputs,
+ memory_sequence_length=encoder_sequence_length,
+ normalize=True,
+ dtype=dtype,
+ )
+ cell = rnn_cell.LSTMCell(num_units)
+ cell = wrapper.AttentionWrapper(cell, attention_mechanism)
+
+ helper = helper_py.TrainingHelper(decoder_inputs,
+ decoder_sequence_length)
+ my_decoder = basic_decoder.BasicDecoder(
+ cell=cell,
+ helper=helper,
+ initial_state=cell.zero_state(
+ dtype=dtype, batch_size=batch_size))
+
+ final_outputs, final_state, _ = decoder.dynamic_decode(my_decoder)
+ self.assertTrue(
+ isinstance(final_outputs, basic_decoder.BasicDecoderOutput))
+ self.assertEqual(final_outputs.rnn_output.dtype, dtype)
+ self.assertTrue(
+ isinstance(final_state, wrapper.AttentionWrapperState))
+ self.assertTrue(
+ isinstance(final_state.cell_state, rnn_cell.LSTMStateTuple))
+
def testBahdanauNotNormalized(self):
create_attention_mechanism = wrapper.BahdanauAttention
@@ -797,6 +853,48 @@ class AttentionWrapperTest(test.TestCase):
expected_final_alignment_history=expected_final_alignment_history,
name='testMultiAttention')
+ def testMultiAttentionWithLayerInstances(self):
+ create_attention_mechanisms = (
+ wrapper.BahdanauAttention, wrapper.LuongAttention)
+
+ expected_final_output = BasicDecoderOutput(
+ rnn_output=ResultSummary(
+ shape=(5, 3, 7), dtype=dtype('float32'), mean=0.0011709079),
+ sample_id=ResultSummary(
+ shape=(5, 3), dtype=dtype('int32'), mean=3.2000000000000002))
+ expected_final_state = AttentionWrapperState(
+ cell_state=LSTMStateTuple(
+ c=ResultSummary(
+ shape=(5, 9), dtype=dtype('float32'), mean=-0.0038725811),
+ h=ResultSummary(
+ shape=(5, 9), dtype=dtype('float32'), mean=-0.0019329828)),
+ attention=ResultSummary(
+ shape=(5, 7), dtype=dtype('float32'), mean=0.001174294),
+ time=3,
+ alignments=(
+ ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125),
+ ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125)),
+ attention_state=(
+ ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125),
+ ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125)),
+ alignment_history=())
+
+ expected_final_alignment_history = (
+ ResultSummary(shape=(3, 5, 8), dtype=dtype('float32'), mean=0.125),
+ ResultSummary(shape=(3, 5, 8), dtype=dtype('float32'), mean=0.125))
+
+ self._testWithMaybeMultiAttention(
+ True,
+ create_attention_mechanisms,
+ expected_final_output,
+ expected_final_state,
+ attention_mechanism_depths=[9, 9],
+ attention_layers=[layers_core.Dense(3, use_bias=False),
+ layers_core.Dense(4, use_bias=False)],
+ alignment_history=True,
+ expected_final_alignment_history=expected_final_alignment_history,
+ name='testMultiAttention')
+
def testLuongMonotonicHard(self):
# Run attention mechanism with mode='hard', make sure probabilities are hard
b, t, u, d = 10, 20, 30, 40