aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/rnn
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-10 14:36:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-10 15:01:47 -0700
commitacf0ee82092727afc2067316982407cf5e496f75 (patch)
treee5df1811ab47e259a1f30c46e22c251411ad326e /tensorflow/contrib/rnn
parentf1cc58bb4144de61a693076d8ff8a26b2644ebbb (diff)
Move from deprecated self.test_session() to self.cached_session().
self.test_session() has been deprecated in 9962eb5e84b15e309410071b06c2ed2d6148ed44 as its name confuses readers of the test. Moving to cached_session() instead which is more explicit about: * the fact that the session may be reused. * the session is not closed even when doing a "with self.test_session()" statement. PiperOrigin-RevId: 212336417
Diffstat (limited to 'tensorflow/contrib/rnn')
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py2
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/fused_rnn_cell_test.py4
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py56
3 files changed, 31 insertions, 31 deletions
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
index aa4562be7c..bf699db3ed 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
@@ -1906,7 +1906,7 @@ class StateSaverRNNTest(test.TestCase):
state_saver = TestStateSaverWithCounters(batch_size, 2 * num_units)
out, state, state_saver = self._factory(scope=None, state_saver=state_saver)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
sess.run(variables_lib.local_variables_initializer())
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/fused_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/fused_rnn_cell_test.py
index f2a032e41e..8d34b9e852 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/fused_rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/fused_rnn_cell_test.py
@@ -38,7 +38,7 @@ class FusedRnnCellTest(test.TestCase):
def testBasicRNNFusedWrapper(self):
"""This test checks that using a wrapper for BasicRNN works as expected."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
initializer = init_ops.random_uniform_initializer(
-0.01, 0.01, seed=19890212)
cell = rnn_cell.BasicRNNCell(10)
@@ -106,7 +106,7 @@ class FusedRnnCellTest(test.TestCase):
self.assertAllClose(basic, fused, rtol=1e-2, atol=1e-2)
def testTimeReversedFusedRNN(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
initializer = init_ops.random_uniform_initializer(
-0.01, 0.01, seed=19890213)
fw_cell = rnn_cell.BasicRNNCell(10)
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
index 2df8f0ec05..6689664fb9 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
@@ -47,7 +47,7 @@ from tensorflow.python.util import nest
class RNNCellTest(test.TestCase):
def testCoupledInputForgetGateLSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_units = 2
state_size = num_units * 2
batch_size = 3
@@ -81,7 +81,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[1], expected_state)
def testTimeFreqLSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_units = 8
state_size = num_units * 2
batch_size = 3
@@ -120,7 +120,7 @@ class RNNCellTest(test.TestCase):
float(np.linalg.norm((res[1][0, :] - res[1][i, :]))) > 1e-6)
def testGridLSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_units = 8
batch_size = 3
input_size = 4
@@ -166,7 +166,7 @@ class RNNCellTest(test.TestCase):
.state_f00_b00_c[i, :]))) > 1e-6)
def testGridLSTMCellWithFrequencyBlocks(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_units = 8
batch_size = 3
feature_size = 2
@@ -248,7 +248,7 @@ class RNNCellTest(test.TestCase):
]],
dtype=np.float32)
for state_is_tuple in [False, True]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"state_is_tuple" + str(state_is_tuple),
initializer=init_ops.constant_initializer(0.5)):
@@ -294,7 +294,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(np.concatenate(res[1], axis=1), expected_state)
def testBidirectionGridLSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_units = 2
batch_size = 3
input_size = 4
@@ -374,7 +374,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(np.concatenate(res[1], axis=1), expected_state)
def testBidirectionGridLSTMCellWithSliceOffset(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_units = 2
batch_size = 3
input_size = 4
@@ -487,7 +487,7 @@ class RNNCellTest(test.TestCase):
input_size = 4
for state_is_tuple in [False, True]:
with ops.Graph().as_default():
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"state_is_tuple_" + str(state_is_tuple)):
lstm_cell = rnn_cell.BasicLSTMCell(
@@ -538,7 +538,7 @@ class RNNCellTest(test.TestCase):
batch_size = 3
for state_is_tuple in [False, True]:
with ops.Graph().as_default():
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"state_is_tuple_" + str(state_is_tuple)):
lstm_cell = rnn_cell.BasicLSTMCell(
@@ -677,7 +677,7 @@ class RNNCellTest(test.TestCase):
0.79457647, 0.79457647, 0.79457647, 0.79457647, 0.79457653, 0.79457653,
0.62739348, 0.62739348, 0.62739348, 0.62739348, 0.62739348, 0.62739348
]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"nas_test", initializer=init_ops.constant_initializer(0.5)):
cell = contrib_rnn_cell.NASCell(num_units=num_units)
@@ -725,7 +725,7 @@ class RNNCellTest(test.TestCase):
0.78973997, 0.78973997, 0.78973997, 0.78973997, 0.78973997, 0.78973997,
1.87398517, 1.87398517, 1.87398517, 1.87398517, 1.87398517
]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"nas_proj_test", initializer=init_ops.constant_initializer(0.5)):
cell = contrib_rnn_cell.NASCell(num_units=num_units, num_proj=num_proj)
@@ -765,7 +765,7 @@ class RNNCellTest(test.TestCase):
[[0.13752282, 0.13752282], [0.10545051, 0.10545051],
[0.10074195, 0.10074195]],
dtype=np.float32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"ugrnn_cell_test", initializer=init_ops.constant_initializer(0.5)):
cell = contrib_rnn_cell.UGRNNCell(num_units=num_units)
@@ -796,7 +796,7 @@ class RNNCellTest(test.TestCase):
[[2.00431061, 2.00431061], [4.00060606, 4.00060606],
[6.00008249, 6.00008249]],
dtype=np.float32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"intersection_rnn_cell_test",
initializer=init_ops.constant_initializer(0.5)):
@@ -837,7 +837,7 @@ class RNNCellTest(test.TestCase):
cell(inputs, init_state)
def testPhasedLSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_units = 2
batch_size = 3
input_size = 4
@@ -874,7 +874,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[1].h, expected_state_h)
def testConv1DLSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
shape = [2, 1]
filter_size = [3]
num_features = 1
@@ -907,7 +907,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[1].h, expected_state_h)
def testConv2DLSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
shape = [2, 2, 1]
filter_size = [3, 3]
num_features = 1
@@ -948,7 +948,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[1].h, expected_state_h)
def testConv3DLSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
shape = [2, 2, 2, 1]
filter_size = [3, 3, 3]
num_features = 1
@@ -999,7 +999,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[1].h, expected_state_h)
def testHighwayWrapper(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"base_cell", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3])
@@ -1030,7 +1030,7 @@ class RNNCellTest(test.TestCase):
# Try with input dimension equal to num_units or not.
for num_inputs in [num_units, num_units + number_of_groups]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root1_%d" % num_inputs,
initializer=init_ops.constant_initializer(0.5)):
@@ -1059,7 +1059,7 @@ class RNNCellTest(test.TestCase):
# Try with num_inputs equal to or not equal to num_units.
for num_inputs in [num_units, num_units + number_of_groups]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root2_%d" % num_inputs,
initializer=init_ops.constant_initializer(0.5)):
@@ -1092,7 +1092,7 @@ class RNNCellTest(test.TestCase):
batch_size = 2
num_units = 4
number_of_groups = 2
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope(
"glstm_failure", initializer=init_ops.constant_initializer(0.5)):
gcell = contrib_rnn_cell.GLSTMCell(
@@ -1121,7 +1121,7 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
# NOTE: all the values in the current test case have been calculated.
def testBasicLSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -1189,7 +1189,7 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
def testBasicLSTMCellWithoutNorm(self):
"""Tests that BasicLSTMCell with layer_norm=False."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -1256,7 +1256,7 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
self.assertAllClose(res[1].h, expected_h, 1e-5)
def testBasicLSTMCellWithStateTuple(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -1294,7 +1294,7 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
def testBasicLSTMCellWithStateTupleLayerNorm(self):
"""The results of LSTMCell and LayerNormBasicLSTMCell should be the same."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -1353,7 +1353,7 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
num_units = 5
allowed_low = [1, 2, 3]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"other", initializer=init_ops.constant_initializer(1)):
x = array_ops.zeros([1, 5])
@@ -1479,7 +1479,7 @@ class CompiledWrapperTest(test.TestCase):
self.assertAllClose(xla_g, non_xla_g, atol=atol)
def testMultiRNNCellWithStateTuple(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -1583,7 +1583,7 @@ class WeightNormLSTMCellTest(test.TestCase):
def _cell_output(self, cell):
"""Calculates cell output."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
init = init_ops.constant_initializer(0.5)
with variable_scope.variable_scope("root",
initializer=init):