aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/legacy_seq2seq
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2017-02-10 11:45:07 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-10 12:08:37 -0800
commita83290927bc2c7d64560d442d8aaec2d1cf155ab (patch)
tree6bbd916fd950ee464e39bfbaef31669b92f482f3 /tensorflow/contrib/legacy_seq2seq
parentb4e6d2959810dc2be17e578bc68449c253407b1c (diff)
More seq2seq tweaks and test updates as RNNCells become more like layers.
Change: 147180372
Diffstat (limited to 'tensorflow/contrib/legacy_seq2seq')
-rw-r--r--tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py116
-rw-r--r--tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py36
2 files changed, 89 insertions, 63 deletions
diff --git a/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py b/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py
index 0f9f0a955c..5d77593619 100644
--- a/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py
+++ b/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py
@@ -211,24 +211,25 @@ class Seq2SeqTest(test.TestCase):
num_decoder_symbols=5,
embedding_size=2,
feed_previous=constant_op.constant(True))
+ with variable_scope.variable_scope("other_2"):
+ d1, _ = seq2seq_lib.embedding_rnn_seq2seq(
+ enc_inp,
+ dec_inp,
+ cell_fn(),
+ num_encoder_symbols=2,
+ num_decoder_symbols=5,
+ embedding_size=2,
+ feed_previous=True)
+ with variable_scope.variable_scope("other_3"):
+ d2, _ = seq2seq_lib.embedding_rnn_seq2seq(
+ enc_inp,
+ dec_inp2,
+ cell_fn(),
+ num_encoder_symbols=2,
+ num_decoder_symbols=5,
+ embedding_size=2,
+ feed_previous=True)
sess.run([variables.global_variables_initializer()])
- variable_scope.get_variable_scope().reuse_variables()
- d1, _ = seq2seq_lib.embedding_rnn_seq2seq(
- enc_inp,
- dec_inp,
- cell_fn(),
- num_encoder_symbols=2,
- num_decoder_symbols=5,
- embedding_size=2,
- feed_previous=True)
- d2, _ = seq2seq_lib.embedding_rnn_seq2seq(
- enc_inp,
- dec_inp2,
- cell_fn(),
- num_encoder_symbols=2,
- num_decoder_symbols=5,
- embedding_size=2,
- feed_previous=True)
res1 = sess.run(d1)
res2 = sess.run(d2)
res3 = sess.run(d3)
@@ -302,22 +303,23 @@ class Seq2SeqTest(test.TestCase):
num_symbols=5,
embedding_size=2,
feed_previous=constant_op.constant(True))
+ with variable_scope.variable_scope("other_2"):
+ d1, _ = seq2seq_lib.embedding_tied_rnn_seq2seq(
+ enc_inp,
+ dec_inp,
+ cell(),
+ num_symbols=5,
+ embedding_size=2,
+ feed_previous=True)
+ with variable_scope.variable_scope("other_3"):
+ d2, _ = seq2seq_lib.embedding_tied_rnn_seq2seq(
+ enc_inp,
+ dec_inp2,
+ cell(),
+ num_symbols=5,
+ embedding_size=2,
+ feed_previous=True)
sess.run([variables.global_variables_initializer()])
- variable_scope.get_variable_scope().reuse_variables()
- d1, _ = seq2seq_lib.embedding_tied_rnn_seq2seq(
- enc_inp,
- dec_inp,
- cell(),
- num_symbols=5,
- embedding_size=2,
- feed_previous=True)
- d2, _ = seq2seq_lib.embedding_tied_rnn_seq2seq(
- enc_inp,
- dec_inp2,
- cell(),
- num_symbols=5,
- embedding_size=2,
- feed_previous=True)
res1 = sess.run(d1)
res2 = sess.run(d2)
res3 = sess.run(d3)
@@ -654,9 +656,15 @@ class Seq2SeqTest(test.TestCase):
i, dtypes.int32, shape=[2]) for i in range(4)
]
dec_symbols_dict = {"0": 5, "1": 6}
- cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=True)
+ def EncCellFn():
+ return core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=True)
+ def DecCellsFn():
+ return dict(
+ (k, core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=True))
+ for k in dec_symbols_dict)
outputs_dict, state_dict = (seq2seq_lib.one2many_rnn_seq2seq(
- enc_inp, dec_inp_dict, cell, 2, dec_symbols_dict, embedding_size=2))
+ enc_inp, dec_inp_dict, EncCellFn(), DecCellsFn(),
+ 2, dec_symbols_dict, embedding_size=2))
sess.run([variables.global_variables_initializer()])
res = sess.run(outputs_dict["0"])
@@ -688,29 +696,33 @@ class Seq2SeqTest(test.TestCase):
outputs_dict3, _ = seq2seq_lib.one2many_rnn_seq2seq(
enc_inp,
dec_inp_dict2,
- cell,
+ EncCellFn(),
+ DecCellsFn(),
2,
dec_symbols_dict,
embedding_size=2,
feed_previous=constant_op.constant(True))
+ with variable_scope.variable_scope("other_2"):
+ outputs_dict1, _ = seq2seq_lib.one2many_rnn_seq2seq(
+ enc_inp,
+ dec_inp_dict,
+ EncCellFn(),
+ DecCellsFn(),
+ 2,
+ dec_symbols_dict,
+ embedding_size=2,
+ feed_previous=True)
+ with variable_scope.variable_scope("other_3"):
+ outputs_dict2, _ = seq2seq_lib.one2many_rnn_seq2seq(
+ enc_inp,
+ dec_inp_dict2,
+ EncCellFn(),
+ DecCellsFn(),
+ 2,
+ dec_symbols_dict,
+ embedding_size=2,
+ feed_previous=True)
sess.run([variables.global_variables_initializer()])
- variable_scope.get_variable_scope().reuse_variables()
- outputs_dict1, _ = seq2seq_lib.one2many_rnn_seq2seq(
- enc_inp,
- dec_inp_dict,
- cell,
- 2,
- dec_symbols_dict,
- embedding_size=2,
- feed_previous=True)
- outputs_dict2, _ = seq2seq_lib.one2many_rnn_seq2seq(
- enc_inp,
- dec_inp_dict2,
- cell,
- 2,
- dec_symbols_dict,
- embedding_size=2,
- feed_previous=True)
res1 = sess.run(outputs_dict1["0"])
res2 = sess.run(outputs_dict2["0"])
res3 = sess.run(outputs_dict3["0"])
diff --git a/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py b/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py
index 0d6eac33de..8608054deb 100644
--- a/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py
+++ b/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py
@@ -917,7 +917,8 @@ def embedding_attention_seq2seq(encoder_inputs,
def one2many_rnn_seq2seq(encoder_inputs,
decoder_inputs_dict,
- cell,
+ enc_cell,
+ dec_cells_dict,
num_encoder_symbols,
num_decoder_symbols_dict,
embedding_size,
@@ -936,7 +937,9 @@ def one2many_rnn_seq2seq(encoder_inputs,
the corresponding decoder_inputs; each decoder_inputs is a list of 1D
Tensors of shape [batch_size]; num_decoders is defined as
len(decoder_inputs_dict).
- cell: core_rnn_cell.RNNCell defining the cell function and size.
+ enc_cell: core_rnn_cell.RNNCell defining the encoder cell function and size.
+ dec_cells_dict: A dictionary mapping encoder name (string) to an
+ instance of core_rnn_cell.RNNCell.
num_encoder_symbols: Integer; number of symbols on the encoder side.
num_decoder_symbols_dict: A dictionary mapping decoder name (string) to an
integer specifying number of symbols for the corresponding decoder;
@@ -960,37 +963,48 @@ def one2many_rnn_seq2seq(encoder_inputs,
state_dict: A mapping from decoder name (string) to the final state of the
corresponding decoder RNN; it is a 2D Tensor of shape
[batch_size x cell.state_size].
+
+ Raises:
+ TypeError: if enc_cell or any of the dec_cells are not instances of RNNCell.
+ ValueError: if len(dec_cells) != len(decoder_inputs_dict).
"""
outputs_dict = {}
state_dict = {}
+ if not isinstance(enc_cell, core_rnn_cell.RNNCell):
+ raise TypeError("enc_cell is not an RNNCell: %s" % type(enc_cell))
+ if set(dec_cells_dict) != set(decoder_inputs_dict):
+ raise ValueError("keys of dec_cells_dict != keys of decodre_inputs_dict")
+ for dec_cell in dec_cells_dict.values():
+ if not isinstance(dec_cell, core_rnn_cell.RNNCell):
+ raise TypeError("dec_cell is not an RNNCell: %s" % type(dec_cell))
+
with variable_scope.variable_scope(
scope or "one2many_rnn_seq2seq", dtype=dtype) as scope:
dtype = scope.dtype
# Encoder.
- encoder_cell = copy.deepcopy(cell)
- encoder_cell = core_rnn_cell.EmbeddingWrapper(
- encoder_cell,
+ enc_cell = core_rnn_cell.EmbeddingWrapper(
+ enc_cell,
embedding_classes=num_encoder_symbols,
embedding_size=embedding_size)
_, encoder_state = core_rnn.static_rnn(
- encoder_cell, encoder_inputs, dtype=dtype)
+ enc_cell, encoder_inputs, dtype=dtype)
# Decoder.
for name, decoder_inputs in decoder_inputs_dict.items():
num_decoder_symbols = num_decoder_symbols_dict[name]
+ dec_cell = dec_cells_dict[name]
with variable_scope.variable_scope("one2many_decoder_" + str(
name)) as scope:
- decoder_cell = copy.deepcopy(cell)
- decoder_cell = core_rnn_cell.OutputProjectionWrapper(
- decoder_cell, num_decoder_symbols)
+ dec_cell = core_rnn_cell.OutputProjectionWrapper(
+ dec_cell, num_decoder_symbols)
if isinstance(feed_previous, bool):
outputs, state = embedding_rnn_decoder(
decoder_inputs,
encoder_state,
- decoder_cell,
+ dec_cell,
num_decoder_symbols,
embedding_size,
feed_previous=feed_previous)
@@ -1005,7 +1019,7 @@ def one2many_rnn_seq2seq(encoder_inputs,
outputs, state = embedding_rnn_decoder(
decoder_inputs,
encoder_state,
- decoder_cell,
+ dec_cell,
num_decoder_symbols,
embedding_size,
feed_previous=feed_previous)