aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/rnn
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2018-09-04 18:32:44 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-04 18:39:34 -0700
commit30db26a5f4983b248bd4565d08c59155ad8bb36c (patch)
treeb531d8d39712ab41cdd6167a67f2aff9189f1c2a /tensorflow/contrib/rnn
parente1ba7ee122d218dd39cd423b821078d36b5663d1 (diff)
Test cleanups
- Remove unnecessary use of test_session() in tests that run with eager execution enabled. - Use cached_session() instead of test_session() (self.test_session() has been deprecated in 9962eb5e84b15e309410071b06c2ed2d6148ed44 as its name confuses readers of the test. Moving to cached_session() instead which is more explicit about: * the fact that the session may be reused. * the session is not closed even when doing a "with self.test_session()" statement.) PiperOrigin-RevId: 211562969
Diffstat (limited to 'tensorflow/contrib/rnn')
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py73
1 files changed, 36 insertions, 37 deletions
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
index 15ce9d1ce7..be0306cb07 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
@@ -48,7 +48,7 @@ Linear = core_rnn_cell._Linear # pylint: disable=invalid-name
class RNNCellTest(test.TestCase):
def testLinear(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(1.0)):
x = array_ops.zeros([1, 2])
@@ -69,7 +69,7 @@ class RNNCellTest(test.TestCase):
self.assertEqual(len(variables_lib.trainable_variables()), 2)
def testBasicRNNCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -89,7 +89,7 @@ class RNNCellTest(test.TestCase):
self.assertEqual(res[0].shape, (1, 2))
def testBasicRNNCellNotTrainable(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
def not_trainable_getter(getter, *args, **kwargs):
kwargs["trainable"] = False
@@ -116,7 +116,7 @@ class RNNCellTest(test.TestCase):
self.assertEqual(res[0].shape, (1, 2))
def testIndRNNCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -137,7 +137,7 @@ class RNNCellTest(test.TestCase):
self.assertEqual(res[0].shape, (1, 2))
def testGRUCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -165,7 +165,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[0], [[0.156736, 0.156736]])
def testIndyGRUCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -193,7 +193,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[0], [[0.155127, 0.157328]])
def testSRUCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -208,7 +208,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[0], [[0.509682, 0.509682]])
def testSRUCellWithDiffSize(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3])
@@ -288,7 +288,7 @@ class RNNCellTest(test.TestCase):
def testBasicLSTMCellDimension0Error(self):
"""Tests that dimension 0 in both(x and m) shape must be equal."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
num_units = 2
@@ -309,7 +309,7 @@ class RNNCellTest(test.TestCase):
def testBasicLSTMCellStateSizeError(self):
"""Tests that state_size must be num_units * 2."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
num_units = 2
@@ -329,7 +329,7 @@ class RNNCellTest(test.TestCase):
})
def testBasicLSTMCellStateTupleType(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -360,7 +360,7 @@ class RNNCellTest(test.TestCase):
self.assertTrue(isinstance(out_m1, rnn_cell_impl.LSTMStateTuple))
def testBasicLSTMCellWithStateTuple(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -459,7 +459,7 @@ class RNNCellTest(test.TestCase):
self.assertEqual(len(res), 2)
def testLSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_units = 8
num_proj = 6
state_size = num_units + num_proj
@@ -494,7 +494,7 @@ class RNNCellTest(test.TestCase):
float(np.linalg.norm((res[1][0, :] - res[1][i, :]))) > 1e-6)
def testLSTMCellVariables(self):
- with self.test_session():
+ with self.cached_session():
num_units = 8
num_proj = 6
state_size = num_units + num_proj
@@ -517,7 +517,7 @@ class RNNCellTest(test.TestCase):
"root/lstm_cell/projection/kernel")
def testLSTMCellLayerNorm(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_units = 2
num_proj = 3
batch_size = 1
@@ -562,22 +562,21 @@ class RNNCellTest(test.TestCase):
rnn_cell_impl.DropoutWrapper,
rnn_cell_impl.ResidualWrapper,
lambda cell: rnn_cell_impl.MultiRNNCell([cell])]:
- with self.test_session():
- cell = rnn_cell_impl.BasicRNNCell(1)
- wrapper = wrapper_type(cell)
- wrapper(array_ops.ones([1, 1]),
- state=wrapper.zero_state(batch_size=1, dtype=dtypes.float32))
- self.evaluate([v.initializer for v in cell.variables])
- checkpoint = checkpointable_utils.Checkpoint(wrapper=wrapper)
- prefix = os.path.join(self.get_temp_dir(), "ckpt")
- self.evaluate(cell._bias.assign([40.]))
- save_path = checkpoint.save(prefix)
- self.evaluate(cell._bias.assign([0.]))
- checkpoint.restore(save_path).assert_consumed().run_restore_ops()
- self.assertAllEqual([40.], self.evaluate(cell._bias))
+ cell = rnn_cell_impl.BasicRNNCell(1)
+ wrapper = wrapper_type(cell)
+ wrapper(array_ops.ones([1, 1]),
+ state=wrapper.zero_state(batch_size=1, dtype=dtypes.float32))
+ self.evaluate([v.initializer for v in cell.variables])
+ checkpoint = checkpointable_utils.Checkpoint(wrapper=wrapper)
+ prefix = os.path.join(self.get_temp_dir(), "ckpt")
+ self.evaluate(cell._bias.assign([40.]))
+ save_path = checkpoint.save(prefix)
+ self.evaluate(cell._bias.assign([0.]))
+ checkpoint.restore(save_path).assert_consumed().run_restore_ops()
+ self.assertAllEqual([40.], self.evaluate(cell._bias))
def testOutputProjectionWrapper(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3])
@@ -594,7 +593,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[0], [[0.231907, 0.231907]])
def testInputProjectionWrapper(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -612,7 +611,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[0], [[0.154605, 0.154605, 0.154605]])
def testResidualWrapper(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3])
@@ -638,7 +637,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[2], res[3])
def testResidualWrapperWithSlice(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 5])
@@ -716,7 +715,7 @@ class RNNCellTest(test.TestCase):
self.assertTrue([s for s in gpu_stats if "gru_cell" in s.node_name])
def testEmbeddingWrapper(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 1], dtype=dtypes.int32)
@@ -735,7 +734,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[0], [[0.17139, 0.17139]])
def testEmbeddingWrapperWithDynamicRnn(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope("root"):
inputs = ops.convert_to_tensor([[[0], [0]]], dtype=dtypes.int64)
input_lengths = ops.convert_to_tensor([2], dtype=dtypes.int64)
@@ -753,7 +752,7 @@ class RNNCellTest(test.TestCase):
sess.run(outputs)
def testMultiRNNCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -770,7 +769,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res, [[0.175991, 0.175991, 0.13248, 0.13248]])
def testMultiRNNCellWithStateTuple(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -809,7 +808,7 @@ class DropoutWrapperTest(test.TestCase):
time_steps=None,
parallel_iterations=None,
**kwargs):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
if batch_size is None and time_steps is None: