aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/legacy_seq2seq
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-12 14:25:47 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-12 14:29:13 -0700
commit7a3fa74736e4359c208dbb66db38c186d6cf6813 (patch)
tree2d84b0d89ade2f00792b18a31ed147e51b57cf6a /tensorflow/contrib/legacy_seq2seq
parentc61c7f1ea6a5e6aa0af19eb21d03b351031d944c (diff)
Fix support for seq2seq with mixed precision
When the type of the input tensor `x` is not the same as the type of the hidden states cast is required. This mixed precision case occurs when using the seq2seq layer with a data type of float16 or bfloat16. PiperOrigin-RevId: 204364209
Diffstat (limited to 'tensorflow/contrib/legacy_seq2seq')
-rw-r--r--tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py14
1 files changed, 10 insertions, 4 deletions
diff --git a/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py b/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py
index 5e7b422e3c..e742447208 100644
--- a/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py
+++ b/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py
@@ -625,11 +625,13 @@ def attention_decoder(decoder_inputs,
v = []
attention_vec_size = attn_size # Size of query vectors for attention.
for a in xrange(num_heads):
- k = variable_scope.get_variable("AttnW_%d" % a,
- [1, 1, attn_size, attention_vec_size])
+ k = variable_scope.get_variable(
+ "AttnW_%d" % a, [1, 1, attn_size, attention_vec_size],
+ dtype=dtype)
hidden_features.append(nn_ops.conv2d(hidden, k, [1, 1, 1, 1], "SAME"))
v.append(
- variable_scope.get_variable("AttnV_%d" % a, [attention_vec_size]))
+ variable_scope.get_variable(
+ "AttnV_%d" % a, [attention_vec_size], dtype=dtype))
state = initial_state
@@ -647,11 +649,13 @@ def attention_decoder(decoder_inputs,
with variable_scope.variable_scope("Attention_%d" % a):
y = Linear(query, attention_vec_size, True)(query)
y = array_ops.reshape(y, [-1, 1, 1, attention_vec_size])
+ y = math_ops.cast(y, dtype)
# Attention mask is a softmax of v^T * tanh(...).
s = math_ops.reduce_sum(v[a] * math_ops.tanh(hidden_features[a] + y),
[2, 3])
- a = nn_ops.softmax(s)
+ a = nn_ops.softmax(math_ops.cast(s, dtype=dtypes.float32))
# Now calculate the attention-weighted vector d.
+ a = math_ops.cast(a, dtype)
d = math_ops.reduce_sum(
array_ops.reshape(a, [-1, attn_length, 1, 1]) * hidden, [1, 2])
ds.append(array_ops.reshape(d, [-1, attn_size]))
@@ -681,6 +685,7 @@ def attention_decoder(decoder_inputs,
raise ValueError("Could not infer input size from input: %s" % inp.name)
inputs = [inp] + attns
+ inputs = [math_ops.cast(e, dtype) for e in inputs]
x = Linear(inputs, input_size, True)(inputs)
# Run the RNN.
cell_output, state = cell(x, state)
@@ -693,6 +698,7 @@ def attention_decoder(decoder_inputs,
attns = attention(state)
with variable_scope.variable_scope("AttnOutputProjection"):
+ cell_output = math_ops.cast(cell_output, dtype)
inputs = [cell_output] + attns
output = Linear(inputs, output_size, True)(inputs)
if loop_function is not None: