diff options
author | 2016-12-10 11:37:06 -0800 | |
---|---|---|
committer | 2016-12-10 11:46:06 -0800 | |
commit | 13b0c97780b690e34f8b40057cd789080fb489fd (patch) | |
tree | f3a4aa87231c4f47a7170f796b3a70e3fffa9d2b /tensorflow/contrib/legacy_seq2seq/python | |
parent | d8f2a4b0e2548f1f2ea8ca44c134a2a2604af5c6 (diff) |
Update caller to move from tf.nn.rnn to (the identical) tf.contrib.rnn.static_rnn.
Change: 141658438
Diffstat (limited to 'tensorflow/contrib/legacy_seq2seq/python')
-rw-r--r-- | tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py | 19 |
1 files changed, 12 insertions, 7 deletions
diff --git a/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py b/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py index c908fbc655..7d550771b8 100644 --- a/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py +++ b/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py @@ -31,7 +31,7 @@ class Seq2SeqTest(tf.test.TestCase): with self.test_session() as sess: with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)): inp = [tf.constant(0.5, shape=[2, 2])] * 2 - _, enc_state = tf.nn.rnn( + _, enc_state = tf.contrib.rnn.static_rnn( tf.contrib.rnn.GRUCell(2), inp, dtype=tf.float32) dec_inp = [tf.constant(0.4, shape=[2, 2])] * 3 cell = tf.contrib.rnn.OutputProjectionWrapper( @@ -86,7 +86,7 @@ class Seq2SeqTest(tf.test.TestCase): with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)): inp = [tf.constant(0.5, shape=[2, 2])] * 2 cell = tf.contrib.rnn.BasicLSTMCell(2, state_is_tuple=True) - _, enc_state = tf.nn.rnn(cell, inp, dtype=tf.float32) + _, enc_state = tf.contrib.rnn.static_rnn(cell, inp, dtype=tf.float32) dec_inp = [tf.constant(i, tf.int32, shape=[2]) for i in range(3)] dec, mem = tf.contrib.legacy_seq2seq.embedding_rnn_decoder( dec_inp, enc_state, cell, num_symbols=4, embedding_size=2) @@ -230,7 +230,8 @@ class Seq2SeqTest(tf.test.TestCase): with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)): cell = tf.contrib.rnn.GRUCell(2) inp = [tf.constant(0.5, shape=[2, 2])] * 2 - enc_outputs, enc_state = tf.nn.rnn(cell, inp, dtype=tf.float32) + enc_outputs, enc_state = tf.contrib.rnn.static_rnn( + cell, inp, dtype=tf.float32) attn_states = tf.concat_v2( [tf.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs], 1) dec_inp = [tf.constant(0.4, shape=[2, 2])] * 3 @@ -250,7 +251,8 @@ class Seq2SeqTest(tf.test.TestCase): with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)): cell = tf.contrib.rnn.GRUCell(2) inp = [tf.constant(0.5, shape=[2, 2])] * 2 - enc_outputs, enc_state = tf.nn.rnn(cell, inp, dtype=tf.float32) + enc_outputs, enc_state = tf.contrib.rnn.static_rnn( + cell, inp, dtype=tf.float32) attn_states = tf.concat_v2( [tf.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs], 1) dec_inp = [tf.constant(0.4, shape=[2, 2])] * 3 @@ -312,7 +314,8 @@ class Seq2SeqTest(tf.test.TestCase): cell = tf.contrib.rnn.MultiRNNCell(cells=[cell] * 2, state_is_tuple=True) inp = [tf.constant(0.5, shape=[2, 2])] * 2 - enc_outputs, enc_state = tf.nn.rnn(cell, inp, dtype=tf.float32) + enc_outputs, enc_state = tf.contrib.rnn.static_rnn( + cell, inp, dtype=tf.float32) attn_states = tf.concat_v2( [tf.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs], 1) dec_inp = [tf.constant(0.4, shape=[2, 2])] * 3 @@ -339,7 +342,8 @@ class Seq2SeqTest(tf.test.TestCase): cell = tf.contrib.rnn.MultiRNNCell(cells=[cell] * 2, state_is_tuple=True) inp = tf.constant(0.5, shape=[2, 2, 2]) - enc_outputs, enc_state = tf.nn.rnn(cell, inp, dtype=tf.float32) + enc_outputs, enc_state = tf.contrib.rnn.static_rnn( + cell, inp, dtype=tf.float32) attn_states = tf.concat_v2( [tf.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs], 1) @@ -364,7 +368,8 @@ class Seq2SeqTest(tf.test.TestCase): with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)): inp = [tf.constant(0.5, shape=[2, 2])] * 2 cell = tf.contrib.rnn.GRUCell(2) - enc_outputs, enc_state = tf.nn.rnn(cell, inp, dtype=tf.float32) + enc_outputs, enc_state = tf.contrib.rnn.static_rnn( + cell, inp, dtype=tf.float32) attn_states = tf.concat_v2( [tf.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs], 1) dec_inp = [tf.constant(i, tf.int32, shape=[2]) for i in range(3)] |