diff options
Diffstat (limited to 'tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py')
-rw-r--r-- | tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py | 56 |
1 files changed, 28 insertions, 28 deletions
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py index 2df8f0ec05..6689664fb9 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py @@ -47,7 +47,7 @@ from tensorflow.python.util import nest class RNNCellTest(test.TestCase): def testCoupledInputForgetGateLSTMCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: num_units = 2 state_size = num_units * 2 batch_size = 3 @@ -81,7 +81,7 @@ class RNNCellTest(test.TestCase): self.assertAllClose(res[1], expected_state) def testTimeFreqLSTMCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: num_units = 8 state_size = num_units * 2 batch_size = 3 @@ -120,7 +120,7 @@ class RNNCellTest(test.TestCase): float(np.linalg.norm((res[1][0, :] - res[1][i, :]))) > 1e-6) def testGridLSTMCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: num_units = 8 batch_size = 3 input_size = 4 @@ -166,7 +166,7 @@ class RNNCellTest(test.TestCase): .state_f00_b00_c[i, :]))) > 1e-6) def testGridLSTMCellWithFrequencyBlocks(self): - with self.test_session() as sess: + with self.cached_session() as sess: num_units = 8 batch_size = 3 feature_size = 2 @@ -248,7 +248,7 @@ class RNNCellTest(test.TestCase): ]], dtype=np.float32) for state_is_tuple in [False, True]: - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "state_is_tuple" + str(state_is_tuple), initializer=init_ops.constant_initializer(0.5)): @@ -294,7 +294,7 @@ class RNNCellTest(test.TestCase): self.assertAllClose(np.concatenate(res[1], axis=1), expected_state) def testBidirectionGridLSTMCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: num_units = 2 batch_size = 3 input_size = 4 @@ -374,7 +374,7 @@ class RNNCellTest(test.TestCase): self.assertAllClose(np.concatenate(res[1], axis=1), expected_state) def testBidirectionGridLSTMCellWithSliceOffset(self): - with self.test_session() as sess: + with self.cached_session() as sess: num_units = 2 batch_size = 3 input_size = 4 @@ -487,7 +487,7 @@ class RNNCellTest(test.TestCase): input_size = 4 for state_is_tuple in [False, True]: with ops.Graph().as_default(): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "state_is_tuple_" + str(state_is_tuple)): lstm_cell = rnn_cell.BasicLSTMCell( @@ -538,7 +538,7 @@ class RNNCellTest(test.TestCase): batch_size = 3 for state_is_tuple in [False, True]: with ops.Graph().as_default(): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "state_is_tuple_" + str(state_is_tuple)): lstm_cell = rnn_cell.BasicLSTMCell( @@ -677,7 +677,7 @@ class RNNCellTest(test.TestCase): 0.79457647, 0.79457647, 0.79457647, 0.79457647, 0.79457653, 0.79457653, 0.62739348, 0.62739348, 0.62739348, 0.62739348, 0.62739348, 0.62739348 ]]) - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "nas_test", initializer=init_ops.constant_initializer(0.5)): cell = contrib_rnn_cell.NASCell(num_units=num_units) @@ -725,7 +725,7 @@ class RNNCellTest(test.TestCase): 0.78973997, 0.78973997, 0.78973997, 0.78973997, 0.78973997, 0.78973997, 1.87398517, 1.87398517, 1.87398517, 1.87398517, 1.87398517 ]]) - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "nas_proj_test", initializer=init_ops.constant_initializer(0.5)): cell = contrib_rnn_cell.NASCell(num_units=num_units, num_proj=num_proj) @@ -765,7 +765,7 @@ class RNNCellTest(test.TestCase): [[0.13752282, 0.13752282], [0.10545051, 0.10545051], [0.10074195, 0.10074195]], dtype=np.float32) - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "ugrnn_cell_test", initializer=init_ops.constant_initializer(0.5)): cell = contrib_rnn_cell.UGRNNCell(num_units=num_units) @@ -796,7 +796,7 @@ class RNNCellTest(test.TestCase): [[2.00431061, 2.00431061], [4.00060606, 4.00060606], [6.00008249, 6.00008249]], dtype=np.float32) - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "intersection_rnn_cell_test", initializer=init_ops.constant_initializer(0.5)): @@ -837,7 +837,7 @@ class RNNCellTest(test.TestCase): cell(inputs, init_state) def testPhasedLSTMCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: num_units = 2 batch_size = 3 input_size = 4 @@ -874,7 +874,7 @@ class RNNCellTest(test.TestCase): self.assertAllClose(res[1].h, expected_state_h) def testConv1DLSTMCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: shape = [2, 1] filter_size = [3] num_features = 1 @@ -907,7 +907,7 @@ class RNNCellTest(test.TestCase): self.assertAllClose(res[1].h, expected_state_h) def testConv2DLSTMCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: shape = [2, 2, 1] filter_size = [3, 3] num_features = 1 @@ -948,7 +948,7 @@ class RNNCellTest(test.TestCase): self.assertAllClose(res[1].h, expected_state_h) def testConv3DLSTMCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: shape = [2, 2, 2, 1] filter_size = [3, 3, 3] num_features = 1 @@ -999,7 +999,7 @@ class RNNCellTest(test.TestCase): self.assertAllClose(res[1].h, expected_state_h) def testHighwayWrapper(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "base_cell", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 3]) @@ -1030,7 +1030,7 @@ class RNNCellTest(test.TestCase): # Try with input dimension equal to num_units or not. for num_inputs in [num_units, num_units + number_of_groups]: - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root1_%d" % num_inputs, initializer=init_ops.constant_initializer(0.5)): @@ -1059,7 +1059,7 @@ class RNNCellTest(test.TestCase): # Try with num_inputs equal to or not equal to num_units. for num_inputs in [num_units, num_units + number_of_groups]: - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root2_%d" % num_inputs, initializer=init_ops.constant_initializer(0.5)): @@ -1092,7 +1092,7 @@ class RNNCellTest(test.TestCase): batch_size = 2 num_units = 4 number_of_groups = 2 - with self.test_session(): + with self.cached_session(): with variable_scope.variable_scope( "glstm_failure", initializer=init_ops.constant_initializer(0.5)): gcell = contrib_rnn_cell.GLSTMCell( @@ -1121,7 +1121,7 @@ class LayerNormBasicLSTMCellTest(test.TestCase): # NOTE: all the values in the current test case have been calculated. def testBasicLSTMCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) @@ -1189,7 +1189,7 @@ class LayerNormBasicLSTMCellTest(test.TestCase): def testBasicLSTMCellWithoutNorm(self): """Tests that BasicLSTMCell with layer_norm=False.""" - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) @@ -1256,7 +1256,7 @@ class LayerNormBasicLSTMCellTest(test.TestCase): self.assertAllClose(res[1].h, expected_h, 1e-5) def testBasicLSTMCellWithStateTuple(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) @@ -1294,7 +1294,7 @@ class LayerNormBasicLSTMCellTest(test.TestCase): def testBasicLSTMCellWithStateTupleLayerNorm(self): """The results of LSTMCell and LayerNormBasicLSTMCell should be the same.""" - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) @@ -1353,7 +1353,7 @@ class LayerNormBasicLSTMCellTest(test.TestCase): num_units = 5 allowed_low = [1, 2, 3] - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "other", initializer=init_ops.constant_initializer(1)): x = array_ops.zeros([1, 5]) @@ -1479,7 +1479,7 @@ class CompiledWrapperTest(test.TestCase): self.assertAllClose(xla_g, non_xla_g, atol=atol) def testMultiRNNCellWithStateTuple(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) @@ -1583,7 +1583,7 @@ class WeightNormLSTMCellTest(test.TestCase): def _cell_output(self, cell): """Calculates cell output.""" - with self.test_session() as sess: + with self.cached_session() as sess: init = init_ops.constant_initializer(0.5) with variable_scope.variable_scope("root", initializer=init): |