aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/seq2seq
diff options
context:
space:
mode:
authorGravatar Yong Tang <yong.tang.github@outlook.com>2018-04-01 02:07:30 +0000
committerGravatar Yong Tang <yong.tang.github@outlook.com>2018-04-16 20:04:11 +0000
commitfe4ab63ab258d67f37844f374db265130ceecf2a (patch)
treefef3501e8341025ddf04f246a3ba429600b5ba78 /tensorflow/contrib/seq2seq
parent35b8a8cfebe910687f3cc038c00a6e33ba09637a (diff)
Add test case for Bahdanau attention when normalized=True and dtype = float16/32
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Diffstat (limited to 'tensorflow/contrib/seq2seq')
-rw-r--r--tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py35
1 files changed, 35 insertions, 0 deletions
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
index 84a7b45b5a..6781433a1f 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
@@ -281,6 +281,41 @@ class AttentionWrapperTest(test.TestCase):
expected_final_alignment_history,
final_alignment_history_info)
+ def testBahdanauNormalizedDType(self):
+ for dtype in [np.float16, np.float32, np.float64]:
+ num_units = 128
+ encoder_outputs = array_ops.placeholder(dtype, shape=[64, None, 256])
+ encoder_sequence_length = array_ops.placeholder(dtypes.int32, shape=[64])
+ decoder_inputs = array_ops.placeholder(dtype, shape=[64, None, 128])
+ decoder_sequence_length = array_ops.placeholder(dtypes.int32, shape=[64])
+ batch_size = 64
+ attention_mechanism = wrapper.BahdanauAttention(
+ num_units=num_units,
+ memory=encoder_outputs,
+ memory_sequence_length=encoder_sequence_length,
+ normalize=True,
+ dtype=dtype,
+ )
+ cell = rnn_cell.LSTMCell(num_units)
+ cell = wrapper.AttentionWrapper(cell, attention_mechanism)
+
+ helper = helper_py.TrainingHelper(decoder_inputs,
+ decoder_sequence_length)
+ my_decoder = basic_decoder.BasicDecoder(
+ cell=cell,
+ helper=helper,
+ initial_state=cell.zero_state(
+ dtype=dtype, batch_size=batch_size))
+
+ final_outputs, final_state, _ = decoder.dynamic_decode(my_decoder)
+ self.assertTrue(
+ isinstance(final_outputs, basic_decoder.BasicDecoderOutput))
+ self.assertEqual(final_outputs.rnn_output.dtype, dtype)
+ self.assertTrue(
+ isinstance(final_state, wrapper.AttentionWrapperState))
+ self.assertTrue(
+ isinstance(final_state.cell_state, rnn_cell.LSTMStateTuple))
+
def testBahdanauNotNormalized(self):
create_attention_mechanism = wrapper.BahdanauAttention