diff options
author | Jianwei Xie <xiejw@google.com> | 2016-12-10 11:37:06 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-12-10 11:46:06 -0800 |
commit | 13b0c97780b690e34f8b40057cd789080fb489fd (patch) | |
tree | f3a4aa87231c4f47a7170f796b3a70e3fffa9d2b /tensorflow/contrib/grid_rnn | |
parent | d8f2a4b0e2548f1f2ea8ca44c134a2a2604af5c6 (diff) |
Update caller to move from tf.nn.rnn to (the identical) tf.contrib.rnn.static_rnn.
Change: 141658438
Diffstat (limited to 'tensorflow/contrib/grid_rnn')
-rw-r--r-- | tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py b/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py index c7b14eff40..e5ebf89603 100644 --- a/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py +++ b/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py @@ -408,7 +408,7 @@ class GridRNNCellTest(tf.test.TestCase): tf.float32, shape=(batch_size, input_size)) ] - outputs, state = tf.nn.rnn(cell, inputs, dtype=tf.float32) + outputs, state = tf.contrib.rnn.static_rnn(cell, inputs, dtype=tf.float32) self.assertEqual(len(outputs), len(inputs)) self.assertEqual(state.get_shape(), (batch_size, 8)) @@ -441,7 +441,7 @@ class GridRNNCellTest(tf.test.TestCase): tf.float32, shape=(batch_size, input_size)) ] - outputs, state = tf.nn.rnn(cell, inputs, dtype=tf.float32) + outputs, state = tf.contrib.rnn.static_rnn(cell, inputs, dtype=tf.float32) self.assertEqual(len(outputs), len(inputs)) self.assertEqual(state.get_shape(), (batch_size, 4)) @@ -474,7 +474,7 @@ class GridRNNCellTest(tf.test.TestCase): tf.float32, shape=(batch_size, input_size)) ] - outputs, state = tf.nn.rnn(cell, inputs, dtype=tf.float32) + outputs, state = tf.contrib.rnn.static_rnn(cell, inputs, dtype=tf.float32) self.assertEqual(len(outputs), len(inputs)) self.assertEqual(state.get_shape(), (batch_size, 8)) @@ -505,7 +505,7 @@ class GridRNNCellTest(tf.test.TestCase): inputs = ([tf.placeholder(tf.float32, shape=(batch_size, input_size))] + (max_length - 1) * [tf.zeros([batch_size, input_size])]) - outputs, state = tf.nn.rnn(cell, inputs, dtype=tf.float32) + outputs, state = tf.contrib.rnn.static_rnn(cell, inputs, dtype=tf.float32) self.assertEqual(len(outputs), len(inputs)) self.assertEqual(state.get_shape(), (batch_size, 4)) |