aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/legacy_seq2seq
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2017-01-19 22:13:55 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-19 22:27:57 -0800
commita20cfa5494f4491ba38e460653191a3987af771b (patch)
tree71bb33f5b21bf4fc418bf8dd20123db63597c790 /tensorflow/contrib/legacy_seq2seq
parentd3e3314ebe37d4ec2fe005913cfd8fedd51d92cc (diff)
Force new instance creation in MultiRNNCell
Change: 145049923
Diffstat (limited to 'tensorflow/contrib/legacy_seq2seq')
-rw-r--r--tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py17
1 files changed, 11 insertions, 6 deletions
diff --git a/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py b/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py
index 041cc6bf82..900f609681 100644
--- a/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py
+++ b/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py
@@ -405,9 +405,10 @@ class Seq2SeqTest(test.TestCase):
with self.test_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
- cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=True)
+ single_cell = lambda: core_rnn_cell_impl.BasicLSTMCell( # pylint: disable=g-long-lambda
+ 2, state_is_tuple=True)
cell = core_rnn_cell_impl.MultiRNNCell(
- cells=[cell] * 2, state_is_tuple=True)
+ cells=[single_cell() for _ in range(2)], state_is_tuple=True)
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
enc_outputs, enc_state = core_rnn.static_rnn(
cell, inp, dtype=dtypes.float32)
@@ -433,9 +434,11 @@ class Seq2SeqTest(test.TestCase):
with self.test_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
- cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=True)
+ single_cell = lambda: core_rnn_cell_impl.BasicLSTMCell( # pylint: disable=g-long-lambda
+ 2, state_is_tuple=True)
+
cell = core_rnn_cell_impl.MultiRNNCell(
- cells=[cell] * 2, state_is_tuple=True)
+ cells=[single_cell() for _ in range(2)], state_is_tuple=True)
inp = constant_op.constant(0.5, shape=[2, 2, 2])
enc_outputs, enc_state = core_rnn.static_rnn(
cell, inp, dtype=dtypes.float32)
@@ -743,7 +746,8 @@ class Seq2SeqTest(test.TestCase):
def GRUSeq2Seq(enc_inp, dec_inp):
cell = core_rnn_cell_impl.MultiRNNCell(
- [core_rnn_cell_impl.GRUCell(24)] * 2, state_is_tuple=True)
+ [core_rnn_cell_impl.GRUCell(24) for _ in range(2)],
+ state_is_tuple=True)
return seq2seq_lib.embedding_attention_seq2seq(
enc_inp,
dec_inp,
@@ -808,7 +812,8 @@ class Seq2SeqTest(test.TestCase):
def GRUSeq2Seq(enc_inp, dec_inp):
cell = core_rnn_cell_impl.MultiRNNCell(
- [core_rnn_cell_impl.GRUCell(24)] * 2, state_is_tuple=True)
+ [core_rnn_cell_impl.GRUCell(24) for _ in range(2)],
+ state_is_tuple=True)
return seq2seq_lib.embedding_attention_seq2seq(
enc_inp,
dec_inp,