aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/grid_rnn
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-21 00:02:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-21 00:07:39 -0700
commit2952f5134905af795ba90ae1eb97e39091ba9843 (patch)
treef73bc5cd0342d9449114bd933863c2aa55610aa2 /tensorflow/contrib/grid_rnn
parentcf047f7755f3400ee128db2571042091fe9f8314 (diff)
Move from deprecated self.test_session() to self.cached_session().
self.test_session() has been deprecated in 9962eb5e84b15e309410071b06c2ed2d6148ed44 as its name confuses readers of the test. Moving to cached_session() instead which is more explicit about: * the fact that the session may be reused. * the session is not closed even when doing a "with self.test_session()" statement. PiperOrigin-RevId: 213944355
Diffstat (limited to 'tensorflow/contrib/grid_rnn')
-rw-r--r--tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py28
1 files changed, 14 insertions, 14 deletions
diff --git a/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py b/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py
index fed8a771cc..27aed091c2 100644
--- a/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py
+++ b/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py
@@ -233,7 +233,7 @@ class GridRNNCellTest(test.TestCase):
([[1.38917875, 1.49043763]], [[0.83884692, 0.86036491]])))
def testGrid2LSTMCellWithRelu(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3])
@@ -261,7 +261,7 @@ class GridRNNCellTest(test.TestCase):
"""
def testGrid2BasicRNNCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([2, 2])
@@ -292,7 +292,7 @@ class GridRNNCellTest(test.TestCase):
[[0.80049908, 0.80049908], [0.97574311, 0.97574311]]))
def testGrid2BasicRNNCellTied(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([2, 2])
@@ -323,7 +323,7 @@ class GridRNNCellTest(test.TestCase):
[[0.80049908, 0.80049908], [0.97574311, 0.97574311]]))
def testGrid2BasicRNNCellWithRelu(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -348,7 +348,7 @@ class GridRNNCellTest(test.TestCase):
"""
def testGrid1LSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)) as root_scope:
x = array_ops.zeros([1, 3])
@@ -410,7 +410,7 @@ class GridRNNCellTest(test.TestCase):
"""
def testGrid3LSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3])
@@ -455,7 +455,7 @@ class GridRNNCellTest(test.TestCase):
"""
def testGridRNNEdgeCasesLikeRelu(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([3, 2])
@@ -481,7 +481,7 @@ class GridRNNCellTest(test.TestCase):
self.assertAllClose(res_g, ([[0, 0], [0, 0], [0.5, 0.5]],))
def testGridRNNEdgeCasesNoOutput(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -541,7 +541,7 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(out[0].get_shape()[1], num_units)
self.assertEqual(out[0].dtype, inp.dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
input_value = np.ones((batch_size, input_size))
@@ -581,7 +581,7 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(out[0].get_shape()[1], num_units)
self.assertEqual(out[0].dtype, inp.dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
input_value = np.ones((batch_size, input_size))
@@ -623,7 +623,7 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(out[0].get_shape()[1], num_units)
self.assertEqual(out[0].dtype, inp.dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
input_value = np.ones((batch_size, input_size))
@@ -663,7 +663,7 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(out[0].get_shape(), (3, num_units))
self.assertEqual(out[0].dtype, inp.dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
input_value = np.ones((batch_size, input_size))
@@ -700,7 +700,7 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(out[0].get_shape()[1], num_units)
self.assertEqual(out[0].dtype, inp.dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
input_value = np.ones((3, input_size))
@@ -715,7 +715,7 @@ class GridRNNCellTest(test.TestCase):
def testGrid2LSTMCellLegacy(self):
"""Test for legacy case (when state_is_tuple=False)."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3])