aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/seq2seq
diff options
context:
space:
mode:
authorGravatar Yong Tang <yong.tang.github@outlook.com>2018-04-01 02:01:47 +0000
committerGravatar Yong Tang <yong.tang.github@outlook.com>2018-04-16 20:04:11 +0000
commit35b8a8cfebe910687f3cc038c00a6e33ba09637a (patch)
tree15a475aa44d1d4415df19b2c5c2d2eee6c4feb65 /tensorflow/contrib/seq2seq
parentf7da45918a6e0d7981eeb2ddf03439ace17f3af7 (diff)
Fix the issue with Bahdanau attention when normalized=True and dtype = float16/32
While revisiting 18016 I noticed that Bahdanau attention has a similiar dtype mismatch issue when normalized=True. The issue comes from: ``` g = variable_scope.get_variable( "attention_g", dtype=dtype, initializer=math.sqrt((1. / num_units))) ``` where the initializer value does not work well with differnt dtype. This fix converts changes the initializer to `init_ops.constant_initializer` to address the issue, and adds additional test cases for it. Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Diffstat (limited to 'tensorflow/contrib/seq2seq')
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
index 9ba541ce23..867e49b565 100644
--- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
+++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
@@ -472,7 +472,7 @@ def _bahdanau_score(processed_query, keys, normalize):
# Scalar used in weight normalization
g = variable_scope.get_variable(
"attention_g", dtype=dtype,
- initializer=math.sqrt((1. / num_units)))
+ initializer=init_ops.constant_initializer(math.sqrt((1. / num_units))), shape=())
# Bias added prior to the nonlinearity
b = variable_scope.get_variable(
"attention_b", [num_units], dtype=dtype,