aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/legacy_seq2seq
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2017-05-22 17:32:50 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-22 17:36:11 -0700
commit827d2e4b9180db67853f60c125e548d83986b96c (patch)
tree1ccaf8f20bf678ec755330b488eb28946dbe38e6 /tensorflow/contrib/legacy_seq2seq
parent95719e869c61c78a4b0ac0407e1fb04e60daca35 (diff)
Move many of the "core" RNNCells and rnn functions back to TF core.
Unit test files will move in a followup PR. This is the big API change. The old behavior (using tf.contrib.rnn....) will continue to work for backwards compatibility. PiperOrigin-RevId: 156809677
Diffstat (limited to 'tensorflow/contrib/legacy_seq2seq')
-rw-r--r--tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py95
-rw-r--r--tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py47
2 files changed, 64 insertions, 78 deletions
diff --git a/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py b/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py
index 2898935a47..4395138e20 100644
--- a/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py
+++ b/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py
@@ -25,8 +25,7 @@ import random
import numpy as np
from tensorflow.contrib.legacy_seq2seq.python.ops import seq2seq as seq2seq_lib
-from tensorflow.contrib.rnn.python.ops import core_rnn
-from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
+from tensorflow.contrib.rnn.python.ops import core_rnn_cell
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -37,6 +36,7 @@ from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import nn_impl
from tensorflow.python.ops import rnn
+from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
@@ -51,11 +51,10 @@ class Seq2SeqTest(test.TestCase):
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
- _, enc_state = core_rnn.static_rnn(
- core_rnn_cell_impl.GRUCell(2), inp, dtype=dtypes.float32)
+ _, enc_state = rnn.static_rnn(
+ rnn_cell.GRUCell(2), inp, dtype=dtypes.float32)
dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3
- cell = core_rnn_cell_impl.OutputProjectionWrapper(
- core_rnn_cell_impl.GRUCell(2), 4)
+ cell = core_rnn_cell.OutputProjectionWrapper(rnn_cell.GRUCell(2), 4)
dec, mem = seq2seq_lib.rnn_decoder(dec_inp, enc_state, cell)
sess.run([variables.global_variables_initializer()])
res = sess.run(dec)
@@ -71,8 +70,7 @@ class Seq2SeqTest(test.TestCase):
"root", initializer=init_ops.constant_initializer(0.5)):
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3
- cell = core_rnn_cell_impl.OutputProjectionWrapper(
- core_rnn_cell_impl.GRUCell(2), 4)
+ cell = core_rnn_cell.OutputProjectionWrapper(rnn_cell.GRUCell(2), 4)
dec, mem = seq2seq_lib.basic_rnn_seq2seq(inp, dec_inp, cell)
sess.run([variables.global_variables_initializer()])
res = sess.run(dec)
@@ -88,8 +86,7 @@ class Seq2SeqTest(test.TestCase):
"root", initializer=init_ops.constant_initializer(0.5)):
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3
- cell = core_rnn_cell_impl.OutputProjectionWrapper(
- core_rnn_cell_impl.GRUCell(2), 4)
+ cell = core_rnn_cell.OutputProjectionWrapper(rnn_cell.GRUCell(2), 4)
dec, mem = seq2seq_lib.tied_rnn_seq2seq(inp, dec_inp, cell)
sess.run([variables.global_variables_initializer()])
res = sess.run(dec)
@@ -105,9 +102,9 @@ class Seq2SeqTest(test.TestCase):
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
- cell_fn = lambda: core_rnn_cell_impl.BasicLSTMCell(2)
+ cell_fn = lambda: rnn_cell.BasicLSTMCell(2)
cell = cell_fn()
- _, enc_state = core_rnn.static_rnn(cell, inp, dtype=dtypes.float32)
+ _, enc_state = rnn.static_rnn(cell, inp, dtype=dtypes.float32)
dec_inp = [
constant_op.constant(
i, dtypes.int32, shape=[2]) for i in range(3)
@@ -138,7 +135,7 @@ class Seq2SeqTest(test.TestCase):
constant_op.constant(
i, dtypes.int32, shape=[2]) for i in range(3)
]
- cell_fn = lambda: core_rnn_cell_impl.BasicLSTMCell(2)
+ cell_fn = lambda: rnn_cell.BasicLSTMCell(2)
cell = cell_fn()
dec, mem = seq2seq_lib.embedding_rnn_seq2seq(
enc_inp,
@@ -158,7 +155,7 @@ class Seq2SeqTest(test.TestCase):
# Test with state_is_tuple=False.
with variable_scope.variable_scope("no_tuple"):
- cell_nt = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)
+ cell_nt = rnn_cell.BasicLSTMCell(2, state_is_tuple=False)
dec, mem = seq2seq_lib.embedding_rnn_seq2seq(
enc_inp,
dec_inp,
@@ -242,9 +239,7 @@ class Seq2SeqTest(test.TestCase):
constant_op.constant(
i, dtypes.int32, shape=[2]) for i in range(3)
]
- cell = functools.partial(
- core_rnn_cell_impl.BasicLSTMCell,
- 2, state_is_tuple=True)
+ cell = functools.partial(rnn_cell.BasicLSTMCell, 2, state_is_tuple=True)
dec, mem = seq2seq_lib.embedding_tied_rnn_seq2seq(
enc_inp, dec_inp, cell(), num_symbols=5, embedding_size=2)
sess.run([variables.global_variables_initializer()])
@@ -324,11 +319,10 @@ class Seq2SeqTest(test.TestCase):
with self.test_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
- cell_fn = lambda: core_rnn_cell_impl.GRUCell(2)
+ cell_fn = lambda: rnn_cell.GRUCell(2)
cell = cell_fn()
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
- enc_outputs, enc_state = core_rnn.static_rnn(
- cell, inp, dtype=dtypes.float32)
+ enc_outputs, enc_state = rnn.static_rnn(cell, inp, dtype=dtypes.float32)
attn_states = array_ops.concat([
array_ops.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs
], 1)
@@ -350,11 +344,10 @@ class Seq2SeqTest(test.TestCase):
with self.test_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
- cell_fn = lambda: core_rnn_cell_impl.GRUCell(2)
+ cell_fn = lambda: rnn_cell.GRUCell(2)
cell = cell_fn()
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
- enc_outputs, enc_state = core_rnn.static_rnn(
- cell, inp, dtype=dtypes.float32)
+ enc_outputs, enc_state = rnn.static_rnn(cell, inp, dtype=dtypes.float32)
attn_states = array_ops.concat([
array_ops.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs
], 1)
@@ -377,7 +370,7 @@ class Seq2SeqTest(test.TestCase):
with self.test_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
- cell_fn = lambda: core_rnn_cell_impl.GRUCell(2)
+ cell_fn = lambda: rnn_cell.GRUCell(2)
cell = cell_fn()
inp = constant_op.constant(0.5, shape=[2, 2, 2])
enc_outputs, enc_state = rnn.dynamic_rnn(
@@ -401,7 +394,7 @@ class Seq2SeqTest(test.TestCase):
with self.test_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
- cell_fn = lambda: core_rnn_cell_impl.GRUCell(2)
+ cell_fn = lambda: rnn_cell.GRUCell(2)
cell = cell_fn()
inp = constant_op.constant(0.5, shape=[2, 2, 2])
enc_outputs, enc_state = rnn.dynamic_rnn(
@@ -426,14 +419,13 @@ class Seq2SeqTest(test.TestCase):
with self.test_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
- single_cell = lambda: core_rnn_cell_impl.BasicLSTMCell( # pylint: disable=g-long-lambda
+ single_cell = lambda: rnn_cell.BasicLSTMCell( # pylint: disable=g-long-lambda
2, state_is_tuple=True)
- cell_fn = lambda: core_rnn_cell_impl.MultiRNNCell( # pylint: disable=g-long-lambda
+ cell_fn = lambda: rnn_cell.MultiRNNCell( # pylint: disable=g-long-lambda
cells=[single_cell() for _ in range(2)], state_is_tuple=True)
cell = cell_fn()
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
- enc_outputs, enc_state = core_rnn.static_rnn(
- cell, inp, dtype=dtypes.float32)
+ enc_outputs, enc_state = rnn.static_rnn(cell, inp, dtype=dtypes.float32)
attn_states = array_ops.concat([
array_ops.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs
], 1)
@@ -459,12 +451,11 @@ class Seq2SeqTest(test.TestCase):
with self.test_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
- cell_fn = lambda: core_rnn_cell_impl.MultiRNNCell( # pylint: disable=g-long-lambda
- cells=[core_rnn_cell_impl.BasicLSTMCell(2) for _ in range(2)])
+ cell_fn = lambda: rnn_cell.MultiRNNCell( # pylint: disable=g-long-lambda
+ cells=[rnn_cell.BasicLSTMCell(2) for _ in range(2)])
cell = cell_fn()
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
- enc_outputs, enc_state = core_rnn.static_rnn(
- cell, inp, dtype=dtypes.float32)
+ enc_outputs, enc_state = rnn.static_rnn(cell, inp, dtype=dtypes.float32)
attn_states = array_ops.concat([
array_ops.reshape(e, [-1, 1, cell.output_size])
for e in enc_outputs
@@ -492,10 +483,9 @@ class Seq2SeqTest(test.TestCase):
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
- cell_fn = lambda: core_rnn_cell_impl.GRUCell(2)
+ cell_fn = lambda: rnn_cell.GRUCell(2)
cell = cell_fn()
- enc_outputs, enc_state = core_rnn.static_rnn(
- cell, inp, dtype=dtypes.float32)
+ enc_outputs, enc_state = rnn.static_rnn(cell, inp, dtype=dtypes.float32)
attn_states = array_ops.concat([
array_ops.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs
], 1)
@@ -534,7 +524,7 @@ class Seq2SeqTest(test.TestCase):
constant_op.constant(
i, dtypes.int32, shape=[2]) for i in range(3)
]
- cell_fn = lambda: core_rnn_cell_impl.BasicLSTMCell(2)
+ cell_fn = lambda: rnn_cell.BasicLSTMCell(2)
cell = cell_fn()
dec, mem = seq2seq_lib.embedding_attention_seq2seq(
enc_inp,
@@ -555,8 +545,7 @@ class Seq2SeqTest(test.TestCase):
# Test with state_is_tuple=False.
with variable_scope.variable_scope("no_tuple"):
cell_fn = functools.partial(
- core_rnn_cell_impl.BasicLSTMCell,
- 2, state_is_tuple=False)
+ rnn_cell.BasicLSTMCell, 2, state_is_tuple=False)
cell_nt = cell_fn()
dec, mem = seq2seq_lib.embedding_attention_seq2seq(
enc_inp,
@@ -651,11 +640,10 @@ class Seq2SeqTest(test.TestCase):
]
dec_symbols_dict = {"0": 5, "1": 6}
def EncCellFn():
- return core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=True)
+ return rnn_cell.BasicLSTMCell(2, state_is_tuple=True)
def DecCellsFn():
- return dict(
- (k, core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=True))
- for k in dec_symbols_dict)
+ return dict((k, rnn_cell.BasicLSTMCell(2, state_is_tuple=True))
+ for k in dec_symbols_dict)
outputs_dict, state_dict = (seq2seq_lib.one2many_rnn_seq2seq(
enc_inp, dec_inp_dict, EncCellFn(), DecCellsFn(),
2, dec_symbols_dict, embedding_size=2))
@@ -796,8 +784,8 @@ class Seq2SeqTest(test.TestCase):
# """Example sequence-to-sequence model that uses GRU cells."""
# def GRUSeq2Seq(enc_inp, dec_inp):
- # cell = core_rnn_cell_impl.MultiRNNCell(
- # [core_rnn_cell_impl.GRUCell(24) for _ in range(2)])
+ # cell = rnn_cell.MultiRNNCell(
+ # [rnn_cell.GRUCell(24) for _ in range(2)])
# return seq2seq_lib.embedding_attention_seq2seq(
# enc_inp,
# dec_inp,
@@ -862,9 +850,8 @@ class Seq2SeqTest(test.TestCase):
"""Example sequence-to-sequence model that uses GRU cells."""
def GRUSeq2Seq(enc_inp, dec_inp):
- cell = core_rnn_cell_impl.MultiRNNCell(
- [core_rnn_cell_impl.GRUCell(24) for _ in range(2)],
- state_is_tuple=True)
+ cell = rnn_cell.MultiRNNCell(
+ [rnn_cell.GRUCell(24) for _ in range(2)], state_is_tuple=True)
return seq2seq_lib.embedding_attention_seq2seq(
enc_inp,
dec_inp,
@@ -1040,7 +1027,7 @@ class Seq2SeqTest(test.TestCase):
self.assertAllClose(v_true.eval(), v_false.eval())
def EmbeddingRNNSeq2SeqF(enc_inp, dec_inp, feed_previous):
- cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=True)
+ cell = rnn_cell.BasicLSTMCell(2, state_is_tuple=True)
return seq2seq_lib.embedding_rnn_seq2seq(
enc_inp,
dec_inp,
@@ -1051,7 +1038,7 @@ class Seq2SeqTest(test.TestCase):
feed_previous=feed_previous)
def EmbeddingRNNSeq2SeqNoTupleF(enc_inp, dec_inp, feed_previous):
- cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)
+ cell = rnn_cell.BasicLSTMCell(2, state_is_tuple=False)
return seq2seq_lib.embedding_rnn_seq2seq(
enc_inp,
dec_inp,
@@ -1062,7 +1049,7 @@ class Seq2SeqTest(test.TestCase):
feed_previous=feed_previous)
def EmbeddingTiedRNNSeq2Seq(enc_inp, dec_inp, feed_previous):
- cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=True)
+ cell = rnn_cell.BasicLSTMCell(2, state_is_tuple=True)
return seq2seq_lib.embedding_tied_rnn_seq2seq(
enc_inp,
dec_inp,
@@ -1072,7 +1059,7 @@ class Seq2SeqTest(test.TestCase):
feed_previous=feed_previous)
def EmbeddingTiedRNNSeq2SeqNoTuple(enc_inp, dec_inp, feed_previous):
- cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)
+ cell = rnn_cell.BasicLSTMCell(2, state_is_tuple=False)
return seq2seq_lib.embedding_tied_rnn_seq2seq(
enc_inp,
dec_inp,
@@ -1082,7 +1069,7 @@ class Seq2SeqTest(test.TestCase):
feed_previous=feed_previous)
def EmbeddingAttentionSeq2Seq(enc_inp, dec_inp, feed_previous):
- cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=True)
+ cell = rnn_cell.BasicLSTMCell(2, state_is_tuple=True)
return seq2seq_lib.embedding_attention_seq2seq(
enc_inp,
dec_inp,
@@ -1093,7 +1080,7 @@ class Seq2SeqTest(test.TestCase):
feed_previous=feed_previous)
def EmbeddingAttentionSeq2SeqNoTuple(enc_inp, dec_inp, feed_previous):
- cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)
+ cell = rnn_cell.BasicLSTMCell(2, state_is_tuple=False)
return seq2seq_lib.embedding_attention_seq2seq(
enc_inp,
dec_inp,
diff --git a/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py b/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py
index a80b898156..23b4a73b23 100644
--- a/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py
+++ b/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py
@@ -62,9 +62,7 @@ import copy
from six.moves import xrange # pylint: disable=redefined-builtin
from six.moves import zip # pylint: disable=redefined-builtin
-from tensorflow.contrib.rnn.python.ops import core_rnn
from tensorflow.contrib.rnn.python.ops import core_rnn_cell
-from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -72,11 +70,13 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import rnn
+from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import variable_scope
from tensorflow.python.util import nest
# TODO(ebrevdo): Remove once _linear is fully deprecated.
-linear = core_rnn_cell_impl._linear # pylint: disable=protected-access
+linear = rnn_cell_impl._linear # pylint: disable=protected-access
def _extract_argmax_and_embed(embedding,
@@ -119,7 +119,7 @@ def rnn_decoder(decoder_inputs,
Args:
decoder_inputs: A list of 2D Tensors [batch_size x input_size].
initial_state: 2D Tensor with shape [batch_size x cell.state_size].
- cell: core_rnn_cell.RNNCell defining the cell function and size.
+ cell: rnn_cell.RNNCell defining the cell function and size.
loop_function: If not None, this function will be applied to the i-th output
in order to generate the i+1-st input, and decoder_inputs will be ignored,
except for the first element ("GO" symbol). This can be used for decoding,
@@ -170,7 +170,7 @@ def basic_rnn_seq2seq(encoder_inputs,
Args:
encoder_inputs: A list of 2D Tensors [batch_size x input_size].
decoder_inputs: A list of 2D Tensors [batch_size x input_size].
- cell: core_rnn_cell.RNNCell defining the cell function and size.
+ cell: tf.nn.rnn_cell.RNNCell defining the cell function and size.
dtype: The dtype of the initial state of the RNN cell (default: tf.float32).
scope: VariableScope for the created subgraph; default: "basic_rnn_seq2seq".
@@ -183,7 +183,7 @@ def basic_rnn_seq2seq(encoder_inputs,
"""
with variable_scope.variable_scope(scope or "basic_rnn_seq2seq"):
enc_cell = copy.deepcopy(cell)
- _, enc_state = core_rnn.static_rnn(enc_cell, encoder_inputs, dtype=dtype)
+ _, enc_state = rnn.static_rnn(enc_cell, encoder_inputs, dtype=dtype)
return rnn_decoder(decoder_inputs, enc_state, cell)
@@ -202,7 +202,7 @@ def tied_rnn_seq2seq(encoder_inputs,
Args:
encoder_inputs: A list of 2D Tensors [batch_size x input_size].
decoder_inputs: A list of 2D Tensors [batch_size x input_size].
- cell: core_rnn_cell.RNNCell defining the cell function and size.
+ cell: tf.nn.rnn_cell.RNNCell defining the cell function and size.
loop_function: If not None, this function will be applied to i-th output
in order to generate i+1-th input, and decoder_inputs will be ignored,
except for the first element ("GO" symbol), see rnn_decoder for details.
@@ -219,7 +219,7 @@ def tied_rnn_seq2seq(encoder_inputs,
"""
with variable_scope.variable_scope("combined_tied_rnn_seq2seq"):
scope = scope or "tied_rnn_seq2seq"
- _, enc_state = core_rnn.static_rnn(
+ _, enc_state = rnn.static_rnn(
cell, encoder_inputs, dtype=dtype, scope=scope)
variable_scope.get_variable_scope().reuse_variables()
return rnn_decoder(
@@ -244,7 +244,7 @@ def embedding_rnn_decoder(decoder_inputs,
Args:
decoder_inputs: A list of 1D batch-sized int32 Tensors (decoder inputs).
initial_state: 2D Tensor [batch_size x cell.state_size].
- cell: core_rnn_cell.RNNCell defining the cell function.
+ cell: tf.nn.rnn_cell.RNNCell defining the cell function.
num_symbols: Integer, how many symbols come into the embedding.
embedding_size: Integer, the length of the embedding vector for each symbol.
output_projection: None or a pair (W, B) of output projection weights and
@@ -320,7 +320,7 @@ def embedding_rnn_seq2seq(encoder_inputs,
Args:
encoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
decoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
- cell: core_rnn_cell.RNNCell defining the cell function and size.
+ cell: tf.nn.rnn_cell.RNNCell defining the cell function and size.
num_encoder_symbols: Integer; number of symbols on the encoder side.
num_decoder_symbols: Integer; number of symbols on the decoder side.
embedding_size: Integer, the length of the embedding vector for each symbol.
@@ -360,8 +360,7 @@ def embedding_rnn_seq2seq(encoder_inputs,
encoder_cell,
embedding_classes=num_encoder_symbols,
embedding_size=embedding_size)
- _, encoder_state = core_rnn.static_rnn(
- encoder_cell, encoder_inputs, dtype=dtype)
+ _, encoder_state = rnn.static_rnn(encoder_cell, encoder_inputs, dtype=dtype)
# Decoder.
if output_projection is None:
@@ -431,7 +430,7 @@ def embedding_tied_rnn_seq2seq(encoder_inputs,
Args:
encoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
decoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
- cell: core_rnn_cell.RNNCell defining the cell function and size.
+ cell: tf.nn.rnn_cell.RNNCell defining the cell function and size.
num_symbols: Integer; number of symbols for both encoder and decoder.
embedding_size: Integer, the length of the embedding vector for each symbol.
num_decoder_symbols: Integer; number of output symbols for decoder. If
@@ -560,7 +559,7 @@ def attention_decoder(decoder_inputs,
decoder_inputs: A list of 2D Tensors [batch_size x input_size].
initial_state: 2D Tensor [batch_size x cell.state_size].
attention_states: 3D Tensor [batch_size x attn_length x attn_size].
- cell: core_rnn_cell.RNNCell defining the cell function and size.
+ cell: tf.nn.rnn_cell.RNNCell defining the cell function and size.
output_size: Size of the output vectors; if None, we use cell.output_size.
num_heads: Number of attention heads that read from attention_states.
loop_function: If not None, this function will be applied to i-th output
@@ -720,7 +719,7 @@ def embedding_attention_decoder(decoder_inputs,
decoder_inputs: A list of 1D batch-sized int32 Tensors (decoder inputs).
initial_state: 2D Tensor [batch_size x cell.state_size].
attention_states: 3D Tensor [batch_size x attn_length x attn_size].
- cell: core_rnn_cell.RNNCell defining the cell function.
+ cell: tf.nn.rnn_cell.RNNCell defining the cell function.
num_symbols: Integer, how many symbols come into the embedding.
embedding_size: Integer, the length of the embedding vector for each symbol.
num_heads: Number of attention heads that read from attention_states.
@@ -814,7 +813,7 @@ def embedding_attention_seq2seq(encoder_inputs,
Args:
encoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
decoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
- cell: core_rnn_cell.RNNCell defining the cell function and size.
+ cell: tf.nn.rnn_cell.RNNCell defining the cell function and size.
num_encoder_symbols: Integer; number of symbols on the encoder side.
num_decoder_symbols: Integer; number of symbols on the decoder side.
embedding_size: Integer, the length of the embedding vector for each symbol.
@@ -851,7 +850,7 @@ def embedding_attention_seq2seq(encoder_inputs,
encoder_cell,
embedding_classes=num_encoder_symbols,
embedding_size=embedding_size)
- encoder_outputs, encoder_state = core_rnn.static_rnn(
+ encoder_outputs, encoder_state = rnn.static_rnn(
encoder_cell, encoder_inputs, dtype=dtype)
# First calculate a concatenation of encoder outputs to put attention on.
@@ -937,9 +936,10 @@ def one2many_rnn_seq2seq(encoder_inputs,
the corresponding decoder_inputs; each decoder_inputs is a list of 1D
Tensors of shape [batch_size]; num_decoders is defined as
len(decoder_inputs_dict).
- enc_cell: core_rnn_cell.RNNCell defining the encoder cell function and size.
+ enc_cell: tf.nn.rnn_cell.RNNCell defining the encoder cell function and
+ size.
dec_cells_dict: A dictionary mapping encoder name (string) to an
- instance of core_rnn_cell.RNNCell.
+ instance of tf.nn.rnn_cell.RNNCell.
num_encoder_symbols: Integer; number of symbols on the encoder side.
num_decoder_symbols_dict: A dictionary mapping decoder name (string) to an
integer specifying number of symbols for the corresponding decoder;
@@ -971,12 +971,12 @@ def one2many_rnn_seq2seq(encoder_inputs,
outputs_dict = {}
state_dict = {}
- if not isinstance(enc_cell, core_rnn_cell.RNNCell):
+ if not isinstance(enc_cell, rnn_cell_impl.RNNCell):
raise TypeError("enc_cell is not an RNNCell: %s" % type(enc_cell))
if set(dec_cells_dict) != set(decoder_inputs_dict):
raise ValueError("keys of dec_cells_dict != keys of decodre_inputs_dict")
for dec_cell in dec_cells_dict.values():
- if not isinstance(dec_cell, core_rnn_cell.RNNCell):
+ if not isinstance(dec_cell, rnn_cell_impl.RNNCell):
raise TypeError("dec_cell is not an RNNCell: %s" % type(dec_cell))
with variable_scope.variable_scope(
@@ -988,8 +988,7 @@ def one2many_rnn_seq2seq(encoder_inputs,
enc_cell,
embedding_classes=num_encoder_symbols,
embedding_size=embedding_size)
- _, encoder_state = core_rnn.static_rnn(
- enc_cell, encoder_inputs, dtype=dtype)
+ _, encoder_state = rnn.static_rnn(enc_cell, encoder_inputs, dtype=dtype)
# Decoder.
for name, decoder_inputs in decoder_inputs_dict.items():
@@ -1153,7 +1152,7 @@ def model_with_buckets(encoder_inputs,
The seq2seq argument is a function that defines a sequence-to-sequence model,
e.g., seq2seq = lambda x, y: basic_rnn_seq2seq(
- x, y, core_rnn_cell.GRUCell(24))
+ x, y, rnn_cell.GRUCell(24))
Args:
encoder_inputs: A list of Tensors to feed the encoder; first seq2seq input.