diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-10 14:36:52 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-10 15:01:47 -0700 |
commit | acf0ee82092727afc2067316982407cf5e496f75 (patch) | |
tree | e5df1811ab47e259a1f30c46e22c251411ad326e /tensorflow/contrib/rnn | |
parent | f1cc58bb4144de61a693076d8ff8a26b2644ebbb (diff) |
Move from deprecated self.test_session() to self.cached_session().
self.test_session() has been deprecated in 9962eb5e84b15e309410071b06c2ed2d6148ed44 as its name confuses readers of the test. Moving to cached_session() instead which is more explicit about:
* the fact that the session may be reused.
* the session is not closed even when doing a "with self.test_session()" statement.
PiperOrigin-RevId: 212336417
Diffstat (limited to 'tensorflow/contrib/rnn')
3 files changed, 31 insertions, 31 deletions
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py index aa4562be7c..bf699db3ed 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py @@ -1906,7 +1906,7 @@ class StateSaverRNNTest(test.TestCase): state_saver = TestStateSaverWithCounters(batch_size, 2 * num_units) out, state, state_saver = self._factory(scope=None, state_saver=state_saver) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables_lib.global_variables_initializer()) sess.run(variables_lib.local_variables_initializer()) diff --git a/tensorflow/contrib/rnn/python/kernel_tests/fused_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/fused_rnn_cell_test.py index f2a032e41e..8d34b9e852 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/fused_rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/fused_rnn_cell_test.py @@ -38,7 +38,7 @@ class FusedRnnCellTest(test.TestCase): def testBasicRNNFusedWrapper(self): """This test checks that using a wrapper for BasicRNN works as expected.""" - with self.test_session() as sess: + with self.cached_session() as sess: initializer = init_ops.random_uniform_initializer( -0.01, 0.01, seed=19890212) cell = rnn_cell.BasicRNNCell(10) @@ -106,7 +106,7 @@ class FusedRnnCellTest(test.TestCase): self.assertAllClose(basic, fused, rtol=1e-2, atol=1e-2) def testTimeReversedFusedRNN(self): - with self.test_session() as sess: + with self.cached_session() as sess: initializer = init_ops.random_uniform_initializer( -0.01, 0.01, seed=19890213) fw_cell = rnn_cell.BasicRNNCell(10) diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py index 2df8f0ec05..6689664fb9 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py @@ -47,7 +47,7 @@ from tensorflow.python.util import nest class RNNCellTest(test.TestCase): def testCoupledInputForgetGateLSTMCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: num_units = 2 state_size = num_units * 2 batch_size = 3 @@ -81,7 +81,7 @@ class RNNCellTest(test.TestCase): self.assertAllClose(res[1], expected_state) def testTimeFreqLSTMCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: num_units = 8 state_size = num_units * 2 batch_size = 3 @@ -120,7 +120,7 @@ class RNNCellTest(test.TestCase): float(np.linalg.norm((res[1][0, :] - res[1][i, :]))) > 1e-6) def testGridLSTMCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: num_units = 8 batch_size = 3 input_size = 4 @@ -166,7 +166,7 @@ class RNNCellTest(test.TestCase): .state_f00_b00_c[i, :]))) > 1e-6) def testGridLSTMCellWithFrequencyBlocks(self): - with self.test_session() as sess: + with self.cached_session() as sess: num_units = 8 batch_size = 3 feature_size = 2 @@ -248,7 +248,7 @@ class RNNCellTest(test.TestCase): ]], dtype=np.float32) for state_is_tuple in [False, True]: - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "state_is_tuple" + str(state_is_tuple), initializer=init_ops.constant_initializer(0.5)): @@ -294,7 +294,7 @@ class RNNCellTest(test.TestCase): self.assertAllClose(np.concatenate(res[1], axis=1), expected_state) def testBidirectionGridLSTMCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: num_units = 2 batch_size = 3 input_size = 4 @@ -374,7 +374,7 @@ class RNNCellTest(test.TestCase): self.assertAllClose(np.concatenate(res[1], axis=1), expected_state) def testBidirectionGridLSTMCellWithSliceOffset(self): - with self.test_session() as sess: + with self.cached_session() as sess: num_units = 2 batch_size = 3 input_size = 4 @@ -487,7 +487,7 @@ class RNNCellTest(test.TestCase): input_size = 4 for state_is_tuple in [False, True]: with ops.Graph().as_default(): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "state_is_tuple_" + str(state_is_tuple)): lstm_cell = rnn_cell.BasicLSTMCell( @@ -538,7 +538,7 @@ class RNNCellTest(test.TestCase): batch_size = 3 for state_is_tuple in [False, True]: with ops.Graph().as_default(): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "state_is_tuple_" + str(state_is_tuple)): lstm_cell = rnn_cell.BasicLSTMCell( @@ -677,7 +677,7 @@ class RNNCellTest(test.TestCase): 0.79457647, 0.79457647, 0.79457647, 0.79457647, 0.79457653, 0.79457653, 0.62739348, 0.62739348, 0.62739348, 0.62739348, 0.62739348, 0.62739348 ]]) - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "nas_test", initializer=init_ops.constant_initializer(0.5)): cell = contrib_rnn_cell.NASCell(num_units=num_units) @@ -725,7 +725,7 @@ class RNNCellTest(test.TestCase): 0.78973997, 0.78973997, 0.78973997, 0.78973997, 0.78973997, 0.78973997, 1.87398517, 1.87398517, 1.87398517, 1.87398517, 1.87398517 ]]) - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "nas_proj_test", initializer=init_ops.constant_initializer(0.5)): cell = contrib_rnn_cell.NASCell(num_units=num_units, num_proj=num_proj) @@ -765,7 +765,7 @@ class RNNCellTest(test.TestCase): [[0.13752282, 0.13752282], [0.10545051, 0.10545051], [0.10074195, 0.10074195]], dtype=np.float32) - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "ugrnn_cell_test", initializer=init_ops.constant_initializer(0.5)): cell = contrib_rnn_cell.UGRNNCell(num_units=num_units) @@ -796,7 +796,7 @@ class RNNCellTest(test.TestCase): [[2.00431061, 2.00431061], [4.00060606, 4.00060606], [6.00008249, 6.00008249]], dtype=np.float32) - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "intersection_rnn_cell_test", initializer=init_ops.constant_initializer(0.5)): @@ -837,7 +837,7 @@ class RNNCellTest(test.TestCase): cell(inputs, init_state) def testPhasedLSTMCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: num_units = 2 batch_size = 3 input_size = 4 @@ -874,7 +874,7 @@ class RNNCellTest(test.TestCase): self.assertAllClose(res[1].h, expected_state_h) def testConv1DLSTMCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: shape = [2, 1] filter_size = [3] num_features = 1 @@ -907,7 +907,7 @@ class RNNCellTest(test.TestCase): self.assertAllClose(res[1].h, expected_state_h) def testConv2DLSTMCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: shape = [2, 2, 1] filter_size = [3, 3] num_features = 1 @@ -948,7 +948,7 @@ class RNNCellTest(test.TestCase): self.assertAllClose(res[1].h, expected_state_h) def testConv3DLSTMCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: shape = [2, 2, 2, 1] filter_size = [3, 3, 3] num_features = 1 @@ -999,7 +999,7 @@ class RNNCellTest(test.TestCase): self.assertAllClose(res[1].h, expected_state_h) def testHighwayWrapper(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "base_cell", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 3]) @@ -1030,7 +1030,7 @@ class RNNCellTest(test.TestCase): # Try with input dimension equal to num_units or not. for num_inputs in [num_units, num_units + number_of_groups]: - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root1_%d" % num_inputs, initializer=init_ops.constant_initializer(0.5)): @@ -1059,7 +1059,7 @@ class RNNCellTest(test.TestCase): # Try with num_inputs equal to or not equal to num_units. for num_inputs in [num_units, num_units + number_of_groups]: - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root2_%d" % num_inputs, initializer=init_ops.constant_initializer(0.5)): @@ -1092,7 +1092,7 @@ class RNNCellTest(test.TestCase): batch_size = 2 num_units = 4 number_of_groups = 2 - with self.test_session(): + with self.cached_session(): with variable_scope.variable_scope( "glstm_failure", initializer=init_ops.constant_initializer(0.5)): gcell = contrib_rnn_cell.GLSTMCell( @@ -1121,7 +1121,7 @@ class LayerNormBasicLSTMCellTest(test.TestCase): # NOTE: all the values in the current test case have been calculated. def testBasicLSTMCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) @@ -1189,7 +1189,7 @@ class LayerNormBasicLSTMCellTest(test.TestCase): def testBasicLSTMCellWithoutNorm(self): """Tests that BasicLSTMCell with layer_norm=False.""" - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) @@ -1256,7 +1256,7 @@ class LayerNormBasicLSTMCellTest(test.TestCase): self.assertAllClose(res[1].h, expected_h, 1e-5) def testBasicLSTMCellWithStateTuple(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) @@ -1294,7 +1294,7 @@ class LayerNormBasicLSTMCellTest(test.TestCase): def testBasicLSTMCellWithStateTupleLayerNorm(self): """The results of LSTMCell and LayerNormBasicLSTMCell should be the same.""" - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) @@ -1353,7 +1353,7 @@ class LayerNormBasicLSTMCellTest(test.TestCase): num_units = 5 allowed_low = [1, 2, 3] - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "other", initializer=init_ops.constant_initializer(1)): x = array_ops.zeros([1, 5]) @@ -1479,7 +1479,7 @@ class CompiledWrapperTest(test.TestCase): self.assertAllClose(xla_g, non_xla_g, atol=atol) def testMultiRNNCellWithStateTuple(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) @@ -1583,7 +1583,7 @@ class WeightNormLSTMCellTest(test.TestCase): def _cell_output(self, cell): """Calculates cell output.""" - with self.test_session() as sess: + with self.cached_session() as sess: init = init_ops.constant_initializer(0.5) with variable_scope.variable_scope("root", initializer=init): |