diff options
author | Jianwei Xie <xiejw@google.com> | 2016-12-15 09:00:04 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-12-15 09:06:41 -0800 |
commit | 5a21e678bd8b0fe294717a0362d41c6dfc467cef (patch) | |
tree | 8e386fe2a0abcd85b5e334551f5ffbd1bebe61ad /tensorflow/contrib/legacy_seq2seq | |
parent | fd7ff167e6f02fe0966fa70ef52a99d16e0490ec (diff) |
Move the implementation code of static_rnn, static_bidirectional_rnn and static_state_saving_rnn from core to contrib.
Change: 142148394
Diffstat (limited to 'tensorflow/contrib/legacy_seq2seq')
-rw-r--r-- | tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py | 18 |
1 files changed, 10 insertions, 8 deletions
diff --git a/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py b/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py index d2018bb219..b04b6ce16a 100644 --- a/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py +++ b/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py @@ -60,6 +60,7 @@ from __future__ import print_function from six.moves import xrange # pylint: disable=redefined-builtin from six.moves import zip # pylint: disable=redefined-builtin +from tensorflow.contrib.rnn.python.ops import core_rnn as contrib_rnn from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -67,7 +68,6 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import rnn from tensorflow.python.ops import rnn_cell from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import variable_scope @@ -180,7 +180,7 @@ def basic_rnn_seq2seq(encoder_inputs, It is a 2D Tensor of shape [batch_size x cell.state_size]. """ with variable_scope.variable_scope(scope or "basic_rnn_seq2seq"): - _, enc_state = rnn.rnn(cell, encoder_inputs, dtype=dtype) + _, enc_state = contrib_rnn.static_rnn(cell, encoder_inputs, dtype=dtype) return rnn_decoder(decoder_inputs, enc_state, cell) @@ -216,7 +216,8 @@ def tied_rnn_seq2seq(encoder_inputs, """ with variable_scope.variable_scope("combined_tied_rnn_seq2seq"): scope = scope or "tied_rnn_seq2seq" - _, enc_state = rnn.rnn(cell, encoder_inputs, dtype=dtype, scope=scope) + _, enc_state = contrib_rnn.static_rnn( + cell, encoder_inputs, dtype=dtype, scope=scope) variable_scope.get_variable_scope().reuse_variables() return rnn_decoder( decoder_inputs, @@ -355,7 +356,8 @@ def embedding_rnn_seq2seq(encoder_inputs, cell, embedding_classes=num_encoder_symbols, embedding_size=embedding_size) - _, encoder_state = rnn.rnn(encoder_cell, encoder_inputs, dtype=dtype) + _, encoder_state = contrib_rnn.static_rnn( + encoder_cell, encoder_inputs, dtype=dtype) # Decoder. if output_projection is None: @@ -844,9 +846,8 @@ def embedding_attention_seq2seq(encoder_inputs, cell, embedding_classes=num_encoder_symbols, embedding_size=embedding_size) - encoder_outputs, encoder_state = rnn.rnn(encoder_cell, - encoder_inputs, - dtype=dtype) + encoder_outputs, encoder_state = contrib_rnn.static_rnn( + encoder_cell, encoder_inputs, dtype=dtype) # First calculate a concatenation of encoder outputs to put attention on. top_states = [ @@ -967,7 +968,8 @@ def one2many_rnn_seq2seq(encoder_inputs, cell, embedding_classes=num_encoder_symbols, embedding_size=embedding_size) - _, encoder_state = rnn.rnn(encoder_cell, encoder_inputs, dtype=dtype) + _, encoder_state = contrib_rnn.static_rnn( + encoder_cell, encoder_inputs, dtype=dtype) # Decoder. for name, decoder_inputs in decoder_inputs_dict.items(): |