aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/legacy_seq2seq
diff options
context:
space:
mode:
authorGravatar Jianwei Xie <xiejw@google.com>2016-12-21 21:14:42 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-21 21:25:23 -0800
commit37fbebdd6c3c8f274896cc36e6feb5b7e2097a59 (patch)
treeac8ccfc31f69f46e4bce0b34184c22908af24581 /tensorflow/contrib/legacy_seq2seq
parent0f0e29e7ba06c50fe4a1a7718e63731b96563a8d (diff)
Move the implementation code of rnn_cells to contrib.
Change: 142730769
Diffstat (limited to 'tensorflow/contrib/legacy_seq2seq')
-rw-r--r--tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py57
1 files changed, 29 insertions, 28 deletions
diff --git a/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py b/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py
index b04b6ce16a..0ab854c37c 100644
--- a/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py
+++ b/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py
@@ -60,7 +60,9 @@ from __future__ import print_function
from six.moves import xrange # pylint: disable=redefined-builtin
from six.moves import zip # pylint: disable=redefined-builtin
-from tensorflow.contrib.rnn.python.ops import core_rnn as contrib_rnn
+from tensorflow.contrib.rnn.python.ops import core_rnn
+from tensorflow.contrib.rnn.python.ops import core_rnn_cell
+from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -68,13 +70,11 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
-from tensorflow.python.ops import rnn_cell
-from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import variable_scope
from tensorflow.python.util import nest
# TODO(ebrevdo): Remove once _linear is fully deprecated.
-linear = rnn_cell_impl._linear # pylint: disable=protected-access
+linear = core_rnn_cell_impl._linear # pylint: disable=protected-access
def _extract_argmax_and_embed(embedding,
@@ -117,7 +117,7 @@ def rnn_decoder(decoder_inputs,
Args:
decoder_inputs: A list of 2D Tensors [batch_size x input_size].
initial_state: 2D Tensor with shape [batch_size x cell.state_size].
- cell: rnn_cell.RNNCell defining the cell function and size.
+ cell: core_rnn_cell.RNNCell defining the cell function and size.
loop_function: If not None, this function will be applied to the i-th output
in order to generate the i+1-st input, and decoder_inputs will be ignored,
except for the first element ("GO" symbol). This can be used for decoding,
@@ -168,7 +168,7 @@ def basic_rnn_seq2seq(encoder_inputs,
Args:
encoder_inputs: A list of 2D Tensors [batch_size x input_size].
decoder_inputs: A list of 2D Tensors [batch_size x input_size].
- cell: rnn_cell.RNNCell defining the cell function and size.
+ cell: core_rnn_cell.RNNCell defining the cell function and size.
dtype: The dtype of the initial state of the RNN cell (default: tf.float32).
scope: VariableScope for the created subgraph; default: "basic_rnn_seq2seq".
@@ -180,7 +180,7 @@ def basic_rnn_seq2seq(encoder_inputs,
It is a 2D Tensor of shape [batch_size x cell.state_size].
"""
with variable_scope.variable_scope(scope or "basic_rnn_seq2seq"):
- _, enc_state = contrib_rnn.static_rnn(cell, encoder_inputs, dtype=dtype)
+ _, enc_state = core_rnn.static_rnn(cell, encoder_inputs, dtype=dtype)
return rnn_decoder(decoder_inputs, enc_state, cell)
@@ -199,7 +199,7 @@ def tied_rnn_seq2seq(encoder_inputs,
Args:
encoder_inputs: A list of 2D Tensors [batch_size x input_size].
decoder_inputs: A list of 2D Tensors [batch_size x input_size].
- cell: rnn_cell.RNNCell defining the cell function and size.
+ cell: core_rnn_cell.RNNCell defining the cell function and size.
loop_function: If not None, this function will be applied to i-th output
in order to generate i+1-th input, and decoder_inputs will be ignored,
except for the first element ("GO" symbol), see rnn_decoder for details.
@@ -216,7 +216,7 @@ def tied_rnn_seq2seq(encoder_inputs,
"""
with variable_scope.variable_scope("combined_tied_rnn_seq2seq"):
scope = scope or "tied_rnn_seq2seq"
- _, enc_state = contrib_rnn.static_rnn(
+ _, enc_state = core_rnn.static_rnn(
cell, encoder_inputs, dtype=dtype, scope=scope)
variable_scope.get_variable_scope().reuse_variables()
return rnn_decoder(
@@ -241,7 +241,7 @@ def embedding_rnn_decoder(decoder_inputs,
Args:
decoder_inputs: A list of 1D batch-sized int32 Tensors (decoder inputs).
initial_state: 2D Tensor [batch_size x cell.state_size].
- cell: rnn_cell.RNNCell defining the cell function.
+ cell: core_rnn_cell.RNNCell defining the cell function.
num_symbols: Integer, how many symbols come into the embedding.
embedding_size: Integer, the length of the embedding vector for each symbol.
output_projection: None or a pair (W, B) of output projection weights and
@@ -317,7 +317,7 @@ def embedding_rnn_seq2seq(encoder_inputs,
Args:
encoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
decoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
- cell: rnn_cell.RNNCell defining the cell function and size.
+ cell: core_rnn_cell.RNNCell defining the cell function and size.
num_encoder_symbols: Integer; number of symbols on the encoder side.
num_decoder_symbols: Integer; number of symbols on the decoder side.
embedding_size: Integer, the length of the embedding vector for each symbol.
@@ -352,16 +352,16 @@ def embedding_rnn_seq2seq(encoder_inputs,
dtype = scope.dtype
# Encoder.
- encoder_cell = rnn_cell.EmbeddingWrapper(
+ encoder_cell = core_rnn_cell.EmbeddingWrapper(
cell,
embedding_classes=num_encoder_symbols,
embedding_size=embedding_size)
- _, encoder_state = contrib_rnn.static_rnn(
+ _, encoder_state = core_rnn.static_rnn(
encoder_cell, encoder_inputs, dtype=dtype)
# Decoder.
if output_projection is None:
- cell = rnn_cell.OutputProjectionWrapper(cell, num_decoder_symbols)
+ cell = core_rnn_cell.OutputProjectionWrapper(cell, num_decoder_symbols)
if isinstance(feed_previous, bool):
return embedding_rnn_decoder(
@@ -427,7 +427,7 @@ def embedding_tied_rnn_seq2seq(encoder_inputs,
Args:
encoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
decoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
- cell: rnn_cell.RNNCell defining the cell function and size.
+ cell: core_rnn_cell.RNNCell defining the cell function and size.
num_symbols: Integer; number of symbols for both encoder and decoder.
embedding_size: Integer, the length of the embedding vector for each symbol.
num_decoder_symbols: Integer; number of output symbols for decoder. If
@@ -483,7 +483,7 @@ def embedding_tied_rnn_seq2seq(encoder_inputs,
if num_decoder_symbols is not None:
output_symbols = num_decoder_symbols
if output_projection is None:
- cell = rnn_cell.OutputProjectionWrapper(cell, output_symbols)
+ cell = core_rnn_cell.OutputProjectionWrapper(cell, output_symbols)
if isinstance(feed_previous, bool):
loop_function = _extract_argmax_and_embed(embedding, output_projection,
@@ -556,7 +556,7 @@ def attention_decoder(decoder_inputs,
decoder_inputs: A list of 2D Tensors [batch_size x input_size].
initial_state: 2D Tensor [batch_size x cell.state_size].
attention_states: 3D Tensor [batch_size x attn_length x attn_size].
- cell: rnn_cell.RNNCell defining the cell function and size.
+ cell: core_rnn_cell.RNNCell defining the cell function and size.
output_size: Size of the output vectors; if None, we use cell.output_size.
num_heads: Number of attention heads that read from attention_states.
loop_function: If not None, this function will be applied to i-th output
@@ -716,7 +716,7 @@ def embedding_attention_decoder(decoder_inputs,
decoder_inputs: A list of 1D batch-sized int32 Tensors (decoder inputs).
initial_state: 2D Tensor [batch_size x cell.state_size].
attention_states: 3D Tensor [batch_size x attn_length x attn_size].
- cell: rnn_cell.RNNCell defining the cell function.
+ cell: core_rnn_cell.RNNCell defining the cell function.
num_symbols: Integer, how many symbols come into the embedding.
embedding_size: Integer, the length of the embedding vector for each symbol.
num_heads: Number of attention heads that read from attention_states.
@@ -810,7 +810,7 @@ def embedding_attention_seq2seq(encoder_inputs,
Args:
encoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
decoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
- cell: rnn_cell.RNNCell defining the cell function and size.
+ cell: core_rnn_cell.RNNCell defining the cell function and size.
num_encoder_symbols: Integer; number of symbols on the encoder side.
num_decoder_symbols: Integer; number of symbols on the decoder side.
embedding_size: Integer, the length of the embedding vector for each symbol.
@@ -842,11 +842,11 @@ def embedding_attention_seq2seq(encoder_inputs,
scope or "embedding_attention_seq2seq", dtype=dtype) as scope:
dtype = scope.dtype
# Encoder.
- encoder_cell = rnn_cell.EmbeddingWrapper(
+ encoder_cell = core_rnn_cell.EmbeddingWrapper(
cell,
embedding_classes=num_encoder_symbols,
embedding_size=embedding_size)
- encoder_outputs, encoder_state = contrib_rnn.static_rnn(
+ encoder_outputs, encoder_state = core_rnn.static_rnn(
encoder_cell, encoder_inputs, dtype=dtype)
# First calculate a concatenation of encoder outputs to put attention on.
@@ -858,7 +858,7 @@ def embedding_attention_seq2seq(encoder_inputs,
# Decoder.
output_size = None
if output_projection is None:
- cell = rnn_cell.OutputProjectionWrapper(cell, num_decoder_symbols)
+ cell = core_rnn_cell.OutputProjectionWrapper(cell, num_decoder_symbols)
output_size = num_decoder_symbols
if isinstance(feed_previous, bool):
@@ -931,7 +931,7 @@ def one2many_rnn_seq2seq(encoder_inputs,
the corresponding decoder_inputs; each decoder_inputs is a list of 1D
Tensors of shape [batch_size]; num_decoders is defined as
len(decoder_inputs_dict).
- cell: rnn_cell.RNNCell defining the cell function and size.
+ cell: core_rnn_cell.RNNCell defining the cell function and size.
num_encoder_symbols: Integer; number of symbols on the encoder side.
num_decoder_symbols_dict: A dictionary mapping decoder name (string) to an
integer specifying number of symbols for the corresponding decoder;
@@ -964,11 +964,11 @@ def one2many_rnn_seq2seq(encoder_inputs,
dtype = scope.dtype
# Encoder.
- encoder_cell = rnn_cell.EmbeddingWrapper(
+ encoder_cell = core_rnn_cell.EmbeddingWrapper(
cell,
embedding_classes=num_encoder_symbols,
embedding_size=embedding_size)
- _, encoder_state = contrib_rnn.static_rnn(
+ _, encoder_state = core_rnn.static_rnn(
encoder_cell, encoder_inputs, dtype=dtype)
# Decoder.
@@ -977,8 +977,8 @@ def one2many_rnn_seq2seq(encoder_inputs,
with variable_scope.variable_scope("one2many_decoder_" + str(
name)) as scope:
- decoder_cell = rnn_cell.OutputProjectionWrapper(cell,
- num_decoder_symbols)
+ decoder_cell = core_rnn_cell.OutputProjectionWrapper(
+ cell, num_decoder_symbols)
if isinstance(feed_previous, bool):
outputs, state = embedding_rnn_decoder(
decoder_inputs,
@@ -1127,7 +1127,8 @@ def model_with_buckets(encoder_inputs,
"""Create a sequence-to-sequence model with support for bucketing.
The seq2seq argument is a function that defines a sequence-to-sequence model,
- e.g., seq2seq = lambda x, y: basic_rnn_seq2seq(x, y, rnn_cell.GRUCell(24))
+ e.g., seq2seq = lambda x, y: basic_rnn_seq2seq(
+ x, y, core_rnn_cell.GRUCell(24))
Args:
encoder_inputs: A list of Tensors to feed the encoder; first seq2seq input.