aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Colin Raffel <craffel@google.com>2017-11-07 21:05:37 -0800
committerGravatar Andrew Selle <aselle@andyselle.com>2017-11-10 16:14:35 -0800
commit8e729b337fc062369643b592a96cfbacd6e43712 (patch)
tree8a3394066473d5854dfad45d29309fdd6325120f
parent9a6c0eb137d87f0578821793af64232fc54c53b6 (diff)
Fix tf.contrib.seq2seq._monotonic_probability_fn to use a hard sigmoid when mode='hard'.
Also adds tests to make sure the attention probabilities are 0 or 1 when mode='hard'. PiperOrigin-RevId: 174956465
-rw-r--r--tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py37
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py6
2 files changed, 42 insertions, 1 deletions
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
index 91493302b1..01a5540121 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
@@ -33,6 +33,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import variables
from tensorflow.python.ops import variable_scope as vs
@@ -589,6 +590,24 @@ class AttentionWrapperTest(test.TestCase):
expected_final_alignment_history=expected_final_alignment_history,
name='testBahdanauMonotonicNormalized')
+ def testBahdanauMonotonicHard(self):
+ # Run attention mechanism with mode='hard', make sure probabilities are hard
+ b, t, u, d = 10, 20, 30, 40
+ with self.test_session(use_gpu=True) as sess:
+ a = wrapper.BahdanauMonotonicAttention(
+ d,
+ random_ops.random_normal((b, t, u)),
+ mode='hard')
+ # Just feed previous attention as [1, 0, 0, ...]
+ attn = a(random_ops.random_normal((b, d)), array_ops.one_hot([0]*b, t))
+ sess.run(variables.global_variables_initializer())
+ attn_out = attn.eval()
+ # All values should be 0 or 1
+ self.assertTrue(np.all(np.logical_or(attn_out == 0, attn_out == 1)))
+ # Sum of distributions should be 0 or 1 (0 when all p_choose_i are 0)
+ self.assertTrue(np.all(np.logical_or(attn_out.sum(axis=1) == 1,
+ attn_out.sum(axis=1) == 0)))
+
def testLuongMonotonicNotNormalized(self):
create_attention_mechanism = functools.partial(
wrapper.LuongMonotonicAttention, sigmoid_noise=1.0,
@@ -695,6 +714,24 @@ class AttentionWrapperTest(test.TestCase):
expected_final_alignment_history=expected_final_alignment_history,
name='testMultiAttention')
+ def testLuongMonotonicHard(self):
+ # Run attention mechanism with mode='hard', make sure probabilities are hard
+ b, t, u, d = 10, 20, 30, 40
+ with self.test_session(use_gpu=True) as sess:
+ a = wrapper.LuongMonotonicAttention(
+ d,
+ random_ops.random_normal((b, t, u)),
+ mode='hard')
+ # Just feed previous attention as [1, 0, 0, ...]
+ attn = a(random_ops.random_normal((b, d)), array_ops.one_hot([0]*b, t))
+ sess.run(variables.global_variables_initializer())
+ attn_out = attn.eval()
+ # All values should be 0 or 1
+ self.assertTrue(np.all(np.logical_or(attn_out == 0, attn_out == 1)))
+ # Sum of distributions should be 0 or 1 (0 when all p_choose_i are 0)
+ self.assertTrue(np.all(np.logical_or(attn_out.sum(axis=1) == 1,
+ attn_out.sum(axis=1) == 0)))
+
def testMultiAttentionNoAttentionLayer(self):
create_attention_mechanisms = (
wrapper.BahdanauAttention, wrapper.LuongAttention)
diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
index 839df079ee..87230e3355 100644
--- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
+++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
@@ -679,7 +679,11 @@ def _monotonic_probability_fn(score, previous_alignments, sigmoid_noise, mode,
seed=seed)
score += sigmoid_noise*noise
# Compute "choosing" probabilities from the attention scores
- p_choose_i = math_ops.sigmoid(score)
+ if mode == "hard":
+ # When mode is hard, use a hard sigmoid
+ p_choose_i = math_ops.cast(score > 0, score.dtype)
+ else:
+ p_choose_i = math_ops.sigmoid(score)
# Convert from choosing probabilities to attention distribution
return monotonic_attention(p_choose_i, previous_alignments, mode)