diff options
author | 2017-01-19 22:13:55 -0800 | |
---|---|---|
committer | 2017-01-19 22:27:57 -0800 | |
commit | a20cfa5494f4491ba38e460653191a3987af771b (patch) | |
tree | 71bb33f5b21bf4fc418bf8dd20123db63597c790 /tensorflow/contrib/legacy_seq2seq | |
parent | d3e3314ebe37d4ec2fe005913cfd8fedd51d92cc (diff) |
Force new instance creation in MultiRNNCell
Change: 145049923
Diffstat (limited to 'tensorflow/contrib/legacy_seq2seq')
-rw-r--r-- | tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py | 17 |
1 files changed, 11 insertions, 6 deletions
diff --git a/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py b/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py index 041cc6bf82..900f609681 100644 --- a/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py +++ b/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py @@ -405,9 +405,10 @@ class Seq2SeqTest(test.TestCase): with self.test_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): - cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=True) + single_cell = lambda: core_rnn_cell_impl.BasicLSTMCell( # pylint: disable=g-long-lambda + 2, state_is_tuple=True) cell = core_rnn_cell_impl.MultiRNNCell( - cells=[cell] * 2, state_is_tuple=True) + cells=[single_cell() for _ in range(2)], state_is_tuple=True) inp = [constant_op.constant(0.5, shape=[2, 2])] * 2 enc_outputs, enc_state = core_rnn.static_rnn( cell, inp, dtype=dtypes.float32) @@ -433,9 +434,11 @@ class Seq2SeqTest(test.TestCase): with self.test_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): - cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=True) + single_cell = lambda: core_rnn_cell_impl.BasicLSTMCell( # pylint: disable=g-long-lambda + 2, state_is_tuple=True) + cell = core_rnn_cell_impl.MultiRNNCell( - cells=[cell] * 2, state_is_tuple=True) + cells=[single_cell() for _ in range(2)], state_is_tuple=True) inp = constant_op.constant(0.5, shape=[2, 2, 2]) enc_outputs, enc_state = core_rnn.static_rnn( cell, inp, dtype=dtypes.float32) @@ -743,7 +746,8 @@ class Seq2SeqTest(test.TestCase): def GRUSeq2Seq(enc_inp, dec_inp): cell = core_rnn_cell_impl.MultiRNNCell( - [core_rnn_cell_impl.GRUCell(24)] * 2, state_is_tuple=True) + [core_rnn_cell_impl.GRUCell(24) for _ in range(2)], + state_is_tuple=True) return seq2seq_lib.embedding_attention_seq2seq( enc_inp, dec_inp, @@ -808,7 +812,8 @@ class Seq2SeqTest(test.TestCase): def GRUSeq2Seq(enc_inp, dec_inp): cell = core_rnn_cell_impl.MultiRNNCell( - [core_rnn_cell_impl.GRUCell(24)] * 2, state_is_tuple=True) + [core_rnn_cell_impl.GRUCell(24) for _ in range(2)], + state_is_tuple=True) return seq2seq_lib.embedding_attention_seq2seq( enc_inp, dec_inp, |