diff options
author | Jianwei Xie <xiejw@google.com> | 2016-12-21 21:14:42 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-12-21 21:25:23 -0800 |
commit | 37fbebdd6c3c8f274896cc36e6feb5b7e2097a59 (patch) | |
tree | ac8ccfc31f69f46e4bce0b34184c22908af24581 /tensorflow/contrib/legacy_seq2seq | |
parent | 0f0e29e7ba06c50fe4a1a7718e63731b96563a8d (diff) |
Move the implementation code of rnn_cells to contrib.
Change: 142730769
Diffstat (limited to 'tensorflow/contrib/legacy_seq2seq')
-rw-r--r-- | tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py | 57 |
1 files changed, 29 insertions, 28 deletions
diff --git a/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py b/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py index b04b6ce16a..0ab854c37c 100644 --- a/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py +++ b/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py @@ -60,7 +60,9 @@ from __future__ import print_function from six.moves import xrange # pylint: disable=redefined-builtin from six.moves import zip # pylint: disable=redefined-builtin -from tensorflow.contrib.rnn.python.ops import core_rnn as contrib_rnn +from tensorflow.contrib.rnn.python.ops import core_rnn +from tensorflow.contrib.rnn.python.ops import core_rnn_cell +from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -68,13 +70,11 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import rnn_cell -from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import variable_scope from tensorflow.python.util import nest # TODO(ebrevdo): Remove once _linear is fully deprecated. -linear = rnn_cell_impl._linear # pylint: disable=protected-access +linear = core_rnn_cell_impl._linear # pylint: disable=protected-access def _extract_argmax_and_embed(embedding, @@ -117,7 +117,7 @@ def rnn_decoder(decoder_inputs, Args: decoder_inputs: A list of 2D Tensors [batch_size x input_size]. initial_state: 2D Tensor with shape [batch_size x cell.state_size]. - cell: rnn_cell.RNNCell defining the cell function and size. + cell: core_rnn_cell.RNNCell defining the cell function and size. loop_function: If not None, this function will be applied to the i-th output in order to generate the i+1-st input, and decoder_inputs will be ignored, except for the first element ("GO" symbol). This can be used for decoding, @@ -168,7 +168,7 @@ def basic_rnn_seq2seq(encoder_inputs, Args: encoder_inputs: A list of 2D Tensors [batch_size x input_size]. decoder_inputs: A list of 2D Tensors [batch_size x input_size]. - cell: rnn_cell.RNNCell defining the cell function and size. + cell: core_rnn_cell.RNNCell defining the cell function and size. dtype: The dtype of the initial state of the RNN cell (default: tf.float32). scope: VariableScope for the created subgraph; default: "basic_rnn_seq2seq". @@ -180,7 +180,7 @@ def basic_rnn_seq2seq(encoder_inputs, It is a 2D Tensor of shape [batch_size x cell.state_size]. """ with variable_scope.variable_scope(scope or "basic_rnn_seq2seq"): - _, enc_state = contrib_rnn.static_rnn(cell, encoder_inputs, dtype=dtype) + _, enc_state = core_rnn.static_rnn(cell, encoder_inputs, dtype=dtype) return rnn_decoder(decoder_inputs, enc_state, cell) @@ -199,7 +199,7 @@ def tied_rnn_seq2seq(encoder_inputs, Args: encoder_inputs: A list of 2D Tensors [batch_size x input_size]. decoder_inputs: A list of 2D Tensors [batch_size x input_size]. - cell: rnn_cell.RNNCell defining the cell function and size. + cell: core_rnn_cell.RNNCell defining the cell function and size. loop_function: If not None, this function will be applied to i-th output in order to generate i+1-th input, and decoder_inputs will be ignored, except for the first element ("GO" symbol), see rnn_decoder for details. @@ -216,7 +216,7 @@ def tied_rnn_seq2seq(encoder_inputs, """ with variable_scope.variable_scope("combined_tied_rnn_seq2seq"): scope = scope or "tied_rnn_seq2seq" - _, enc_state = contrib_rnn.static_rnn( + _, enc_state = core_rnn.static_rnn( cell, encoder_inputs, dtype=dtype, scope=scope) variable_scope.get_variable_scope().reuse_variables() return rnn_decoder( @@ -241,7 +241,7 @@ def embedding_rnn_decoder(decoder_inputs, Args: decoder_inputs: A list of 1D batch-sized int32 Tensors (decoder inputs). initial_state: 2D Tensor [batch_size x cell.state_size]. - cell: rnn_cell.RNNCell defining the cell function. + cell: core_rnn_cell.RNNCell defining the cell function. num_symbols: Integer, how many symbols come into the embedding. embedding_size: Integer, the length of the embedding vector for each symbol. output_projection: None or a pair (W, B) of output projection weights and @@ -317,7 +317,7 @@ def embedding_rnn_seq2seq(encoder_inputs, Args: encoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. decoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. - cell: rnn_cell.RNNCell defining the cell function and size. + cell: core_rnn_cell.RNNCell defining the cell function and size. num_encoder_symbols: Integer; number of symbols on the encoder side. num_decoder_symbols: Integer; number of symbols on the decoder side. embedding_size: Integer, the length of the embedding vector for each symbol. @@ -352,16 +352,16 @@ def embedding_rnn_seq2seq(encoder_inputs, dtype = scope.dtype # Encoder. - encoder_cell = rnn_cell.EmbeddingWrapper( + encoder_cell = core_rnn_cell.EmbeddingWrapper( cell, embedding_classes=num_encoder_symbols, embedding_size=embedding_size) - _, encoder_state = contrib_rnn.static_rnn( + _, encoder_state = core_rnn.static_rnn( encoder_cell, encoder_inputs, dtype=dtype) # Decoder. if output_projection is None: - cell = rnn_cell.OutputProjectionWrapper(cell, num_decoder_symbols) + cell = core_rnn_cell.OutputProjectionWrapper(cell, num_decoder_symbols) if isinstance(feed_previous, bool): return embedding_rnn_decoder( @@ -427,7 +427,7 @@ def embedding_tied_rnn_seq2seq(encoder_inputs, Args: encoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. decoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. - cell: rnn_cell.RNNCell defining the cell function and size. + cell: core_rnn_cell.RNNCell defining the cell function and size. num_symbols: Integer; number of symbols for both encoder and decoder. embedding_size: Integer, the length of the embedding vector for each symbol. num_decoder_symbols: Integer; number of output symbols for decoder. If @@ -483,7 +483,7 @@ def embedding_tied_rnn_seq2seq(encoder_inputs, if num_decoder_symbols is not None: output_symbols = num_decoder_symbols if output_projection is None: - cell = rnn_cell.OutputProjectionWrapper(cell, output_symbols) + cell = core_rnn_cell.OutputProjectionWrapper(cell, output_symbols) if isinstance(feed_previous, bool): loop_function = _extract_argmax_and_embed(embedding, output_projection, @@ -556,7 +556,7 @@ def attention_decoder(decoder_inputs, decoder_inputs: A list of 2D Tensors [batch_size x input_size]. initial_state: 2D Tensor [batch_size x cell.state_size]. attention_states: 3D Tensor [batch_size x attn_length x attn_size]. - cell: rnn_cell.RNNCell defining the cell function and size. + cell: core_rnn_cell.RNNCell defining the cell function and size. output_size: Size of the output vectors; if None, we use cell.output_size. num_heads: Number of attention heads that read from attention_states. loop_function: If not None, this function will be applied to i-th output @@ -716,7 +716,7 @@ def embedding_attention_decoder(decoder_inputs, decoder_inputs: A list of 1D batch-sized int32 Tensors (decoder inputs). initial_state: 2D Tensor [batch_size x cell.state_size]. attention_states: 3D Tensor [batch_size x attn_length x attn_size]. - cell: rnn_cell.RNNCell defining the cell function. + cell: core_rnn_cell.RNNCell defining the cell function. num_symbols: Integer, how many symbols come into the embedding. embedding_size: Integer, the length of the embedding vector for each symbol. num_heads: Number of attention heads that read from attention_states. @@ -810,7 +810,7 @@ def embedding_attention_seq2seq(encoder_inputs, Args: encoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. decoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. - cell: rnn_cell.RNNCell defining the cell function and size. + cell: core_rnn_cell.RNNCell defining the cell function and size. num_encoder_symbols: Integer; number of symbols on the encoder side. num_decoder_symbols: Integer; number of symbols on the decoder side. embedding_size: Integer, the length of the embedding vector for each symbol. @@ -842,11 +842,11 @@ def embedding_attention_seq2seq(encoder_inputs, scope or "embedding_attention_seq2seq", dtype=dtype) as scope: dtype = scope.dtype # Encoder. - encoder_cell = rnn_cell.EmbeddingWrapper( + encoder_cell = core_rnn_cell.EmbeddingWrapper( cell, embedding_classes=num_encoder_symbols, embedding_size=embedding_size) - encoder_outputs, encoder_state = contrib_rnn.static_rnn( + encoder_outputs, encoder_state = core_rnn.static_rnn( encoder_cell, encoder_inputs, dtype=dtype) # First calculate a concatenation of encoder outputs to put attention on. @@ -858,7 +858,7 @@ def embedding_attention_seq2seq(encoder_inputs, # Decoder. output_size = None if output_projection is None: - cell = rnn_cell.OutputProjectionWrapper(cell, num_decoder_symbols) + cell = core_rnn_cell.OutputProjectionWrapper(cell, num_decoder_symbols) output_size = num_decoder_symbols if isinstance(feed_previous, bool): @@ -931,7 +931,7 @@ def one2many_rnn_seq2seq(encoder_inputs, the corresponding decoder_inputs; each decoder_inputs is a list of 1D Tensors of shape [batch_size]; num_decoders is defined as len(decoder_inputs_dict). - cell: rnn_cell.RNNCell defining the cell function and size. + cell: core_rnn_cell.RNNCell defining the cell function and size. num_encoder_symbols: Integer; number of symbols on the encoder side. num_decoder_symbols_dict: A dictionary mapping decoder name (string) to an integer specifying number of symbols for the corresponding decoder; @@ -964,11 +964,11 @@ def one2many_rnn_seq2seq(encoder_inputs, dtype = scope.dtype # Encoder. - encoder_cell = rnn_cell.EmbeddingWrapper( + encoder_cell = core_rnn_cell.EmbeddingWrapper( cell, embedding_classes=num_encoder_symbols, embedding_size=embedding_size) - _, encoder_state = contrib_rnn.static_rnn( + _, encoder_state = core_rnn.static_rnn( encoder_cell, encoder_inputs, dtype=dtype) # Decoder. @@ -977,8 +977,8 @@ def one2many_rnn_seq2seq(encoder_inputs, with variable_scope.variable_scope("one2many_decoder_" + str( name)) as scope: - decoder_cell = rnn_cell.OutputProjectionWrapper(cell, - num_decoder_symbols) + decoder_cell = core_rnn_cell.OutputProjectionWrapper( + cell, num_decoder_symbols) if isinstance(feed_previous, bool): outputs, state = embedding_rnn_decoder( decoder_inputs, @@ -1127,7 +1127,8 @@ def model_with_buckets(encoder_inputs, """Create a sequence-to-sequence model with support for bucketing. The seq2seq argument is a function that defines a sequence-to-sequence model, - e.g., seq2seq = lambda x, y: basic_rnn_seq2seq(x, y, rnn_cell.GRUCell(24)) + e.g., seq2seq = lambda x, y: basic_rnn_seq2seq( + x, y, core_rnn_cell.GRUCell(24)) Args: encoder_inputs: A list of Tensors to feed the encoder; first seq2seq input. |