diff options
author | Yong Tang <yong.tang.github@outlook.com> | 2018-04-01 02:07:30 +0000 |
---|---|---|
committer | Yong Tang <yong.tang.github@outlook.com> | 2018-04-16 20:04:11 +0000 |
commit | fe4ab63ab258d67f37844f374db265130ceecf2a (patch) | |
tree | fef3501e8341025ddf04f246a3ba429600b5ba78 /tensorflow/contrib/seq2seq | |
parent | 35b8a8cfebe910687f3cc038c00a6e33ba09637a (diff) |
Add test case for Bahdanau attention when normalized=True and dtype = float16/32
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Diffstat (limited to 'tensorflow/contrib/seq2seq')
-rw-r--r-- | tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py | 35 |
1 files changed, 35 insertions, 0 deletions
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py index 84a7b45b5a..6781433a1f 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py @@ -281,6 +281,41 @@ class AttentionWrapperTest(test.TestCase): expected_final_alignment_history, final_alignment_history_info) + def testBahdanauNormalizedDType(self): + for dtype in [np.float16, np.float32, np.float64]: + num_units = 128 + encoder_outputs = array_ops.placeholder(dtype, shape=[64, None, 256]) + encoder_sequence_length = array_ops.placeholder(dtypes.int32, shape=[64]) + decoder_inputs = array_ops.placeholder(dtype, shape=[64, None, 128]) + decoder_sequence_length = array_ops.placeholder(dtypes.int32, shape=[64]) + batch_size = 64 + attention_mechanism = wrapper.BahdanauAttention( + num_units=num_units, + memory=encoder_outputs, + memory_sequence_length=encoder_sequence_length, + normalize=True, + dtype=dtype, + ) + cell = rnn_cell.LSTMCell(num_units) + cell = wrapper.AttentionWrapper(cell, attention_mechanism) + + helper = helper_py.TrainingHelper(decoder_inputs, + decoder_sequence_length) + my_decoder = basic_decoder.BasicDecoder( + cell=cell, + helper=helper, + initial_state=cell.zero_state( + dtype=dtype, batch_size=batch_size)) + + final_outputs, final_state, _ = decoder.dynamic_decode(my_decoder) + self.assertTrue( + isinstance(final_outputs, basic_decoder.BasicDecoderOutput)) + self.assertEqual(final_outputs.rnn_output.dtype, dtype) + self.assertTrue( + isinstance(final_state, wrapper.AttentionWrapperState)) + self.assertTrue( + isinstance(final_state.cell_state, rnn_cell.LSTMStateTuple)) + def testBahdanauNotNormalized(self): create_attention_mechanism = wrapper.BahdanauAttention |