diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/seq2seq_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/seq2seq_test.py | 99 |
1 files changed, 99 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/seq2seq_test.py b/tensorflow/python/kernel_tests/seq2seq_test.py index 77ff0571b7..a6f017f22f 100644 --- a/tensorflow/python/kernel_tests/seq2seq_test.py +++ b/tensorflow/python/kernel_tests/seq2seq_test.py @@ -495,6 +495,105 @@ class Seq2SeqTest(tf.test.TestCase): if len(perplexities[bucket]) > 1: # Assert that perplexity went down. self.assertLess(perplexities[bucket][-1], perplexities[bucket][0]) + def testModelWithBooleanFeedPrevious(self): + """Test the model behavior when feed_previous is True. + + For example, the following two cases have the same effect: + - Train `embedding_rnn_seq2seq` with `feed_previous=True`, which contains + a `embedding_rnn_decoder` with `feed_previous=True` and + `update_embedding_for_previous=True`. The decoder is fed with "<Go>" + and outputs "A, B, C". + - Train `embedding_rnn_seq2seq` with `feed_previous=False`. The decoder + is fed with "<Go>, A, B". + """ + num_encoder_symbols = 3 + num_decoder_symbols = 5 + batch_size = 2 + num_enc_timesteps = 2 + num_dec_timesteps = 3 + + def TestModel(seq2seq): + with self.test_session(graph=tf.Graph()) as sess: + tf.set_random_seed(111) + random.seed(111) + np.random.seed(111) + + enc_inp = [tf.constant(i + 1, tf.int32, shape=[batch_size]) + for i in range(num_enc_timesteps)] + dec_inp_fp_true = [tf.constant(i, tf.int32, shape=[batch_size]) + for i in range(num_dec_timesteps)] + dec_inp_holder_fp_false = [tf.placeholder(tf.int32, shape=[batch_size]) + for _ in range(num_dec_timesteps)] + targets = [tf.constant(i + 1, tf.int32, shape=[batch_size]) + for i in range(num_dec_timesteps)] + weights = [tf.constant(1.0, shape=[batch_size]) + for i in range(num_dec_timesteps)] + + def ForwardBackward(enc_inp, dec_inp, feed_previous): + scope_name = "fp_{}".format(feed_previous) + with tf.variable_scope(scope_name): + dec_op, _ = seq2seq(enc_inp, dec_inp, feed_previous=feed_previous) + net_variables = tf.get_collection(tf.GraphKeys.VARIABLES, + scope_name) + optimizer = tf.train.AdamOptimizer(0.03, epsilon=1e-5) + update_op = optimizer.minimize( + tf.nn.seq2seq.sequence_loss(dec_op, targets, weights), + var_list=net_variables) + return dec_op, update_op, net_variables + + dec_op_fp_true, update_fp_true, variables_fp_true = ForwardBackward( + enc_inp, dec_inp_fp_true, feed_previous=True) + dec_op_fp_false, update_fp_false, variables_fp_false = ForwardBackward( + enc_inp, dec_inp_holder_fp_false, feed_previous=False) + + sess.run(tf.initialize_all_variables()) + + # We only check consistencies between the variables existing in both + # the models with True and False feed_previous. Variables created by + # the loop_function in the model with True feed_previous are ignored. + v_false_name_dict = {v.name.split('/', 1)[-1]: v + for v in variables_fp_false} + matched_variables = [(v, v_false_name_dict[v.name.split('/', 1)[-1]]) + for v in variables_fp_true] + for v_true, v_false in matched_variables: + sess.run(tf.assign(v_false, v_true)) + + # Take the symbols generated by the decoder with feed_previous=True as + # the true input symbols for the decoder with feed_previous=False. + dec_fp_true = sess.run(dec_op_fp_true) + output_symbols_fp_true = np.argmax(dec_fp_true, axis=2) + dec_inp_fp_false = np.vstack((dec_inp_fp_true[0].eval(), + output_symbols_fp_true[:-1])) + sess.run(update_fp_true) + sess.run(update_fp_false, + {holder: inp for holder, inp in zip(dec_inp_holder_fp_false, + dec_inp_fp_false)}) + + for v_true, v_false in matched_variables: + self.assertAllClose(v_true.eval(), v_false.eval()) + + def EmbeddingRNNSeq2SeqF(enc_inp, dec_inp, feed_previous): + cell = tf.nn.rnn_cell.BasicLSTMCell(2) + return tf.nn.seq2seq.embedding_rnn_seq2seq( + enc_inp, dec_inp, cell, num_encoder_symbols, + num_decoder_symbols, feed_previous=feed_previous) + + def EmbeddingTiedRNNSeq2Seq(enc_inp, dec_inp, feed_previous): + cell = tf.nn.rnn_cell.BasicLSTMCell(2) + return tf.nn.seq2seq.embedding_tied_rnn_seq2seq( + enc_inp, dec_inp, cell, num_decoder_symbols, + feed_previous=feed_previous) + + def EmbeddingAttentionSeq2Seq(enc_inp, dec_inp, feed_previous): + cell = tf.nn.rnn_cell.BasicLSTMCell(2) + return tf.nn.seq2seq.embedding_attention_seq2seq( + enc_inp, dec_inp, cell, num_encoder_symbols, + num_decoder_symbols, feed_previous=feed_previous) + + for model in (EmbeddingRNNSeq2SeqF, EmbeddingTiedRNNSeq2Seq, + EmbeddingAttentionSeq2Seq): + TestModel(model) + if __name__ == "__main__": tf.test.main() |