aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/seq2seq
diff options
context:
space:
mode:
authorGravatar Yong Tang <yong.tang.github@outlook.com>2018-03-30 10:41:48 -0700
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2018-03-30 10:41:48 -0700
commit65669c0e12ba7aaa43a34db3798b8ac8fd97ba3e (patch)
treed6c9e1a2993ee96b15a32a223149022ef4abb74b /tensorflow/contrib/seq2seq
parent8328d84fb54c59b93161950e50709b401576bcc3 (diff)
Fix issue with Luong attention when scale=True and dtype of tf.float16/tf.float64 (#18106)
* Fix issue with Luong attention when scale=True and dtype=tf.float16/tf.float64 This fix tries to address the issue raised in 18099 where Luong throws a ValueError when scale=True and dtype is not tf.float32. This fix addresses the issue with the additional test case added. This fix fixes 18099. Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Fix pylint issue Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Add test case for Luong attention with scale=True and dtype=float16/float64 Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Add assertEqual to confirm the dtypes of the output Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Diffstat (limited to 'tensorflow/contrib/seq2seq')
-rw-r--r--tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py36
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py3
2 files changed, 38 insertions, 1 deletions
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
index 07b3ad71d4..d508cf3f9d 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
@@ -353,6 +353,42 @@ class AttentionWrapperTest(test.TestCase):
attention_mechanism_depth=9,
name='testLuongNotNormalized')
+ def testLuongScaledDType(self):
+ # Test case for GitHub issue 18099
+ for dtype in [np.float16, np.float32, np.float64]:
+ num_units = 128
+ encoder_outputs = array_ops.placeholder(dtype, shape=[64, None, 256])
+ encoder_sequence_length = array_ops.placeholder(dtypes.int32, shape=[64])
+ decoder_inputs = array_ops.placeholder(dtype, shape=[64, None, 128])
+ decoder_sequence_length = array_ops.placeholder(dtypes.int32, shape=[64])
+ batch_size = 64
+ attention_mechanism = wrapper.LuongAttention(
+ num_units=num_units,
+ memory=encoder_outputs,
+ memory_sequence_length=encoder_sequence_length,
+ scale=True,
+ dtype=dtype,
+ )
+ cell = rnn_cell.LSTMCell(num_units)
+ cell = wrapper.AttentionWrapper(cell, attention_mechanism)
+
+ helper = helper_py.TrainingHelper(decoder_inputs,
+ decoder_sequence_length)
+ my_decoder = basic_decoder.BasicDecoder(
+ cell=cell,
+ helper=helper,
+ initial_state=cell.zero_state(
+ dtype=dtype, batch_size=batch_size))
+
+ final_outputs, final_state, _ = decoder.dynamic_decode(my_decoder)
+ self.assertTrue(
+ isinstance(final_outputs, basic_decoder.BasicDecoderOutput))
+ self.assertEqual(final_outputs.rnn_output.dtype, dtype)
+ self.assertTrue(
+ isinstance(final_state, wrapper.AttentionWrapperState))
+ self.assertTrue(
+ isinstance(final_state.cell_state, rnn_cell.LSTMStateTuple))
+
def testLuongScaled(self):
create_attention_mechanism = functools.partial(
wrapper.LuongAttention, scale=True)
diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
index be53779826..9e0d69593f 100644
--- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
+++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
@@ -339,7 +339,8 @@ def _luong_score(query, keys, scale):
if scale:
# Scalar used in weight scaling
g = variable_scope.get_variable(
- "attention_g", dtype=dtype, initializer=1.)
+ "attention_g", dtype=dtype,
+ initializer=init_ops.ones_initializer, shape=())
score = g * score
return score