aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/grid_rnn
diff options
context:
space:
mode:
authorGravatar Jianwei Xie <xiejw@google.com>2016-12-10 11:37:06 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-10 11:46:06 -0800
commit13b0c97780b690e34f8b40057cd789080fb489fd (patch)
treef3a4aa87231c4f47a7170f796b3a70e3fffa9d2b /tensorflow/contrib/grid_rnn
parentd8f2a4b0e2548f1f2ea8ca44c134a2a2604af5c6 (diff)
Update caller to move from tf.nn.rnn to (the identical) tf.contrib.rnn.static_rnn.
Change: 141658438
Diffstat (limited to 'tensorflow/contrib/grid_rnn')
-rw-r--r--tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py b/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py
index c7b14eff40..e5ebf89603 100644
--- a/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py
+++ b/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py
@@ -408,7 +408,7 @@ class GridRNNCellTest(tf.test.TestCase):
tf.float32, shape=(batch_size, input_size))
]
- outputs, state = tf.nn.rnn(cell, inputs, dtype=tf.float32)
+ outputs, state = tf.contrib.rnn.static_rnn(cell, inputs, dtype=tf.float32)
self.assertEqual(len(outputs), len(inputs))
self.assertEqual(state.get_shape(), (batch_size, 8))
@@ -441,7 +441,7 @@ class GridRNNCellTest(tf.test.TestCase):
tf.float32, shape=(batch_size, input_size))
]
- outputs, state = tf.nn.rnn(cell, inputs, dtype=tf.float32)
+ outputs, state = tf.contrib.rnn.static_rnn(cell, inputs, dtype=tf.float32)
self.assertEqual(len(outputs), len(inputs))
self.assertEqual(state.get_shape(), (batch_size, 4))
@@ -474,7 +474,7 @@ class GridRNNCellTest(tf.test.TestCase):
tf.float32, shape=(batch_size, input_size))
]
- outputs, state = tf.nn.rnn(cell, inputs, dtype=tf.float32)
+ outputs, state = tf.contrib.rnn.static_rnn(cell, inputs, dtype=tf.float32)
self.assertEqual(len(outputs), len(inputs))
self.assertEqual(state.get_shape(), (batch_size, 8))
@@ -505,7 +505,7 @@ class GridRNNCellTest(tf.test.TestCase):
inputs = ([tf.placeholder(tf.float32, shape=(batch_size, input_size))]
+ (max_length - 1) * [tf.zeros([batch_size, input_size])])
- outputs, state = tf.nn.rnn(cell, inputs, dtype=tf.float32)
+ outputs, state = tf.contrib.rnn.static_rnn(cell, inputs, dtype=tf.float32)
self.assertEqual(len(outputs), len(inputs))
self.assertEqual(state.get_shape(), (batch_size, 4))