diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-10-03 12:49:33 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-03 12:53:46 -0700 |
commit | 3f579020bab8f00e4621e9c7c740cbf13136a809 (patch) | |
tree | ea65a19dcbeebfa9889fd7a323ac7d3b08acef6a /tensorflow/contrib/legacy_seq2seq | |
parent | 9be96491599cd8890092f7010d4afd22862b26dd (diff) |
Convert cells to OO-based to reduce call() overhead
PiperOrigin-RevId: 170898081
Diffstat (limited to 'tensorflow/contrib/legacy_seq2seq')
-rw-r--r-- | tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py | 11 |
1 files changed, 7 insertions, 4 deletions
diff --git a/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py b/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py index d4de638338..8313aa355d 100644 --- a/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py +++ b/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py @@ -76,7 +76,7 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.util import nest # TODO(ebrevdo): Remove once _linear is fully deprecated. -linear = rnn_cell_impl._linear # pylint: disable=protected-access +Linear = rnn_cell_impl._Linear # pylint: disable=protected-access,invalid-name def _extract_argmax_and_embed(embedding, @@ -645,7 +645,7 @@ def attention_decoder(decoder_inputs, query = array_ops.concat(query_list, 1) for a in xrange(num_heads): with variable_scope.variable_scope("Attention_%d" % a): - y = linear(query, attention_vec_size, True) + y = Linear(query, attention_vec_size, True)(query) y = array_ops.reshape(y, [-1, 1, 1, attention_vec_size]) # Attention mask is a softmax of v^T * tanh(...). s = math_ops.reduce_sum(v[a] * math_ops.tanh(hidden_features[a] + y), @@ -679,7 +679,9 @@ def attention_decoder(decoder_inputs, input_size = inp.get_shape().with_rank(2)[1] if input_size.value is None: raise ValueError("Could not infer input size from input: %s" % inp.name) - x = linear([inp] + attns, input_size, True) + + inputs = [inp] + attns + x = Linear(inputs, input_size, True)(inputs) # Run the RNN. cell_output, state = cell(x, state) # Run the attention mechanism. @@ -691,7 +693,8 @@ def attention_decoder(decoder_inputs, attns = attention(state) with variable_scope.variable_scope("AttnOutputProjection"): - output = linear([cell_output] + attns, output_size, True) + inputs = [cell_output] + attns + output = Linear(inputs, output_size, True)(inputs) if loop_function is not None: prev = output outputs.append(output) |