aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/seq2seq_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/seq2seq_test.py')
-rw-r--r--tensorflow/python/kernel_tests/seq2seq_test.py99
1 files changed, 99 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/seq2seq_test.py b/tensorflow/python/kernel_tests/seq2seq_test.py
index 77ff0571b7..a6f017f22f 100644
--- a/tensorflow/python/kernel_tests/seq2seq_test.py
+++ b/tensorflow/python/kernel_tests/seq2seq_test.py
@@ -495,6 +495,105 @@ class Seq2SeqTest(tf.test.TestCase):
if len(perplexities[bucket]) > 1: # Assert that perplexity went down.
self.assertLess(perplexities[bucket][-1], perplexities[bucket][0])
+ def testModelWithBooleanFeedPrevious(self):
+ """Test the model behavior when feed_previous is True.
+
+ For example, the following two cases have the same effect:
+ - Train `embedding_rnn_seq2seq` with `feed_previous=True`, which contains
+ a `embedding_rnn_decoder` with `feed_previous=True` and
+ `update_embedding_for_previous=True`. The decoder is fed with "<Go>"
+ and outputs "A, B, C".
+ - Train `embedding_rnn_seq2seq` with `feed_previous=False`. The decoder
+ is fed with "<Go>, A, B".
+ """
+ num_encoder_symbols = 3
+ num_decoder_symbols = 5
+ batch_size = 2
+ num_enc_timesteps = 2
+ num_dec_timesteps = 3
+
+ def TestModel(seq2seq):
+ with self.test_session(graph=tf.Graph()) as sess:
+ tf.set_random_seed(111)
+ random.seed(111)
+ np.random.seed(111)
+
+ enc_inp = [tf.constant(i + 1, tf.int32, shape=[batch_size])
+ for i in range(num_enc_timesteps)]
+ dec_inp_fp_true = [tf.constant(i, tf.int32, shape=[batch_size])
+ for i in range(num_dec_timesteps)]
+ dec_inp_holder_fp_false = [tf.placeholder(tf.int32, shape=[batch_size])
+ for _ in range(num_dec_timesteps)]
+ targets = [tf.constant(i + 1, tf.int32, shape=[batch_size])
+ for i in range(num_dec_timesteps)]
+ weights = [tf.constant(1.0, shape=[batch_size])
+ for i in range(num_dec_timesteps)]
+
+ def ForwardBackward(enc_inp, dec_inp, feed_previous):
+ scope_name = "fp_{}".format(feed_previous)
+ with tf.variable_scope(scope_name):
+ dec_op, _ = seq2seq(enc_inp, dec_inp, feed_previous=feed_previous)
+ net_variables = tf.get_collection(tf.GraphKeys.VARIABLES,
+ scope_name)
+ optimizer = tf.train.AdamOptimizer(0.03, epsilon=1e-5)
+ update_op = optimizer.minimize(
+ tf.nn.seq2seq.sequence_loss(dec_op, targets, weights),
+ var_list=net_variables)
+ return dec_op, update_op, net_variables
+
+ dec_op_fp_true, update_fp_true, variables_fp_true = ForwardBackward(
+ enc_inp, dec_inp_fp_true, feed_previous=True)
+ dec_op_fp_false, update_fp_false, variables_fp_false = ForwardBackward(
+ enc_inp, dec_inp_holder_fp_false, feed_previous=False)
+
+ sess.run(tf.initialize_all_variables())
+
+ # We only check consistencies between the variables existing in both
+ # the models with True and False feed_previous. Variables created by
+ # the loop_function in the model with True feed_previous are ignored.
+ v_false_name_dict = {v.name.split('/', 1)[-1]: v
+ for v in variables_fp_false}
+ matched_variables = [(v, v_false_name_dict[v.name.split('/', 1)[-1]])
+ for v in variables_fp_true]
+ for v_true, v_false in matched_variables:
+ sess.run(tf.assign(v_false, v_true))
+
+ # Take the symbols generated by the decoder with feed_previous=True as
+ # the true input symbols for the decoder with feed_previous=False.
+ dec_fp_true = sess.run(dec_op_fp_true)
+ output_symbols_fp_true = np.argmax(dec_fp_true, axis=2)
+ dec_inp_fp_false = np.vstack((dec_inp_fp_true[0].eval(),
+ output_symbols_fp_true[:-1]))
+ sess.run(update_fp_true)
+ sess.run(update_fp_false,
+ {holder: inp for holder, inp in zip(dec_inp_holder_fp_false,
+ dec_inp_fp_false)})
+
+ for v_true, v_false in matched_variables:
+ self.assertAllClose(v_true.eval(), v_false.eval())
+
+ def EmbeddingRNNSeq2SeqF(enc_inp, dec_inp, feed_previous):
+ cell = tf.nn.rnn_cell.BasicLSTMCell(2)
+ return tf.nn.seq2seq.embedding_rnn_seq2seq(
+ enc_inp, dec_inp, cell, num_encoder_symbols,
+ num_decoder_symbols, feed_previous=feed_previous)
+
+ def EmbeddingTiedRNNSeq2Seq(enc_inp, dec_inp, feed_previous):
+ cell = tf.nn.rnn_cell.BasicLSTMCell(2)
+ return tf.nn.seq2seq.embedding_tied_rnn_seq2seq(
+ enc_inp, dec_inp, cell, num_decoder_symbols,
+ feed_previous=feed_previous)
+
+ def EmbeddingAttentionSeq2Seq(enc_inp, dec_inp, feed_previous):
+ cell = tf.nn.rnn_cell.BasicLSTMCell(2)
+ return tf.nn.seq2seq.embedding_attention_seq2seq(
+ enc_inp, dec_inp, cell, num_encoder_symbols,
+ num_decoder_symbols, feed_previous=feed_previous)
+
+ for model in (EmbeddingRNNSeq2SeqF, EmbeddingTiedRNNSeq2Seq,
+ EmbeddingAttentionSeq2Seq):
+ TestModel(model)
+
if __name__ == "__main__":
tf.test.main()