diff options
author | Yong Tang <yong.tang.github@outlook.com> | 2018-03-30 10:41:48 -0700 |
---|---|---|
committer | Rasmus Munk Larsen <rmlarsen@google.com> | 2018-03-30 10:41:48 -0700 |
commit | 65669c0e12ba7aaa43a34db3798b8ac8fd97ba3e (patch) | |
tree | d6c9e1a2993ee96b15a32a223149022ef4abb74b /tensorflow/contrib/seq2seq | |
parent | 8328d84fb54c59b93161950e50709b401576bcc3 (diff) |
Fix issue with Luong attention when scale=True and dtype of tf.float16/tf.float64 (#18106)
* Fix issue with Luong attention when scale=True and dtype=tf.float16/tf.float64
This fix tries to address the issue raised in 18099 where
Luong throws a ValueError when scale=True and dtype is not tf.float32.
This fix addresses the issue with the additional test case added.
This fix fixes 18099.
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
* Fix pylint issue
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
* Add test case for Luong attention with scale=True and dtype=float16/float64
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
* Add assertEqual to confirm the dtypes of the output
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Diffstat (limited to 'tensorflow/contrib/seq2seq')
-rw-r--r-- | tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py | 36 | ||||
-rw-r--r-- | tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py | 3 |
2 files changed, 38 insertions, 1 deletions
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py index 07b3ad71d4..d508cf3f9d 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py @@ -353,6 +353,42 @@ class AttentionWrapperTest(test.TestCase): attention_mechanism_depth=9, name='testLuongNotNormalized') + def testLuongScaledDType(self): + # Test case for GitHub issue 18099 + for dtype in [np.float16, np.float32, np.float64]: + num_units = 128 + encoder_outputs = array_ops.placeholder(dtype, shape=[64, None, 256]) + encoder_sequence_length = array_ops.placeholder(dtypes.int32, shape=[64]) + decoder_inputs = array_ops.placeholder(dtype, shape=[64, None, 128]) + decoder_sequence_length = array_ops.placeholder(dtypes.int32, shape=[64]) + batch_size = 64 + attention_mechanism = wrapper.LuongAttention( + num_units=num_units, + memory=encoder_outputs, + memory_sequence_length=encoder_sequence_length, + scale=True, + dtype=dtype, + ) + cell = rnn_cell.LSTMCell(num_units) + cell = wrapper.AttentionWrapper(cell, attention_mechanism) + + helper = helper_py.TrainingHelper(decoder_inputs, + decoder_sequence_length) + my_decoder = basic_decoder.BasicDecoder( + cell=cell, + helper=helper, + initial_state=cell.zero_state( + dtype=dtype, batch_size=batch_size)) + + final_outputs, final_state, _ = decoder.dynamic_decode(my_decoder) + self.assertTrue( + isinstance(final_outputs, basic_decoder.BasicDecoderOutput)) + self.assertEqual(final_outputs.rnn_output.dtype, dtype) + self.assertTrue( + isinstance(final_state, wrapper.AttentionWrapperState)) + self.assertTrue( + isinstance(final_state.cell_state, rnn_cell.LSTMStateTuple)) + def testLuongScaled(self): create_attention_mechanism = functools.partial( wrapper.LuongAttention, scale=True) diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py index be53779826..9e0d69593f 100644 --- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py +++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py @@ -339,7 +339,8 @@ def _luong_score(query, keys, scale): if scale: # Scalar used in weight scaling g = variable_scope.get_variable( - "attention_g", dtype=dtype, initializer=1.) + "attention_g", dtype=dtype, + initializer=init_ops.ones_initializer, shape=()) score = g * score return score |