aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/legacy_seq2seq
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-03 12:49:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-03 12:53:46 -0700
commit3f579020bab8f00e4621e9c7c740cbf13136a809 (patch)
treeea65a19dcbeebfa9889fd7a323ac7d3b08acef6a /tensorflow/contrib/legacy_seq2seq
parent9be96491599cd8890092f7010d4afd22862b26dd (diff)
Convert cells to OO-based to reduce call() overhead
PiperOrigin-RevId: 170898081
Diffstat (limited to 'tensorflow/contrib/legacy_seq2seq')
-rw-r--r--tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py11
1 files changed, 7 insertions, 4 deletions
diff --git a/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py b/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py
index d4de638338..8313aa355d 100644
--- a/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py
+++ b/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py
@@ -76,7 +76,7 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.util import nest
# TODO(ebrevdo): Remove once _linear is fully deprecated.
-linear = rnn_cell_impl._linear # pylint: disable=protected-access
+Linear = rnn_cell_impl._Linear # pylint: disable=protected-access,invalid-name
def _extract_argmax_and_embed(embedding,
@@ -645,7 +645,7 @@ def attention_decoder(decoder_inputs,
query = array_ops.concat(query_list, 1)
for a in xrange(num_heads):
with variable_scope.variable_scope("Attention_%d" % a):
- y = linear(query, attention_vec_size, True)
+ y = Linear(query, attention_vec_size, True)(query)
y = array_ops.reshape(y, [-1, 1, 1, attention_vec_size])
# Attention mask is a softmax of v^T * tanh(...).
s = math_ops.reduce_sum(v[a] * math_ops.tanh(hidden_features[a] + y),
@@ -679,7 +679,9 @@ def attention_decoder(decoder_inputs,
input_size = inp.get_shape().with_rank(2)[1]
if input_size.value is None:
raise ValueError("Could not infer input size from input: %s" % inp.name)
- x = linear([inp] + attns, input_size, True)
+
+ inputs = [inp] + attns
+ x = Linear(inputs, input_size, True)(inputs)
# Run the RNN.
cell_output, state = cell(x, state)
# Run the attention mechanism.
@@ -691,7 +693,8 @@ def attention_decoder(decoder_inputs,
attns = attention(state)
with variable_scope.variable_scope("AttnOutputProjection"):
- output = linear([cell_output] + attns, output_size, True)
+ inputs = [cell_output] + attns
+ output = Linear(inputs, output_size, True)(inputs)
if loop_function is not None:
prev = output
outputs.append(output)