diff options
-rw-r--r-- | tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py | 37 | ||||
-rw-r--r-- | tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py | 6 |
2 files changed, 42 insertions, 1 deletions
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py index 91493302b1..01a5540121 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py @@ -33,6 +33,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops from tensorflow.python.ops import rnn_cell from tensorflow.python.ops import variables from tensorflow.python.ops import variable_scope as vs @@ -589,6 +590,24 @@ class AttentionWrapperTest(test.TestCase): expected_final_alignment_history=expected_final_alignment_history, name='testBahdanauMonotonicNormalized') + def testBahdanauMonotonicHard(self): + # Run attention mechanism with mode='hard', make sure probabilities are hard + b, t, u, d = 10, 20, 30, 40 + with self.test_session(use_gpu=True) as sess: + a = wrapper.BahdanauMonotonicAttention( + d, + random_ops.random_normal((b, t, u)), + mode='hard') + # Just feed previous attention as [1, 0, 0, ...] + attn = a(random_ops.random_normal((b, d)), array_ops.one_hot([0]*b, t)) + sess.run(variables.global_variables_initializer()) + attn_out = attn.eval() + # All values should be 0 or 1 + self.assertTrue(np.all(np.logical_or(attn_out == 0, attn_out == 1))) + # Sum of distributions should be 0 or 1 (0 when all p_choose_i are 0) + self.assertTrue(np.all(np.logical_or(attn_out.sum(axis=1) == 1, + attn_out.sum(axis=1) == 0))) + def testLuongMonotonicNotNormalized(self): create_attention_mechanism = functools.partial( wrapper.LuongMonotonicAttention, sigmoid_noise=1.0, @@ -695,6 +714,24 @@ class AttentionWrapperTest(test.TestCase): expected_final_alignment_history=expected_final_alignment_history, name='testMultiAttention') + def testLuongMonotonicHard(self): + # Run attention mechanism with mode='hard', make sure probabilities are hard + b, t, u, d = 10, 20, 30, 40 + with self.test_session(use_gpu=True) as sess: + a = wrapper.LuongMonotonicAttention( + d, + random_ops.random_normal((b, t, u)), + mode='hard') + # Just feed previous attention as [1, 0, 0, ...] + attn = a(random_ops.random_normal((b, d)), array_ops.one_hot([0]*b, t)) + sess.run(variables.global_variables_initializer()) + attn_out = attn.eval() + # All values should be 0 or 1 + self.assertTrue(np.all(np.logical_or(attn_out == 0, attn_out == 1))) + # Sum of distributions should be 0 or 1 (0 when all p_choose_i are 0) + self.assertTrue(np.all(np.logical_or(attn_out.sum(axis=1) == 1, + attn_out.sum(axis=1) == 0))) + def testMultiAttentionNoAttentionLayer(self): create_attention_mechanisms = ( wrapper.BahdanauAttention, wrapper.LuongAttention) diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py index 839df079ee..87230e3355 100644 --- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py +++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py @@ -679,7 +679,11 @@ def _monotonic_probability_fn(score, previous_alignments, sigmoid_noise, mode, seed=seed) score += sigmoid_noise*noise # Compute "choosing" probabilities from the attention scores - p_choose_i = math_ops.sigmoid(score) + if mode == "hard": + # When mode is hard, use a hard sigmoid + p_choose_i = math_ops.cast(score > 0, score.dtype) + else: + p_choose_i = math_ops.sigmoid(score) # Convert from choosing probabilities to attention distribution return monotonic_attention(p_choose_i, previous_alignments, mode) |