diff options
author | Eugene Brevdo <ebrevdo@google.com> | 2017-05-22 17:32:50 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-05-22 17:36:11 -0700 |
commit | 827d2e4b9180db67853f60c125e548d83986b96c (patch) | |
tree | 1ccaf8f20bf678ec755330b488eb28946dbe38e6 /tensorflow/contrib/legacy_seq2seq | |
parent | 95719e869c61c78a4b0ac0407e1fb04e60daca35 (diff) |
Move many of the "core" RNNCells and rnn functions back to TF core.
Unit test files will move in a followup PR. This is the big API change.
The old behavior (using tf.contrib.rnn....) will continue to work for
backwards compatibility.
PiperOrigin-RevId: 156809677
Diffstat (limited to 'tensorflow/contrib/legacy_seq2seq')
-rw-r--r-- | tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py | 95 | ||||
-rw-r--r-- | tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py | 47 |
2 files changed, 64 insertions, 78 deletions
diff --git a/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py b/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py index 2898935a47..4395138e20 100644 --- a/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py +++ b/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py @@ -25,8 +25,7 @@ import random import numpy as np from tensorflow.contrib.legacy_seq2seq.python.ops import seq2seq as seq2seq_lib -from tensorflow.contrib.rnn.python.ops import core_rnn -from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl +from tensorflow.contrib.rnn.python.ops import core_rnn_cell from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -37,6 +36,7 @@ from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import init_ops from tensorflow.python.ops import nn_impl from tensorflow.python.ops import rnn +from tensorflow.python.ops import rnn_cell from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables @@ -51,11 +51,10 @@ class Seq2SeqTest(test.TestCase): with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): inp = [constant_op.constant(0.5, shape=[2, 2])] * 2 - _, enc_state = core_rnn.static_rnn( - core_rnn_cell_impl.GRUCell(2), inp, dtype=dtypes.float32) + _, enc_state = rnn.static_rnn( + rnn_cell.GRUCell(2), inp, dtype=dtypes.float32) dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3 - cell = core_rnn_cell_impl.OutputProjectionWrapper( - core_rnn_cell_impl.GRUCell(2), 4) + cell = core_rnn_cell.OutputProjectionWrapper(rnn_cell.GRUCell(2), 4) dec, mem = seq2seq_lib.rnn_decoder(dec_inp, enc_state, cell) sess.run([variables.global_variables_initializer()]) res = sess.run(dec) @@ -71,8 +70,7 @@ class Seq2SeqTest(test.TestCase): "root", initializer=init_ops.constant_initializer(0.5)): inp = [constant_op.constant(0.5, shape=[2, 2])] * 2 dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3 - cell = core_rnn_cell_impl.OutputProjectionWrapper( - core_rnn_cell_impl.GRUCell(2), 4) + cell = core_rnn_cell.OutputProjectionWrapper(rnn_cell.GRUCell(2), 4) dec, mem = seq2seq_lib.basic_rnn_seq2seq(inp, dec_inp, cell) sess.run([variables.global_variables_initializer()]) res = sess.run(dec) @@ -88,8 +86,7 @@ class Seq2SeqTest(test.TestCase): "root", initializer=init_ops.constant_initializer(0.5)): inp = [constant_op.constant(0.5, shape=[2, 2])] * 2 dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3 - cell = core_rnn_cell_impl.OutputProjectionWrapper( - core_rnn_cell_impl.GRUCell(2), 4) + cell = core_rnn_cell.OutputProjectionWrapper(rnn_cell.GRUCell(2), 4) dec, mem = seq2seq_lib.tied_rnn_seq2seq(inp, dec_inp, cell) sess.run([variables.global_variables_initializer()]) res = sess.run(dec) @@ -105,9 +102,9 @@ class Seq2SeqTest(test.TestCase): with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): inp = [constant_op.constant(0.5, shape=[2, 2])] * 2 - cell_fn = lambda: core_rnn_cell_impl.BasicLSTMCell(2) + cell_fn = lambda: rnn_cell.BasicLSTMCell(2) cell = cell_fn() - _, enc_state = core_rnn.static_rnn(cell, inp, dtype=dtypes.float32) + _, enc_state = rnn.static_rnn(cell, inp, dtype=dtypes.float32) dec_inp = [ constant_op.constant( i, dtypes.int32, shape=[2]) for i in range(3) @@ -138,7 +135,7 @@ class Seq2SeqTest(test.TestCase): constant_op.constant( i, dtypes.int32, shape=[2]) for i in range(3) ] - cell_fn = lambda: core_rnn_cell_impl.BasicLSTMCell(2) + cell_fn = lambda: rnn_cell.BasicLSTMCell(2) cell = cell_fn() dec, mem = seq2seq_lib.embedding_rnn_seq2seq( enc_inp, @@ -158,7 +155,7 @@ class Seq2SeqTest(test.TestCase): # Test with state_is_tuple=False. with variable_scope.variable_scope("no_tuple"): - cell_nt = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False) + cell_nt = rnn_cell.BasicLSTMCell(2, state_is_tuple=False) dec, mem = seq2seq_lib.embedding_rnn_seq2seq( enc_inp, dec_inp, @@ -242,9 +239,7 @@ class Seq2SeqTest(test.TestCase): constant_op.constant( i, dtypes.int32, shape=[2]) for i in range(3) ] - cell = functools.partial( - core_rnn_cell_impl.BasicLSTMCell, - 2, state_is_tuple=True) + cell = functools.partial(rnn_cell.BasicLSTMCell, 2, state_is_tuple=True) dec, mem = seq2seq_lib.embedding_tied_rnn_seq2seq( enc_inp, dec_inp, cell(), num_symbols=5, embedding_size=2) sess.run([variables.global_variables_initializer()]) @@ -324,11 +319,10 @@ class Seq2SeqTest(test.TestCase): with self.test_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): - cell_fn = lambda: core_rnn_cell_impl.GRUCell(2) + cell_fn = lambda: rnn_cell.GRUCell(2) cell = cell_fn() inp = [constant_op.constant(0.5, shape=[2, 2])] * 2 - enc_outputs, enc_state = core_rnn.static_rnn( - cell, inp, dtype=dtypes.float32) + enc_outputs, enc_state = rnn.static_rnn(cell, inp, dtype=dtypes.float32) attn_states = array_ops.concat([ array_ops.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs ], 1) @@ -350,11 +344,10 @@ class Seq2SeqTest(test.TestCase): with self.test_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): - cell_fn = lambda: core_rnn_cell_impl.GRUCell(2) + cell_fn = lambda: rnn_cell.GRUCell(2) cell = cell_fn() inp = [constant_op.constant(0.5, shape=[2, 2])] * 2 - enc_outputs, enc_state = core_rnn.static_rnn( - cell, inp, dtype=dtypes.float32) + enc_outputs, enc_state = rnn.static_rnn(cell, inp, dtype=dtypes.float32) attn_states = array_ops.concat([ array_ops.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs ], 1) @@ -377,7 +370,7 @@ class Seq2SeqTest(test.TestCase): with self.test_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): - cell_fn = lambda: core_rnn_cell_impl.GRUCell(2) + cell_fn = lambda: rnn_cell.GRUCell(2) cell = cell_fn() inp = constant_op.constant(0.5, shape=[2, 2, 2]) enc_outputs, enc_state = rnn.dynamic_rnn( @@ -401,7 +394,7 @@ class Seq2SeqTest(test.TestCase): with self.test_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): - cell_fn = lambda: core_rnn_cell_impl.GRUCell(2) + cell_fn = lambda: rnn_cell.GRUCell(2) cell = cell_fn() inp = constant_op.constant(0.5, shape=[2, 2, 2]) enc_outputs, enc_state = rnn.dynamic_rnn( @@ -426,14 +419,13 @@ class Seq2SeqTest(test.TestCase): with self.test_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): - single_cell = lambda: core_rnn_cell_impl.BasicLSTMCell( # pylint: disable=g-long-lambda + single_cell = lambda: rnn_cell.BasicLSTMCell( # pylint: disable=g-long-lambda 2, state_is_tuple=True) - cell_fn = lambda: core_rnn_cell_impl.MultiRNNCell( # pylint: disable=g-long-lambda + cell_fn = lambda: rnn_cell.MultiRNNCell( # pylint: disable=g-long-lambda cells=[single_cell() for _ in range(2)], state_is_tuple=True) cell = cell_fn() inp = [constant_op.constant(0.5, shape=[2, 2])] * 2 - enc_outputs, enc_state = core_rnn.static_rnn( - cell, inp, dtype=dtypes.float32) + enc_outputs, enc_state = rnn.static_rnn(cell, inp, dtype=dtypes.float32) attn_states = array_ops.concat([ array_ops.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs ], 1) @@ -459,12 +451,11 @@ class Seq2SeqTest(test.TestCase): with self.test_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): - cell_fn = lambda: core_rnn_cell_impl.MultiRNNCell( # pylint: disable=g-long-lambda - cells=[core_rnn_cell_impl.BasicLSTMCell(2) for _ in range(2)]) + cell_fn = lambda: rnn_cell.MultiRNNCell( # pylint: disable=g-long-lambda + cells=[rnn_cell.BasicLSTMCell(2) for _ in range(2)]) cell = cell_fn() inp = [constant_op.constant(0.5, shape=[2, 2])] * 2 - enc_outputs, enc_state = core_rnn.static_rnn( - cell, inp, dtype=dtypes.float32) + enc_outputs, enc_state = rnn.static_rnn(cell, inp, dtype=dtypes.float32) attn_states = array_ops.concat([ array_ops.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs @@ -492,10 +483,9 @@ class Seq2SeqTest(test.TestCase): with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): inp = [constant_op.constant(0.5, shape=[2, 2])] * 2 - cell_fn = lambda: core_rnn_cell_impl.GRUCell(2) + cell_fn = lambda: rnn_cell.GRUCell(2) cell = cell_fn() - enc_outputs, enc_state = core_rnn.static_rnn( - cell, inp, dtype=dtypes.float32) + enc_outputs, enc_state = rnn.static_rnn(cell, inp, dtype=dtypes.float32) attn_states = array_ops.concat([ array_ops.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs ], 1) @@ -534,7 +524,7 @@ class Seq2SeqTest(test.TestCase): constant_op.constant( i, dtypes.int32, shape=[2]) for i in range(3) ] - cell_fn = lambda: core_rnn_cell_impl.BasicLSTMCell(2) + cell_fn = lambda: rnn_cell.BasicLSTMCell(2) cell = cell_fn() dec, mem = seq2seq_lib.embedding_attention_seq2seq( enc_inp, @@ -555,8 +545,7 @@ class Seq2SeqTest(test.TestCase): # Test with state_is_tuple=False. with variable_scope.variable_scope("no_tuple"): cell_fn = functools.partial( - core_rnn_cell_impl.BasicLSTMCell, - 2, state_is_tuple=False) + rnn_cell.BasicLSTMCell, 2, state_is_tuple=False) cell_nt = cell_fn() dec, mem = seq2seq_lib.embedding_attention_seq2seq( enc_inp, @@ -651,11 +640,10 @@ class Seq2SeqTest(test.TestCase): ] dec_symbols_dict = {"0": 5, "1": 6} def EncCellFn(): - return core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=True) + return rnn_cell.BasicLSTMCell(2, state_is_tuple=True) def DecCellsFn(): - return dict( - (k, core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=True)) - for k in dec_symbols_dict) + return dict((k, rnn_cell.BasicLSTMCell(2, state_is_tuple=True)) + for k in dec_symbols_dict) outputs_dict, state_dict = (seq2seq_lib.one2many_rnn_seq2seq( enc_inp, dec_inp_dict, EncCellFn(), DecCellsFn(), 2, dec_symbols_dict, embedding_size=2)) @@ -796,8 +784,8 @@ class Seq2SeqTest(test.TestCase): # """Example sequence-to-sequence model that uses GRU cells.""" # def GRUSeq2Seq(enc_inp, dec_inp): - # cell = core_rnn_cell_impl.MultiRNNCell( - # [core_rnn_cell_impl.GRUCell(24) for _ in range(2)]) + # cell = rnn_cell.MultiRNNCell( + # [rnn_cell.GRUCell(24) for _ in range(2)]) # return seq2seq_lib.embedding_attention_seq2seq( # enc_inp, # dec_inp, @@ -862,9 +850,8 @@ class Seq2SeqTest(test.TestCase): """Example sequence-to-sequence model that uses GRU cells.""" def GRUSeq2Seq(enc_inp, dec_inp): - cell = core_rnn_cell_impl.MultiRNNCell( - [core_rnn_cell_impl.GRUCell(24) for _ in range(2)], - state_is_tuple=True) + cell = rnn_cell.MultiRNNCell( + [rnn_cell.GRUCell(24) for _ in range(2)], state_is_tuple=True) return seq2seq_lib.embedding_attention_seq2seq( enc_inp, dec_inp, @@ -1040,7 +1027,7 @@ class Seq2SeqTest(test.TestCase): self.assertAllClose(v_true.eval(), v_false.eval()) def EmbeddingRNNSeq2SeqF(enc_inp, dec_inp, feed_previous): - cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=True) + cell = rnn_cell.BasicLSTMCell(2, state_is_tuple=True) return seq2seq_lib.embedding_rnn_seq2seq( enc_inp, dec_inp, @@ -1051,7 +1038,7 @@ class Seq2SeqTest(test.TestCase): feed_previous=feed_previous) def EmbeddingRNNSeq2SeqNoTupleF(enc_inp, dec_inp, feed_previous): - cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False) + cell = rnn_cell.BasicLSTMCell(2, state_is_tuple=False) return seq2seq_lib.embedding_rnn_seq2seq( enc_inp, dec_inp, @@ -1062,7 +1049,7 @@ class Seq2SeqTest(test.TestCase): feed_previous=feed_previous) def EmbeddingTiedRNNSeq2Seq(enc_inp, dec_inp, feed_previous): - cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=True) + cell = rnn_cell.BasicLSTMCell(2, state_is_tuple=True) return seq2seq_lib.embedding_tied_rnn_seq2seq( enc_inp, dec_inp, @@ -1072,7 +1059,7 @@ class Seq2SeqTest(test.TestCase): feed_previous=feed_previous) def EmbeddingTiedRNNSeq2SeqNoTuple(enc_inp, dec_inp, feed_previous): - cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False) + cell = rnn_cell.BasicLSTMCell(2, state_is_tuple=False) return seq2seq_lib.embedding_tied_rnn_seq2seq( enc_inp, dec_inp, @@ -1082,7 +1069,7 @@ class Seq2SeqTest(test.TestCase): feed_previous=feed_previous) def EmbeddingAttentionSeq2Seq(enc_inp, dec_inp, feed_previous): - cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=True) + cell = rnn_cell.BasicLSTMCell(2, state_is_tuple=True) return seq2seq_lib.embedding_attention_seq2seq( enc_inp, dec_inp, @@ -1093,7 +1080,7 @@ class Seq2SeqTest(test.TestCase): feed_previous=feed_previous) def EmbeddingAttentionSeq2SeqNoTuple(enc_inp, dec_inp, feed_previous): - cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False) + cell = rnn_cell.BasicLSTMCell(2, state_is_tuple=False) return seq2seq_lib.embedding_attention_seq2seq( enc_inp, dec_inp, diff --git a/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py b/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py index a80b898156..23b4a73b23 100644 --- a/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py +++ b/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py @@ -62,9 +62,7 @@ import copy from six.moves import xrange # pylint: disable=redefined-builtin from six.moves import zip # pylint: disable=redefined-builtin -from tensorflow.contrib.rnn.python.ops import core_rnn from tensorflow.contrib.rnn.python.ops import core_rnn_cell -from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -72,11 +70,13 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import rnn +from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import variable_scope from tensorflow.python.util import nest # TODO(ebrevdo): Remove once _linear is fully deprecated. -linear = core_rnn_cell_impl._linear # pylint: disable=protected-access +linear = rnn_cell_impl._linear # pylint: disable=protected-access def _extract_argmax_and_embed(embedding, @@ -119,7 +119,7 @@ def rnn_decoder(decoder_inputs, Args: decoder_inputs: A list of 2D Tensors [batch_size x input_size]. initial_state: 2D Tensor with shape [batch_size x cell.state_size]. - cell: core_rnn_cell.RNNCell defining the cell function and size. + cell: rnn_cell.RNNCell defining the cell function and size. loop_function: If not None, this function will be applied to the i-th output in order to generate the i+1-st input, and decoder_inputs will be ignored, except for the first element ("GO" symbol). This can be used for decoding, @@ -170,7 +170,7 @@ def basic_rnn_seq2seq(encoder_inputs, Args: encoder_inputs: A list of 2D Tensors [batch_size x input_size]. decoder_inputs: A list of 2D Tensors [batch_size x input_size]. - cell: core_rnn_cell.RNNCell defining the cell function and size. + cell: tf.nn.rnn_cell.RNNCell defining the cell function and size. dtype: The dtype of the initial state of the RNN cell (default: tf.float32). scope: VariableScope for the created subgraph; default: "basic_rnn_seq2seq". @@ -183,7 +183,7 @@ def basic_rnn_seq2seq(encoder_inputs, """ with variable_scope.variable_scope(scope or "basic_rnn_seq2seq"): enc_cell = copy.deepcopy(cell) - _, enc_state = core_rnn.static_rnn(enc_cell, encoder_inputs, dtype=dtype) + _, enc_state = rnn.static_rnn(enc_cell, encoder_inputs, dtype=dtype) return rnn_decoder(decoder_inputs, enc_state, cell) @@ -202,7 +202,7 @@ def tied_rnn_seq2seq(encoder_inputs, Args: encoder_inputs: A list of 2D Tensors [batch_size x input_size]. decoder_inputs: A list of 2D Tensors [batch_size x input_size]. - cell: core_rnn_cell.RNNCell defining the cell function and size. + cell: tf.nn.rnn_cell.RNNCell defining the cell function and size. loop_function: If not None, this function will be applied to i-th output in order to generate i+1-th input, and decoder_inputs will be ignored, except for the first element ("GO" symbol), see rnn_decoder for details. @@ -219,7 +219,7 @@ def tied_rnn_seq2seq(encoder_inputs, """ with variable_scope.variable_scope("combined_tied_rnn_seq2seq"): scope = scope or "tied_rnn_seq2seq" - _, enc_state = core_rnn.static_rnn( + _, enc_state = rnn.static_rnn( cell, encoder_inputs, dtype=dtype, scope=scope) variable_scope.get_variable_scope().reuse_variables() return rnn_decoder( @@ -244,7 +244,7 @@ def embedding_rnn_decoder(decoder_inputs, Args: decoder_inputs: A list of 1D batch-sized int32 Tensors (decoder inputs). initial_state: 2D Tensor [batch_size x cell.state_size]. - cell: core_rnn_cell.RNNCell defining the cell function. + cell: tf.nn.rnn_cell.RNNCell defining the cell function. num_symbols: Integer, how many symbols come into the embedding. embedding_size: Integer, the length of the embedding vector for each symbol. output_projection: None or a pair (W, B) of output projection weights and @@ -320,7 +320,7 @@ def embedding_rnn_seq2seq(encoder_inputs, Args: encoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. decoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. - cell: core_rnn_cell.RNNCell defining the cell function and size. + cell: tf.nn.rnn_cell.RNNCell defining the cell function and size. num_encoder_symbols: Integer; number of symbols on the encoder side. num_decoder_symbols: Integer; number of symbols on the decoder side. embedding_size: Integer, the length of the embedding vector for each symbol. @@ -360,8 +360,7 @@ def embedding_rnn_seq2seq(encoder_inputs, encoder_cell, embedding_classes=num_encoder_symbols, embedding_size=embedding_size) - _, encoder_state = core_rnn.static_rnn( - encoder_cell, encoder_inputs, dtype=dtype) + _, encoder_state = rnn.static_rnn(encoder_cell, encoder_inputs, dtype=dtype) # Decoder. if output_projection is None: @@ -431,7 +430,7 @@ def embedding_tied_rnn_seq2seq(encoder_inputs, Args: encoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. decoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. - cell: core_rnn_cell.RNNCell defining the cell function and size. + cell: tf.nn.rnn_cell.RNNCell defining the cell function and size. num_symbols: Integer; number of symbols for both encoder and decoder. embedding_size: Integer, the length of the embedding vector for each symbol. num_decoder_symbols: Integer; number of output symbols for decoder. If @@ -560,7 +559,7 @@ def attention_decoder(decoder_inputs, decoder_inputs: A list of 2D Tensors [batch_size x input_size]. initial_state: 2D Tensor [batch_size x cell.state_size]. attention_states: 3D Tensor [batch_size x attn_length x attn_size]. - cell: core_rnn_cell.RNNCell defining the cell function and size. + cell: tf.nn.rnn_cell.RNNCell defining the cell function and size. output_size: Size of the output vectors; if None, we use cell.output_size. num_heads: Number of attention heads that read from attention_states. loop_function: If not None, this function will be applied to i-th output @@ -720,7 +719,7 @@ def embedding_attention_decoder(decoder_inputs, decoder_inputs: A list of 1D batch-sized int32 Tensors (decoder inputs). initial_state: 2D Tensor [batch_size x cell.state_size]. attention_states: 3D Tensor [batch_size x attn_length x attn_size]. - cell: core_rnn_cell.RNNCell defining the cell function. + cell: tf.nn.rnn_cell.RNNCell defining the cell function. num_symbols: Integer, how many symbols come into the embedding. embedding_size: Integer, the length of the embedding vector for each symbol. num_heads: Number of attention heads that read from attention_states. @@ -814,7 +813,7 @@ def embedding_attention_seq2seq(encoder_inputs, Args: encoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. decoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. - cell: core_rnn_cell.RNNCell defining the cell function and size. + cell: tf.nn.rnn_cell.RNNCell defining the cell function and size. num_encoder_symbols: Integer; number of symbols on the encoder side. num_decoder_symbols: Integer; number of symbols on the decoder side. embedding_size: Integer, the length of the embedding vector for each symbol. @@ -851,7 +850,7 @@ def embedding_attention_seq2seq(encoder_inputs, encoder_cell, embedding_classes=num_encoder_symbols, embedding_size=embedding_size) - encoder_outputs, encoder_state = core_rnn.static_rnn( + encoder_outputs, encoder_state = rnn.static_rnn( encoder_cell, encoder_inputs, dtype=dtype) # First calculate a concatenation of encoder outputs to put attention on. @@ -937,9 +936,10 @@ def one2many_rnn_seq2seq(encoder_inputs, the corresponding decoder_inputs; each decoder_inputs is a list of 1D Tensors of shape [batch_size]; num_decoders is defined as len(decoder_inputs_dict). - enc_cell: core_rnn_cell.RNNCell defining the encoder cell function and size. + enc_cell: tf.nn.rnn_cell.RNNCell defining the encoder cell function and + size. dec_cells_dict: A dictionary mapping encoder name (string) to an - instance of core_rnn_cell.RNNCell. + instance of tf.nn.rnn_cell.RNNCell. num_encoder_symbols: Integer; number of symbols on the encoder side. num_decoder_symbols_dict: A dictionary mapping decoder name (string) to an integer specifying number of symbols for the corresponding decoder; @@ -971,12 +971,12 @@ def one2many_rnn_seq2seq(encoder_inputs, outputs_dict = {} state_dict = {} - if not isinstance(enc_cell, core_rnn_cell.RNNCell): + if not isinstance(enc_cell, rnn_cell_impl.RNNCell): raise TypeError("enc_cell is not an RNNCell: %s" % type(enc_cell)) if set(dec_cells_dict) != set(decoder_inputs_dict): raise ValueError("keys of dec_cells_dict != keys of decodre_inputs_dict") for dec_cell in dec_cells_dict.values(): - if not isinstance(dec_cell, core_rnn_cell.RNNCell): + if not isinstance(dec_cell, rnn_cell_impl.RNNCell): raise TypeError("dec_cell is not an RNNCell: %s" % type(dec_cell)) with variable_scope.variable_scope( @@ -988,8 +988,7 @@ def one2many_rnn_seq2seq(encoder_inputs, enc_cell, embedding_classes=num_encoder_symbols, embedding_size=embedding_size) - _, encoder_state = core_rnn.static_rnn( - enc_cell, encoder_inputs, dtype=dtype) + _, encoder_state = rnn.static_rnn(enc_cell, encoder_inputs, dtype=dtype) # Decoder. for name, decoder_inputs in decoder_inputs_dict.items(): @@ -1153,7 +1152,7 @@ def model_with_buckets(encoder_inputs, The seq2seq argument is a function that defines a sequence-to-sequence model, e.g., seq2seq = lambda x, y: basic_rnn_seq2seq( - x, y, core_rnn_cell.GRUCell(24)) + x, y, rnn_cell.GRUCell(24)) Args: encoder_inputs: A list of Tensors to feed the encoder; first seq2seq input. |