aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/grid_rnn
diff options
context:
space:
mode:
authorGravatar Justine Tunney <jart@google.com>2016-12-29 22:46:24 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-29 23:06:59 -0800
commite121667dc609de978a223c56ee906368d2c4ceef (patch)
tree7d4e1f1e1b4fd469487872c0cd34ddace5ac570c /tensorflow/contrib/grid_rnn
parent7815fcba7767aa1eb3196c5861e174f8b3c43bab (diff)
Remove so many more hourglass imports
Change: 143230429
Diffstat (limited to 'tensorflow/contrib/grid_rnn')
-rw-r--r--tensorflow/contrib/grid_rnn/BUILD10
-rw-r--r--tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py332
2 files changed, 201 insertions, 141 deletions
diff --git a/tensorflow/contrib/grid_rnn/BUILD b/tensorflow/contrib/grid_rnn/BUILD
index 04cdc1e135..73473becf9 100644
--- a/tensorflow/contrib/grid_rnn/BUILD
+++ b/tensorflow/contrib/grid_rnn/BUILD
@@ -29,9 +29,17 @@ cuda_py_tests(
srcs = ["python/kernel_tests/grid_rnn_test.py"],
additional_deps = [
":grid_rnn_py",
- "//tensorflow:tensorflow_py",
+ "//third_party/py/numpy",
+ "//tensorflow/contrib/rnn:rnn_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:init_ops",
+ "//tensorflow/python:nn_ops",
"//tensorflow/python:platform_test",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
],
)
diff --git a/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py b/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py
index e5ebf89603..e2a5a5556f 100644
--- a/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py
+++ b/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py
@@ -18,29 +18,46 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import sys
+
+# TODO: #6568 Remove this hack that makes dlopen() not crash.
+if hasattr(sys, 'getdlopenflags') and hasattr(sys, 'setdlopenflags'):
+ import ctypes
+ sys.setdlopenflags(sys.getdlopenflags() | ctypes.RTLD_GLOBAL)
+
import numpy as np
-import tensorflow as tf
+from tensorflow.contrib.grid_rnn.python.ops import grid_rnn_cell
+from tensorflow.contrib.rnn.python.ops import core_rnn
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
-class GridRNNCellTest(tf.test.TestCase):
+
+class GridRNNCellTest(test.TestCase):
def testGrid2BasicLSTMCell(self):
with self.test_session() as sess:
- with tf.variable_scope(
- 'root', initializer=tf.constant_initializer(0.2)) as root_scope:
- x = tf.zeros([1, 3])
- m = tf.zeros([1, 8])
- cell = tf.contrib.grid_rnn.Grid2BasicLSTMCell(2)
+ with variable_scope.variable_scope(
+ 'root', initializer=init_ops.constant_initializer(0.2)) as root_scope:
+ x = array_ops.zeros([1, 3])
+ m = array_ops.zeros([1, 8])
+ cell = grid_rnn_cell.Grid2BasicLSTMCell(2)
self.assertEqual(cell.state_size, 8)
g, s = cell(x, m)
self.assertEqual(g.get_shape(), (1, 2))
self.assertEqual(s.get_shape(), (1, 8))
- sess.run([tf.global_variables_initializer()])
- res = sess.run(
- [g, s], {x: np.array([[1., 1., 1.]]),
- m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])})
+ sess.run([variables.global_variables_initializer()])
+ res = sess.run([g, s], {
+ x: np.array([[1., 1., 1.]]),
+ m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])
+ })
self.assertEqual(res[0].shape, (1, 2))
self.assertEqual(res[1].shape, (1, 8))
self.assertAllClose(res[0], [[0.36617181, 0.36617181]])
@@ -65,20 +82,22 @@ class GridRNNCellTest(tf.test.TestCase):
def testGrid2BasicLSTMCellTied(self):
with self.test_session() as sess:
- with tf.variable_scope('root', initializer=tf.constant_initializer(0.2)):
- x = tf.zeros([1, 3])
- m = tf.zeros([1, 8])
- cell = tf.contrib.grid_rnn.Grid2BasicLSTMCell(2, tied=True)
+ with variable_scope.variable_scope(
+ 'root', initializer=init_ops.constant_initializer(0.2)):
+ x = array_ops.zeros([1, 3])
+ m = array_ops.zeros([1, 8])
+ cell = grid_rnn_cell.Grid2BasicLSTMCell(2, tied=True)
self.assertEqual(cell.state_size, 8)
g, s = cell(x, m)
self.assertEqual(g.get_shape(), (1, 2))
self.assertEqual(s.get_shape(), (1, 8))
- sess.run([tf.global_variables_initializer()])
- res = sess.run(
- [g, s], {x: np.array([[1., 1., 1.]]),
- m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])})
+ sess.run([variables.global_variables_initializer()])
+ res = sess.run([g, s], {
+ x: np.array([[1., 1., 1.]]),
+ m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])
+ })
self.assertEqual(res[0].shape, (1, 2))
self.assertEqual(res[1].shape, (1, 8))
self.assertAllClose(res[0], [[0.36617181, 0.36617181]])
@@ -96,45 +115,50 @@ class GridRNNCellTest(tf.test.TestCase):
def testGrid2BasicLSTMCellWithRelu(self):
with self.test_session() as sess:
- with tf.variable_scope('root', initializer=tf.constant_initializer(0.2)):
- x = tf.zeros([1, 3])
- m = tf.zeros([1, 4])
- cell = tf.contrib.grid_rnn.Grid2BasicLSTMCell(
- 2, tied=False, non_recurrent_fn=tf.nn.relu)
+ with variable_scope.variable_scope(
+ 'root', initializer=init_ops.constant_initializer(0.2)):
+ x = array_ops.zeros([1, 3])
+ m = array_ops.zeros([1, 4])
+ cell = grid_rnn_cell.Grid2BasicLSTMCell(
+ 2, tied=False, non_recurrent_fn=nn_ops.relu)
self.assertEqual(cell.state_size, 4)
g, s = cell(x, m)
self.assertEqual(g.get_shape(), (1, 2))
self.assertEqual(s.get_shape(), (1, 4))
- sess.run([tf.global_variables_initializer()])
- res = sess.run([g, s], {x: np.array([[1., 1., 1.]]),
- m: np.array([[0.1, 0.2, 0.3, 0.4]])})
+ sess.run([variables.global_variables_initializer()])
+ res = sess.run(
+ [g, s],
+ {x: np.array([[1., 1., 1.]]),
+ m: np.array([[0.1, 0.2, 0.3, 0.4]])})
self.assertEqual(res[0].shape, (1, 2))
self.assertEqual(res[1].shape, (1, 4))
self.assertAllClose(res[0], [[0.31667367, 0.31667367]])
- self.assertAllClose(res[1],
- [[0.29530135, 0.37520045, 0.17044567, 0.21292259]])
+ self.assertAllClose(res[1], [[0.29530135, 0.37520045, 0.17044567,
+ 0.21292259]])
"""LSTMCell
"""
def testGrid2LSTMCell(self):
with self.test_session() as sess:
- with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)):
- x = tf.zeros([1, 3])
- m = tf.zeros([1, 8])
- cell = tf.contrib.grid_rnn.Grid2LSTMCell(2, use_peepholes=True)
+ with variable_scope.variable_scope(
+ 'root', initializer=init_ops.constant_initializer(0.5)):
+ x = array_ops.zeros([1, 3])
+ m = array_ops.zeros([1, 8])
+ cell = grid_rnn_cell.Grid2LSTMCell(2, use_peepholes=True)
self.assertEqual(cell.state_size, 8)
g, s = cell(x, m)
self.assertEqual(g.get_shape(), (1, 2))
self.assertEqual(s.get_shape(), (1, 8))
- sess.run([tf.global_variables_initializer()])
- res = sess.run(
- [g, s], {x: np.array([[1., 1., 1.]]),
- m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])})
+ sess.run([variables.global_variables_initializer()])
+ res = sess.run([g, s], {
+ x: np.array([[1., 1., 1.]]),
+ m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])
+ })
self.assertEqual(res[0].shape, (1, 2))
self.assertEqual(res[1].shape, (1, 8))
self.assertAllClose(res[0], [[0.95686918, 0.95686918]])
@@ -144,21 +168,22 @@ class GridRNNCellTest(tf.test.TestCase):
def testGrid2LSTMCellTied(self):
with self.test_session() as sess:
- with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)):
- x = tf.zeros([1, 3])
- m = tf.zeros([1, 8])
- cell = tf.contrib.grid_rnn.Grid2LSTMCell(
- 2, tied=True, use_peepholes=True)
+ with variable_scope.variable_scope(
+ 'root', initializer=init_ops.constant_initializer(0.5)):
+ x = array_ops.zeros([1, 3])
+ m = array_ops.zeros([1, 8])
+ cell = grid_rnn_cell.Grid2LSTMCell(2, tied=True, use_peepholes=True)
self.assertEqual(cell.state_size, 8)
g, s = cell(x, m)
self.assertEqual(g.get_shape(), (1, 2))
self.assertEqual(s.get_shape(), (1, 8))
- sess.run([tf.global_variables_initializer()])
- res = sess.run(
- [g, s], {x: np.array([[1., 1., 1.]]),
- m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])})
+ sess.run([variables.global_variables_initializer()])
+ res = sess.run([g, s], {
+ x: np.array([[1., 1., 1.]]),
+ m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])
+ })
self.assertEqual(res[0].shape, (1, 2))
self.assertEqual(res[1].shape, (1, 8))
self.assertAllClose(res[0], [[0.95686918, 0.95686918]])
@@ -168,45 +193,50 @@ class GridRNNCellTest(tf.test.TestCase):
def testGrid2LSTMCellWithRelu(self):
with self.test_session() as sess:
- with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)):
- x = tf.zeros([1, 3])
- m = tf.zeros([1, 4])
- cell = tf.contrib.grid_rnn.Grid2LSTMCell(
- 2, use_peepholes=True, non_recurrent_fn=tf.nn.relu)
+ with variable_scope.variable_scope(
+ 'root', initializer=init_ops.constant_initializer(0.5)):
+ x = array_ops.zeros([1, 3])
+ m = array_ops.zeros([1, 4])
+ cell = grid_rnn_cell.Grid2LSTMCell(
+ 2, use_peepholes=True, non_recurrent_fn=nn_ops.relu)
self.assertEqual(cell.state_size, 4)
g, s = cell(x, m)
self.assertEqual(g.get_shape(), (1, 2))
self.assertEqual(s.get_shape(), (1, 4))
- sess.run([tf.global_variables_initializer()])
- res = sess.run([g, s], {x: np.array([[1., 1., 1.]]),
- m: np.array([[0.1, 0.2, 0.3, 0.4]])})
+ sess.run([variables.global_variables_initializer()])
+ res = sess.run(
+ [g, s],
+ {x: np.array([[1., 1., 1.]]),
+ m: np.array([[0.1, 0.2, 0.3, 0.4]])})
self.assertEqual(res[0].shape, (1, 2))
self.assertEqual(res[1].shape, (1, 4))
self.assertAllClose(res[0], [[2.1831727, 2.1831727]])
- self.assertAllClose(res[1],
- [[0.92270052, 1.02325559, 0.66159075, 0.70475441]])
+ self.assertAllClose(res[1], [[0.92270052, 1.02325559, 0.66159075,
+ 0.70475441]])
"""RNNCell
"""
def testGrid2BasicRNNCell(self):
with self.test_session() as sess:
- with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)):
- x = tf.zeros([2, 2])
- m = tf.zeros([2, 4])
- cell = tf.contrib.grid_rnn.Grid2BasicRNNCell(2)
+ with variable_scope.variable_scope(
+ 'root', initializer=init_ops.constant_initializer(0.5)):
+ x = array_ops.zeros([2, 2])
+ m = array_ops.zeros([2, 4])
+ cell = grid_rnn_cell.Grid2BasicRNNCell(2)
self.assertEqual(cell.state_size, 4)
g, s = cell(x, m)
self.assertEqual(g.get_shape(), (2, 2))
self.assertEqual(s.get_shape(), (2, 4))
- sess.run([tf.global_variables_initializer()])
- res = sess.run(
- [g, s], {x: np.array([[1., 1.], [2., 2.]]),
- m: np.array([[0.1, 0.1, 0.1, 0.1], [0.2, 0.2, 0.2, 0.2]])})
+ sess.run([variables.global_variables_initializer()])
+ res = sess.run([g, s], {
+ x: np.array([[1., 1.], [2., 2.]]),
+ m: np.array([[0.1, 0.1, 0.1, 0.1], [0.2, 0.2, 0.2, 0.2]])
+ })
self.assertEqual(res[0].shape, (2, 2))
self.assertEqual(res[1].shape, (2, 4))
self.assertAllClose(res[0], [[0.94685763, 0.94685763],
@@ -217,20 +247,22 @@ class GridRNNCellTest(tf.test.TestCase):
def testGrid2BasicRNNCellTied(self):
with self.test_session() as sess:
- with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)):
- x = tf.zeros([2, 2])
- m = tf.zeros([2, 4])
- cell = tf.contrib.grid_rnn.Grid2BasicRNNCell(2, tied=True)
+ with variable_scope.variable_scope(
+ 'root', initializer=init_ops.constant_initializer(0.5)):
+ x = array_ops.zeros([2, 2])
+ m = array_ops.zeros([2, 4])
+ cell = grid_rnn_cell.Grid2BasicRNNCell(2, tied=True)
self.assertEqual(cell.state_size, 4)
g, s = cell(x, m)
self.assertEqual(g.get_shape(), (2, 2))
self.assertEqual(s.get_shape(), (2, 4))
- sess.run([tf.global_variables_initializer()])
- res = sess.run(
- [g, s], {x: np.array([[1., 1.], [2., 2.]]),
- m: np.array([[0.1, 0.1, 0.1, 0.1], [0.2, 0.2, 0.2, 0.2]])})
+ sess.run([variables.global_variables_initializer()])
+ res = sess.run([g, s], {
+ x: np.array([[1., 1.], [2., 2.]]),
+ m: np.array([[0.1, 0.1, 0.1, 0.1], [0.2, 0.2, 0.2, 0.2]])
+ })
self.assertEqual(res[0].shape, (2, 2))
self.assertEqual(res[1].shape, (2, 4))
self.assertAllClose(res[0], [[0.94685763, 0.94685763],
@@ -241,20 +273,21 @@ class GridRNNCellTest(tf.test.TestCase):
def testGrid2BasicRNNCellWithRelu(self):
with self.test_session() as sess:
- with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)):
- x = tf.zeros([1, 2])
- m = tf.zeros([1, 2])
- cell = tf.contrib.grid_rnn.Grid2BasicRNNCell(
- 2, non_recurrent_fn=tf.nn.relu)
+ with variable_scope.variable_scope(
+ 'root', initializer=init_ops.constant_initializer(0.5)):
+ x = array_ops.zeros([1, 2])
+ m = array_ops.zeros([1, 2])
+ cell = grid_rnn_cell.Grid2BasicRNNCell(2, non_recurrent_fn=nn_ops.relu)
self.assertEqual(cell.state_size, 2)
g, s = cell(x, m)
self.assertEqual(g.get_shape(), (1, 2))
self.assertEqual(s.get_shape(), (1, 2))
- sess.run([tf.global_variables_initializer()])
- res = sess.run([g, s], {x: np.array([[1., 1.]]),
- m: np.array([[0.1, 0.1]])})
+ sess.run([variables.global_variables_initializer()])
+ res = sess.run([g, s],
+ {x: np.array([[1., 1.]]),
+ m: np.array([[0.1, 0.1]])})
self.assertEqual(res[0].shape, (1, 2))
self.assertEqual(res[1].shape, (1, 2))
self.assertAllClose(res[0], [[1.80049896, 1.80049896]])
@@ -265,20 +298,22 @@ class GridRNNCellTest(tf.test.TestCase):
def testGrid1LSTMCell(self):
with self.test_session() as sess:
- with tf.variable_scope(
- 'root', initializer=tf.constant_initializer(0.5)) as root_scope:
- x = tf.zeros([1, 3])
- m = tf.zeros([1, 4])
- cell = tf.contrib.grid_rnn.Grid1LSTMCell(2, use_peepholes=True)
+ with variable_scope.variable_scope(
+ 'root', initializer=init_ops.constant_initializer(0.5)) as root_scope:
+ x = array_ops.zeros([1, 3])
+ m = array_ops.zeros([1, 4])
+ cell = grid_rnn_cell.Grid1LSTMCell(2, use_peepholes=True)
self.assertEqual(cell.state_size, 4)
g, s = cell(x, m)
self.assertEqual(g.get_shape(), (1, 2))
self.assertEqual(s.get_shape(), (1, 4))
- sess.run([tf.global_variables_initializer()])
- res = sess.run([g, s], {x: np.array([[1., 1., 1.]]),
- m: np.array([[0.1, 0.2, 0.3, 0.4]])})
+ sess.run([variables.global_variables_initializer()])
+ res = sess.run(
+ [g, s],
+ {x: np.array([[1., 1., 1.]]),
+ m: np.array([[0.1, 0.2, 0.3, 0.4]])})
self.assertEqual(res[0].shape, (1, 2))
self.assertEqual(res[1].shape, (1, 4))
self.assertAllClose(res[0], [[0.91287315, 0.91287315]])
@@ -287,12 +322,12 @@ class GridRNNCellTest(tf.test.TestCase):
root_scope.reuse_variables()
- x2 = tf.zeros([0, 0])
+ x2 = array_ops.zeros([0, 0])
g2, s2 = cell(x2, m)
self.assertEqual(g2.get_shape(), (1, 2))
self.assertEqual(s2.get_shape(), (1, 4))
- sess.run([tf.global_variables_initializer()])
+ sess.run([variables.global_variables_initializer()])
res = sess.run([g2, s2], {m: res[1]})
self.assertEqual(res[0].shape, (1, 2))
self.assertEqual(res[1].shape, (1, 4))
@@ -304,7 +339,7 @@ class GridRNNCellTest(tf.test.TestCase):
self.assertEqual(g3.get_shape(), (1, 2))
self.assertEqual(s3.get_shape(), (1, 4))
- sess.run([tf.global_variables_initializer()])
+ sess.run([variables.global_variables_initializer()])
res = sess.run([g3, s3], {m: res[1]})
self.assertEqual(res[0].shape, (1, 2))
self.assertEqual(res[1].shape, (1, 4))
@@ -317,20 +352,27 @@ class GridRNNCellTest(tf.test.TestCase):
def testGrid3LSTMCell(self):
with self.test_session() as sess:
- with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)):
- x = tf.zeros([1, 3])
- m = tf.zeros([1, 12])
- cell = tf.contrib.grid_rnn.Grid3LSTMCell(2, use_peepholes=True)
+ with variable_scope.variable_scope(
+ 'root', initializer=init_ops.constant_initializer(0.5)):
+ x = array_ops.zeros([1, 3])
+ m = array_ops.zeros([1, 12])
+ cell = grid_rnn_cell.Grid3LSTMCell(2, use_peepholes=True)
self.assertEqual(cell.state_size, 12)
g, s = cell(x, m)
self.assertEqual(g.get_shape(), (1, 2))
self.assertEqual(s.get_shape(), (1, 12))
- sess.run([tf.global_variables_initializer()])
- res = sess.run([g, s], {x: np.array([[1., 1., 1.]]),
- m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7,
- 0.8, -0.1, -0.2, -0.3, -0.4]])})
+ sess.run([variables.global_variables_initializer()])
+ res = sess.run([g, s], {
+ x:
+ np.array([[1., 1., 1.]]),
+ m:
+ np.array([[
+ 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, -0.1, -0.2, -0.3,
+ -0.4
+ ]])
+ })
self.assertEqual(res[0].shape, (1, 2))
self.assertEqual(res[1].shape, (1, 12))
@@ -345,23 +387,24 @@ class GridRNNCellTest(tf.test.TestCase):
def testGridRNNEdgeCasesLikeRelu(self):
with self.test_session() as sess:
- with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)):
- x = tf.zeros([3, 2])
- m = tf.zeros([0, 0])
+ with variable_scope.variable_scope(
+ 'root', initializer=init_ops.constant_initializer(0.5)):
+ x = array_ops.zeros([3, 2])
+ m = array_ops.zeros([0, 0])
# this is equivalent to relu
- cell = tf.contrib.grid_rnn.GridRNNCell(
+ cell = grid_rnn_cell.GridRNNCell(
num_units=2,
num_dims=1,
input_dims=0,
output_dims=0,
non_recurrent_dims=0,
- non_recurrent_fn=tf.nn.relu)
+ non_recurrent_fn=nn_ops.relu)
g, s = cell(x, m)
self.assertEqual(g.get_shape(), (3, 2))
self.assertEqual(s.get_shape(), (0, 0))
- sess.run([tf.global_variables_initializer()])
+ sess.run([variables.global_variables_initializer()])
res = sess.run([g, s], {x: np.array([[1., -1.], [-2, 1], [2, -1]])})
self.assertEqual(res[0].shape, (3, 2))
self.assertEqual(res[1].shape, (0, 0))
@@ -369,25 +412,28 @@ class GridRNNCellTest(tf.test.TestCase):
def testGridRNNEdgeCasesNoOutput(self):
with self.test_session() as sess:
- with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)):
- x = tf.zeros([1, 2])
- m = tf.zeros([1, 4])
+ with variable_scope.variable_scope(
+ 'root', initializer=init_ops.constant_initializer(0.5)):
+ x = array_ops.zeros([1, 2])
+ m = array_ops.zeros([1, 4])
# This cell produces no output
- cell = tf.contrib.grid_rnn.GridRNNCell(
+ cell = grid_rnn_cell.GridRNNCell(
num_units=2,
num_dims=2,
input_dims=0,
output_dims=None,
non_recurrent_dims=0,
- non_recurrent_fn=tf.nn.relu)
+ non_recurrent_fn=nn_ops.relu)
g, s = cell(x, m)
self.assertEqual(g.get_shape(), (0, 0))
self.assertEqual(s.get_shape(), (1, 4))
- sess.run([tf.global_variables_initializer()])
- res = sess.run([g, s], {x: np.array([[1., 1.]]),
- m: np.array([[0.1, 0.1, 0.1, 0.1]])})
+ sess.run([variables.global_variables_initializer()])
+ res = sess.run(
+ [g, s],
+ {x: np.array([[1., 1.]]),
+ m: np.array([[0.1, 0.1, 0.1, 0.1]])})
self.assertEqual(res[0].shape, (0, 0))
self.assertEqual(res[1].shape, (1, 4))
@@ -400,15 +446,16 @@ class GridRNNCellTest(tf.test.TestCase):
max_length = 6 # unrolled up to this length
num_units = 2
- with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)):
- cell = tf.contrib.grid_rnn.Grid2LSTMCell(num_units=num_units)
+ with variable_scope.variable_scope(
+ 'root', initializer=init_ops.constant_initializer(0.5)):
+ cell = grid_rnn_cell.Grid2LSTMCell(num_units=num_units)
inputs = max_length * [
- tf.placeholder(
- tf.float32, shape=(batch_size, input_size))
+ array_ops.placeholder(
+ dtypes.float32, shape=(batch_size, input_size))
]
- outputs, state = tf.contrib.rnn.static_rnn(cell, inputs, dtype=tf.float32)
+ outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
self.assertEqual(len(outputs), len(inputs))
self.assertEqual(state.get_shape(), (batch_size, 8))
@@ -419,7 +466,7 @@ class GridRNNCellTest(tf.test.TestCase):
self.assertEqual(out.dtype, inp.dtype)
with self.test_session() as sess:
- sess.run(tf.global_variables_initializer())
+ sess.run(variables.global_variables_initializer())
input_value = np.ones((batch_size, input_size))
values = sess.run(outputs + [state], feed_dict={inputs[0]: input_value})
@@ -432,16 +479,17 @@ class GridRNNCellTest(tf.test.TestCase):
max_length = 6 # unrolled up to this length
num_units = 2
- with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)):
- cell = tf.contrib.grid_rnn.Grid2LSTMCell(
- num_units=num_units, non_recurrent_fn=tf.nn.relu)
+ with variable_scope.variable_scope(
+ 'root', initializer=init_ops.constant_initializer(0.5)):
+ cell = grid_rnn_cell.Grid2LSTMCell(
+ num_units=num_units, non_recurrent_fn=nn_ops.relu)
inputs = max_length * [
- tf.placeholder(
- tf.float32, shape=(batch_size, input_size))
+ array_ops.placeholder(
+ dtypes.float32, shape=(batch_size, input_size))
]
- outputs, state = tf.contrib.rnn.static_rnn(cell, inputs, dtype=tf.float32)
+ outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
self.assertEqual(len(outputs), len(inputs))
self.assertEqual(state.get_shape(), (batch_size, 4))
@@ -452,7 +500,7 @@ class GridRNNCellTest(tf.test.TestCase):
self.assertEqual(out.dtype, inp.dtype)
with self.test_session() as sess:
- sess.run(tf.global_variables_initializer())
+ sess.run(variables.global_variables_initializer())
input_value = np.ones((batch_size, input_size))
values = sess.run(outputs + [state], feed_dict={inputs[0]: input_value})
@@ -465,16 +513,17 @@ class GridRNNCellTest(tf.test.TestCase):
max_length = 6 # unrolled up to this length
num_units = 2
- with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)):
- cell = tf.contrib.grid_rnn.Grid3LSTMCell(
- num_units=num_units, non_recurrent_fn=tf.nn.relu)
+ with variable_scope.variable_scope(
+ 'root', initializer=init_ops.constant_initializer(0.5)):
+ cell = grid_rnn_cell.Grid3LSTMCell(
+ num_units=num_units, non_recurrent_fn=nn_ops.relu)
inputs = max_length * [
- tf.placeholder(
- tf.float32, shape=(batch_size, input_size))
+ array_ops.placeholder(
+ dtypes.float32, shape=(batch_size, input_size))
]
- outputs, state = tf.contrib.rnn.static_rnn(cell, inputs, dtype=tf.float32)
+ outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
self.assertEqual(len(outputs), len(inputs))
self.assertEqual(state.get_shape(), (batch_size, 8))
@@ -485,7 +534,7 @@ class GridRNNCellTest(tf.test.TestCase):
self.assertEqual(out.dtype, inp.dtype)
with self.test_session() as sess:
- sess.run(tf.global_variables_initializer())
+ sess.run(variables.global_variables_initializer())
input_value = np.ones((batch_size, input_size))
values = sess.run(outputs + [state], feed_dict={inputs[0]: input_value})
@@ -498,14 +547,17 @@ class GridRNNCellTest(tf.test.TestCase):
max_length = 6 # unrolled up to this length
num_units = 2
- with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)):
- cell = tf.contrib.grid_rnn.Grid1LSTMCell(num_units=num_units)
+ with variable_scope.variable_scope(
+ 'root', initializer=init_ops.constant_initializer(0.5)):
+ cell = grid_rnn_cell.Grid1LSTMCell(num_units=num_units)
# for 1-LSTM, we only feed the first step
- inputs = ([tf.placeholder(tf.float32, shape=(batch_size, input_size))]
- + (max_length - 1) * [tf.zeros([batch_size, input_size])])
+ inputs = ([
+ array_ops.placeholder(
+ dtypes.float32, shape=(batch_size, input_size))
+ ] + (max_length - 1) * [array_ops.zeros([batch_size, input_size])])
- outputs, state = tf.contrib.rnn.static_rnn(cell, inputs, dtype=tf.float32)
+ outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
self.assertEqual(len(outputs), len(inputs))
self.assertEqual(state.get_shape(), (batch_size, 4))
@@ -515,7 +567,7 @@ class GridRNNCellTest(tf.test.TestCase):
self.assertEqual(out.dtype, inp.dtype)
with self.test_session() as sess:
- sess.run(tf.global_variables_initializer())
+ sess.run(variables.global_variables_initializer())
input_value = np.ones((batch_size, input_size))
values = sess.run(outputs + [state], feed_dict={inputs[0]: input_value})
@@ -524,4 +576,4 @@ class GridRNNCellTest(tf.test.TestCase):
if __name__ == '__main__':
- tf.test.main()
+ test.main()