aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/legacy_seq2seq/python
diff options
context:
space:
mode:
authorGravatar Jianwei Xie <xiejw@google.com>2016-12-10 11:37:06 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-10 11:46:06 -0800
commit13b0c97780b690e34f8b40057cd789080fb489fd (patch)
treef3a4aa87231c4f47a7170f796b3a70e3fffa9d2b /tensorflow/contrib/legacy_seq2seq/python
parentd8f2a4b0e2548f1f2ea8ca44c134a2a2604af5c6 (diff)
Update caller to move from tf.nn.rnn to (the identical) tf.contrib.rnn.static_rnn.
Change: 141658438
Diffstat (limited to 'tensorflow/contrib/legacy_seq2seq/python')
-rw-r--r--tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py19
1 files changed, 12 insertions, 7 deletions
diff --git a/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py b/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py
index c908fbc655..7d550771b8 100644
--- a/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py
+++ b/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py
@@ -31,7 +31,7 @@ class Seq2SeqTest(tf.test.TestCase):
with self.test_session() as sess:
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
inp = [tf.constant(0.5, shape=[2, 2])] * 2
- _, enc_state = tf.nn.rnn(
+ _, enc_state = tf.contrib.rnn.static_rnn(
tf.contrib.rnn.GRUCell(2), inp, dtype=tf.float32)
dec_inp = [tf.constant(0.4, shape=[2, 2])] * 3
cell = tf.contrib.rnn.OutputProjectionWrapper(
@@ -86,7 +86,7 @@ class Seq2SeqTest(tf.test.TestCase):
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
inp = [tf.constant(0.5, shape=[2, 2])] * 2
cell = tf.contrib.rnn.BasicLSTMCell(2, state_is_tuple=True)
- _, enc_state = tf.nn.rnn(cell, inp, dtype=tf.float32)
+ _, enc_state = tf.contrib.rnn.static_rnn(cell, inp, dtype=tf.float32)
dec_inp = [tf.constant(i, tf.int32, shape=[2]) for i in range(3)]
dec, mem = tf.contrib.legacy_seq2seq.embedding_rnn_decoder(
dec_inp, enc_state, cell, num_symbols=4, embedding_size=2)
@@ -230,7 +230,8 @@ class Seq2SeqTest(tf.test.TestCase):
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
cell = tf.contrib.rnn.GRUCell(2)
inp = [tf.constant(0.5, shape=[2, 2])] * 2
- enc_outputs, enc_state = tf.nn.rnn(cell, inp, dtype=tf.float32)
+ enc_outputs, enc_state = tf.contrib.rnn.static_rnn(
+ cell, inp, dtype=tf.float32)
attn_states = tf.concat_v2(
[tf.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs], 1)
dec_inp = [tf.constant(0.4, shape=[2, 2])] * 3
@@ -250,7 +251,8 @@ class Seq2SeqTest(tf.test.TestCase):
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
cell = tf.contrib.rnn.GRUCell(2)
inp = [tf.constant(0.5, shape=[2, 2])] * 2
- enc_outputs, enc_state = tf.nn.rnn(cell, inp, dtype=tf.float32)
+ enc_outputs, enc_state = tf.contrib.rnn.static_rnn(
+ cell, inp, dtype=tf.float32)
attn_states = tf.concat_v2(
[tf.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs], 1)
dec_inp = [tf.constant(0.4, shape=[2, 2])] * 3
@@ -312,7 +314,8 @@ class Seq2SeqTest(tf.test.TestCase):
cell = tf.contrib.rnn.MultiRNNCell(cells=[cell] * 2,
state_is_tuple=True)
inp = [tf.constant(0.5, shape=[2, 2])] * 2
- enc_outputs, enc_state = tf.nn.rnn(cell, inp, dtype=tf.float32)
+ enc_outputs, enc_state = tf.contrib.rnn.static_rnn(
+ cell, inp, dtype=tf.float32)
attn_states = tf.concat_v2(
[tf.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs], 1)
dec_inp = [tf.constant(0.4, shape=[2, 2])] * 3
@@ -339,7 +342,8 @@ class Seq2SeqTest(tf.test.TestCase):
cell = tf.contrib.rnn.MultiRNNCell(cells=[cell] * 2,
state_is_tuple=True)
inp = tf.constant(0.5, shape=[2, 2, 2])
- enc_outputs, enc_state = tf.nn.rnn(cell, inp, dtype=tf.float32)
+ enc_outputs, enc_state = tf.contrib.rnn.static_rnn(
+ cell, inp, dtype=tf.float32)
attn_states = tf.concat_v2(
[tf.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs],
1)
@@ -364,7 +368,8 @@ class Seq2SeqTest(tf.test.TestCase):
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
inp = [tf.constant(0.5, shape=[2, 2])] * 2
cell = tf.contrib.rnn.GRUCell(2)
- enc_outputs, enc_state = tf.nn.rnn(cell, inp, dtype=tf.float32)
+ enc_outputs, enc_state = tf.contrib.rnn.static_rnn(
+ cell, inp, dtype=tf.float32)
attn_states = tf.concat_v2(
[tf.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs], 1)
dec_inp = [tf.constant(i, tf.int32, shape=[2]) for i in range(3)]