diff options
author | Eugene Brevdo <ebrevdo@google.com> | 2017-02-08 12:31:53 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-02-08 12:46:20 -0800 |
commit | 6d85f317e6773aa653d8166a28dc601f58aa7198 (patch) | |
tree | 96da7c9a8554b9126b2432d55a462c2121983997 /tensorflow/contrib/legacy_seq2seq | |
parent | 1fa028759586e1fa5a6bdfe4675490087dc3f6fe (diff) |
Prepare legacy_seq2seq for coming changes to RNNCell.
* Create new instances of RNNCells as appropriate for testing.
* Disable certain tests because there is no way to for them to work until
we allow variable reuse from an RNNCell's current scope within another
scope (in April).
Change: 146942741
Diffstat (limited to 'tensorflow/contrib/legacy_seq2seq')
-rw-r--r-- | tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py | 305 |
1 files changed, 175 insertions, 130 deletions
diff --git a/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py b/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py index 900f609681..0f9f0a955c 100644 --- a/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py +++ b/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools import math import random import sys @@ -110,14 +111,17 @@ class Seq2SeqTest(test.TestCase): with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): inp = [constant_op.constant(0.5, shape=[2, 2])] * 2 - cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=True) + cell_fn = lambda: core_rnn_cell_impl.BasicLSTMCell(2) + cell = cell_fn() _, enc_state = core_rnn.static_rnn(cell, inp, dtype=dtypes.float32) dec_inp = [ constant_op.constant( i, dtypes.int32, shape=[2]) for i in range(3) ] + # Use a new cell instance since the attention decoder uses a + # different variable scope. dec, mem = seq2seq_lib.embedding_rnn_decoder( - dec_inp, enc_state, cell, num_symbols=4, embedding_size=2) + dec_inp, enc_state, cell_fn(), num_symbols=4, embedding_size=2) sess.run([variables.global_variables_initializer()]) res = sess.run(dec) self.assertEqual(3, len(res)) @@ -140,7 +144,8 @@ class Seq2SeqTest(test.TestCase): constant_op.constant( i, dtypes.int32, shape=[2]) for i in range(3) ] - cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=True) + cell_fn = lambda: core_rnn_cell_impl.BasicLSTMCell(2) + cell = cell_fn() dec, mem = seq2seq_lib.embedding_rnn_seq2seq( enc_inp, dec_inp, @@ -159,11 +164,11 @@ class Seq2SeqTest(test.TestCase): # Test with state_is_tuple=False. with variable_scope.variable_scope("no_tuple"): - cell1 = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False) + cell_nt = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False) dec, mem = seq2seq_lib.embedding_rnn_seq2seq( enc_inp, dec_inp, - cell1, + cell_nt, num_encoder_symbols=2, num_decoder_symbols=5, embedding_size=2) @@ -182,7 +187,7 @@ class Seq2SeqTest(test.TestCase): dec, _ = seq2seq_lib.embedding_rnn_seq2seq( enc_inp, dec_inp, - cell, + cell_fn(), num_encoder_symbols=2, num_decoder_symbols=5, embedding_size=2, @@ -201,7 +206,7 @@ class Seq2SeqTest(test.TestCase): d3, _ = seq2seq_lib.embedding_rnn_seq2seq( enc_inp, dec_inp2, - cell, + cell_fn(), num_encoder_symbols=2, num_decoder_symbols=5, embedding_size=2, @@ -211,7 +216,7 @@ class Seq2SeqTest(test.TestCase): d1, _ = seq2seq_lib.embedding_rnn_seq2seq( enc_inp, dec_inp, - cell, + cell_fn(), num_encoder_symbols=2, num_decoder_symbols=5, embedding_size=2, @@ -219,7 +224,7 @@ class Seq2SeqTest(test.TestCase): d2, _ = seq2seq_lib.embedding_rnn_seq2seq( enc_inp, dec_inp2, - cell, + cell_fn(), num_encoder_symbols=2, num_decoder_symbols=5, embedding_size=2, @@ -242,9 +247,11 @@ class Seq2SeqTest(test.TestCase): constant_op.constant( i, dtypes.int32, shape=[2]) for i in range(3) ] - cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=True) + cell = functools.partial( + core_rnn_cell_impl.BasicLSTMCell, + 2, state_is_tuple=True) dec, mem = seq2seq_lib.embedding_tied_rnn_seq2seq( - enc_inp, dec_inp, cell, num_symbols=5, embedding_size=2) + enc_inp, dec_inp, cell(), num_symbols=5, embedding_size=2) sess.run([variables.global_variables_initializer()]) res = sess.run(dec) self.assertEqual(3, len(res)) @@ -260,7 +267,7 @@ class Seq2SeqTest(test.TestCase): dec, mem = seq2seq_lib.embedding_tied_rnn_seq2seq( enc_inp, dec_inp, - cell, + cell(), num_symbols=5, num_decoder_symbols=3, embedding_size=2) @@ -276,7 +283,7 @@ class Seq2SeqTest(test.TestCase): dec, _ = seq2seq_lib.embedding_tied_rnn_seq2seq( enc_inp, dec_inp, - cell, + cell(), num_symbols=5, embedding_size=2, output_projection=(w, b)) @@ -291,7 +298,7 @@ class Seq2SeqTest(test.TestCase): d3, _ = seq2seq_lib.embedding_tied_rnn_seq2seq( enc_inp, dec_inp2, - cell, + cell(), num_symbols=5, embedding_size=2, feed_previous=constant_op.constant(True)) @@ -300,14 +307,14 @@ class Seq2SeqTest(test.TestCase): d1, _ = seq2seq_lib.embedding_tied_rnn_seq2seq( enc_inp, dec_inp, - cell, + cell(), num_symbols=5, embedding_size=2, feed_previous=True) d2, _ = seq2seq_lib.embedding_tied_rnn_seq2seq( enc_inp, dec_inp2, - cell, + cell(), num_symbols=5, embedding_size=2, feed_previous=True) @@ -321,7 +328,8 @@ class Seq2SeqTest(test.TestCase): with self.test_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): - cell = core_rnn_cell_impl.GRUCell(2) + cell_fn = lambda: core_rnn_cell_impl.GRUCell(2) + cell = cell_fn() inp = [constant_op.constant(0.5, shape=[2, 2])] * 2 enc_outputs, enc_state = core_rnn.static_rnn( cell, inp, dtype=dtypes.float32) @@ -329,8 +337,11 @@ class Seq2SeqTest(test.TestCase): array_ops.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs ], 1) dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3 + + # Create a new cell instance for the decoder, since it uses a + # different variable scope dec, mem = seq2seq_lib.attention_decoder( - dec_inp, enc_state, attn_states, cell, output_size=4) + dec_inp, enc_state, attn_states, cell_fn(), output_size=4) sess.run([variables.global_variables_initializer()]) res = sess.run(dec) self.assertEqual(3, len(res)) @@ -343,7 +354,8 @@ class Seq2SeqTest(test.TestCase): with self.test_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): - cell = core_rnn_cell_impl.GRUCell(2) + cell_fn = lambda: core_rnn_cell_impl.GRUCell(2) + cell = cell_fn() inp = [constant_op.constant(0.5, shape=[2, 2])] * 2 enc_outputs, enc_state = core_rnn.static_rnn( cell, inp, dtype=dtypes.float32) @@ -351,8 +363,12 @@ class Seq2SeqTest(test.TestCase): array_ops.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs ], 1) dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3 + + # Use a new cell instance since the attention decoder uses a + # different variable scope. dec, mem = seq2seq_lib.attention_decoder( - dec_inp, enc_state, attn_states, cell, output_size=4, num_heads=2) + dec_inp, enc_state, attn_states, cell_fn(), + output_size=4, num_heads=2) sess.run([variables.global_variables_initializer()]) res = sess.run(dec) self.assertEqual(3, len(res)) @@ -365,14 +381,18 @@ class Seq2SeqTest(test.TestCase): with self.test_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): - cell = core_rnn_cell_impl.GRUCell(2) + cell_fn = lambda: core_rnn_cell_impl.GRUCell(2) + cell = cell_fn() inp = constant_op.constant(0.5, shape=[2, 2, 2]) enc_outputs, enc_state = rnn.dynamic_rnn( cell, inp, dtype=dtypes.float32) attn_states = enc_outputs dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3 + + # Use a new cell instance since the attention decoder uses a + # different variable scope. dec, mem = seq2seq_lib.attention_decoder( - dec_inp, enc_state, attn_states, cell, output_size=4) + dec_inp, enc_state, attn_states, cell_fn(), output_size=4) sess.run([variables.global_variables_initializer()]) res = sess.run(dec) self.assertEqual(3, len(res)) @@ -385,14 +405,19 @@ class Seq2SeqTest(test.TestCase): with self.test_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): - cell = core_rnn_cell_impl.GRUCell(2) + cell_fn = lambda: core_rnn_cell_impl.GRUCell(2) + cell = cell_fn() inp = constant_op.constant(0.5, shape=[2, 2, 2]) enc_outputs, enc_state = rnn.dynamic_rnn( cell, inp, dtype=dtypes.float32) attn_states = enc_outputs dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3 + + # Use a new cell instance since the attention decoder uses a + # different variable scope. dec, mem = seq2seq_lib.attention_decoder( - dec_inp, enc_state, attn_states, cell, output_size=4, num_heads=2) + dec_inp, enc_state, attn_states, cell_fn(), + output_size=4, num_heads=2) sess.run([variables.global_variables_initializer()]) res = sess.run(dec) self.assertEqual(3, len(res)) @@ -407,8 +432,9 @@ class Seq2SeqTest(test.TestCase): "root", initializer=init_ops.constant_initializer(0.5)): single_cell = lambda: core_rnn_cell_impl.BasicLSTMCell( # pylint: disable=g-long-lambda 2, state_is_tuple=True) - cell = core_rnn_cell_impl.MultiRNNCell( + cell_fn = lambda: core_rnn_cell_impl.MultiRNNCell( # pylint: disable=g-long-lambda cells=[single_cell() for _ in range(2)], state_is_tuple=True) + cell = cell_fn() inp = [constant_op.constant(0.5, shape=[2, 2])] * 2 enc_outputs, enc_state = core_rnn.static_rnn( cell, inp, dtype=dtypes.float32) @@ -416,8 +442,11 @@ class Seq2SeqTest(test.TestCase): array_ops.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs ], 1) dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3 + + # Use a new cell instance since the attention decoder uses a + # different variable scope. dec, mem = seq2seq_lib.attention_decoder( - dec_inp, enc_state, attn_states, cell, output_size=4) + dec_inp, enc_state, attn_states, cell_fn(), output_size=4) sess.run([variables.global_variables_initializer()]) res = sess.run(dec) self.assertEqual(3, len(res)) @@ -434,11 +463,9 @@ class Seq2SeqTest(test.TestCase): with self.test_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): - single_cell = lambda: core_rnn_cell_impl.BasicLSTMCell( # pylint: disable=g-long-lambda - 2, state_is_tuple=True) - - cell = core_rnn_cell_impl.MultiRNNCell( - cells=[single_cell() for _ in range(2)], state_is_tuple=True) + cell_fn = lambda: core_rnn_cell_impl.MultiRNNCell( # pylint: disable=g-long-lambda + cells=[core_rnn_cell_impl.BasicLSTMCell(2) for _ in range(2)]) + cell = cell_fn() inp = constant_op.constant(0.5, shape=[2, 2, 2]) enc_outputs, enc_state = core_rnn.static_rnn( cell, inp, dtype=dtypes.float32) @@ -447,8 +474,11 @@ class Seq2SeqTest(test.TestCase): for e in enc_outputs ], 1) dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3 + + # Use a new cell instance since the attention decoder uses a + # different variable scope. dec, mem = seq2seq_lib.attention_decoder( - dec_inp, enc_state, attn_states, cell, output_size=4) + dec_inp, enc_state, attn_states, cell_fn(), output_size=4) sess.run([variables.global_variables_initializer()]) res = sess.run(dec) self.assertEqual(3, len(res)) @@ -466,7 +496,8 @@ class Seq2SeqTest(test.TestCase): with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): inp = [constant_op.constant(0.5, shape=[2, 2])] * 2 - cell = core_rnn_cell_impl.GRUCell(2) + cell_fn = lambda: core_rnn_cell_impl.GRUCell(2) + cell = cell_fn() enc_outputs, enc_state = core_rnn.static_rnn( cell, inp, dtype=dtypes.float32) attn_states = array_ops.concat([ @@ -476,11 +507,14 @@ class Seq2SeqTest(test.TestCase): constant_op.constant( i, dtypes.int32, shape=[2]) for i in range(3) ] + + # Use a new cell instance since the attention decoder uses a + # different variable scope. dec, mem = seq2seq_lib.embedding_attention_decoder( dec_inp, enc_state, attn_states, - cell, + cell_fn(), num_symbols=4, embedding_size=2, output_size=3) @@ -504,7 +538,8 @@ class Seq2SeqTest(test.TestCase): constant_op.constant( i, dtypes.int32, shape=[2]) for i in range(3) ] - cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=True) + cell_fn = lambda: core_rnn_cell_impl.BasicLSTMCell(2) + cell = cell_fn() dec, mem = seq2seq_lib.embedding_attention_seq2seq( enc_inp, dec_inp, @@ -523,11 +558,14 @@ class Seq2SeqTest(test.TestCase): # Test with state_is_tuple=False. with variable_scope.variable_scope("no_tuple"): - cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False) + cell_fn = functools.partial( + core_rnn_cell_impl.BasicLSTMCell, + 2, state_is_tuple=False) + cell_nt = cell_fn() dec, mem = seq2seq_lib.embedding_attention_seq2seq( enc_inp, dec_inp, - cell, + cell_nt, num_encoder_symbols=2, num_decoder_symbols=5, embedding_size=2) @@ -546,7 +584,7 @@ class Seq2SeqTest(test.TestCase): dec, _ = seq2seq_lib.embedding_attention_seq2seq( enc_inp, dec_inp, - cell, + cell_fn(), num_encoder_symbols=2, num_decoder_symbols=5, embedding_size=2, @@ -556,43 +594,47 @@ class Seq2SeqTest(test.TestCase): self.assertEqual(3, len(res)) self.assertEqual((2, 2), res[0].shape) - # Test that previous-feeding model ignores inputs after the first. - dec_inp2 = [ - constant_op.constant( - 0, dtypes.int32, shape=[2]) for _ in range(3) - ] - with variable_scope.variable_scope("other"): - d3, _ = seq2seq_lib.embedding_attention_seq2seq( - enc_inp, - dec_inp2, - cell, - num_encoder_symbols=2, - num_decoder_symbols=5, - embedding_size=2, - feed_previous=constant_op.constant(True)) - sess.run([variables.global_variables_initializer()]) - variable_scope.get_variable_scope().reuse_variables() - d1, _ = seq2seq_lib.embedding_attention_seq2seq( - enc_inp, - dec_inp, - cell, - num_encoder_symbols=2, - num_decoder_symbols=5, - embedding_size=2, - feed_previous=True) - d2, _ = seq2seq_lib.embedding_attention_seq2seq( - enc_inp, - dec_inp2, - cell, - num_encoder_symbols=2, - num_decoder_symbols=5, - embedding_size=2, - feed_previous=True) - res1 = sess.run(d1) - res2 = sess.run(d2) - res3 = sess.run(d3) - self.assertAllClose(res1, res2) - self.assertAllClose(res1, res3) + # TODO(ebrevdo, lukaszkaiser): Re-enable once RNNCells allow reuse + # within a variable scope that already has a weights tensor. + # + # # Test that previous-feeding model ignores inputs after the first. + # dec_inp2 = [ + # constant_op.constant( + # 0, dtypes.int32, shape=[2]) for _ in range(3) + # ] + # with variable_scope.variable_scope("other"): + # d3, _ = seq2seq_lib.embedding_attention_seq2seq( + # enc_inp, + # dec_inp2, + # cell_fn(), + # num_encoder_symbols=2, + # num_decoder_symbols=5, + # embedding_size=2, + # feed_previous=constant_op.constant(True)) + # sess.run([variables.global_variables_initializer()]) + # variable_scope.get_variable_scope().reuse_variables() + # cell = cell_fn() + # d1, _ = seq2seq_lib.embedding_attention_seq2seq( + # enc_inp, + # dec_inp, + # cell, + # num_encoder_symbols=2, + # num_decoder_symbols=5, + # embedding_size=2, + # feed_previous=True) + # d2, _ = seq2seq_lib.embedding_attention_seq2seq( + # enc_inp, + # dec_inp2, + # cell, + # num_encoder_symbols=2, + # num_decoder_symbols=5, + # embedding_size=2, + # feed_previous=True) + # res1 = sess.run(d1) + # res2 = sess.run(d2) + # res3 = sess.run(d3) + # self.assertAllClose(res1, res2) + # self.assertAllClose(res1, res3) def testOne2ManyRNNSeq2Seq(self): with self.test_session() as sess: @@ -734,61 +776,64 @@ class Seq2SeqTest(test.TestCase): res = sess.run(loss_per_sequence) self.assertAllClose(np.asarray([4.828314, 4.828314]), res) - def testModelWithBucketsScopeAndLoss(self): - """Test that variable scope reuse is not reset after model_with_buckets.""" - classes = 10 - buckets = [(4, 4), (8, 8)] - - with self.test_session(): - # Here comes a sample Seq2Seq model using GRU cells. - def SampleGRUSeq2Seq(enc_inp, dec_inp, weights, per_example_loss): - """Example sequence-to-sequence model that uses GRU cells.""" - - def GRUSeq2Seq(enc_inp, dec_inp): - cell = core_rnn_cell_impl.MultiRNNCell( - [core_rnn_cell_impl.GRUCell(24) for _ in range(2)], - state_is_tuple=True) - return seq2seq_lib.embedding_attention_seq2seq( - enc_inp, - dec_inp, - cell, - num_encoder_symbols=classes, - num_decoder_symbols=classes, - embedding_size=24) - - targets = [dec_inp[i + 1] for i in range(len(dec_inp) - 1)] + [0] - return seq2seq_lib.model_with_buckets( - enc_inp, - dec_inp, - targets, - weights, - buckets, - GRUSeq2Seq, - per_example_loss=per_example_loss) - - # Now we construct the copy model. - inp = [ - array_ops.placeholder( - dtypes.int32, shape=[None]) for _ in range(8) - ] - out = [ - array_ops.placeholder( - dtypes.int32, shape=[None]) for _ in range(8) - ] - weights = [ - array_ops.ones_like( - inp[0], dtype=dtypes.float32) for _ in range(8) - ] - with variable_scope.variable_scope("root"): - _, losses1 = SampleGRUSeq2Seq(inp, out, weights, per_example_loss=False) - # Now check that we did not accidentally set reuse. - self.assertEqual(False, variable_scope.get_variable_scope().reuse) - # Construct one more model with per-example loss. - variable_scope.get_variable_scope().reuse_variables() - _, losses2 = SampleGRUSeq2Seq(inp, out, weights, per_example_loss=True) - # First loss is scalar, the second one is a 1-dimensinal tensor. - self.assertEqual([], losses1[0].get_shape().as_list()) - self.assertEqual([None], losses2[0].get_shape().as_list()) + # TODO(ebrevdo, lukaszkaiser): Re-enable once RNNCells allow reuse + # within a variable scope that already has a weights tensor. + # + # def testModelWithBucketsScopeAndLoss(self): + # """Test variable scope reuse is not reset after model_with_buckets.""" + # classes = 10 + # buckets = [(4, 4), (8, 8)] + + # with self.test_session(): + # # Here comes a sample Seq2Seq model using GRU cells. + # def SampleGRUSeq2Seq(enc_inp, dec_inp, weights, per_example_loss): + # """Example sequence-to-sequence model that uses GRU cells.""" + + # def GRUSeq2Seq(enc_inp, dec_inp): + # cell = core_rnn_cell_impl.MultiRNNCell( + # [core_rnn_cell_impl.GRUCell(24) for _ in range(2)]) + # return seq2seq_lib.embedding_attention_seq2seq( + # enc_inp, + # dec_inp, + # cell, + # num_encoder_symbols=classes, + # num_decoder_symbols=classes, + # embedding_size=24) + + # targets = [dec_inp[i + 1] for i in range(len(dec_inp) - 1)] + [0] + # return seq2seq_lib.model_with_buckets( + # enc_inp, + # dec_inp, + # targets, + # weights, + # buckets, + # GRUSeq2Seq, + # per_example_loss=per_example_loss) + + # # Now we construct the copy model. + # inp = [ + # array_ops.placeholder( + # dtypes.int32, shape=[None]) for _ in range(8) + # ] + # out = [ + # array_ops.placeholder( + # dtypes.int32, shape=[None]) for _ in range(8) + # ] + # weights = [ + # array_ops.ones_like( + # inp[0], dtype=dtypes.float32) for _ in range(8) + # ] + # with variable_scope.variable_scope("root"): + # _, losses1 = SampleGRUSeq2Seq( + # inp, out, weights, per_example_loss=False) + # # Now check that we did not accidentally set reuse. + # self.assertEqual(False, variable_scope.get_variable_scope().reuse) + # with variable_scope.variable_scope("new"): + # _, losses2 = SampleGRUSeq2Seq + # inp, out, weights, per_example_loss=True) + # # First loss is scalar, the second one is a 1-dimensinal tensor. + # self.assertEqual([], losses1[0].get_shape().as_list()) + # self.assertEqual([None], losses2[0].get_shape().as_list()) def testModelWithBuckets(self): """Larger tests that does full sequence-to-sequence model training.""" |