aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/checkpoint_utils_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/checkpoint_utils_test.py')
-rw-r--r--tensorflow/python/training/checkpoint_utils_test.py24
1 files changed, 12 insertions, 12 deletions
diff --git a/tensorflow/python/training/checkpoint_utils_test.py b/tensorflow/python/training/checkpoint_utils_test.py
index 1aab16338a..61dcbdb2b8 100644
--- a/tensorflow/python/training/checkpoint_utils_test.py
+++ b/tensorflow/python/training/checkpoint_utils_test.py
@@ -84,7 +84,7 @@ class CheckpointsTest(test.TestCase):
def testNoTensor(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
_, _, _, _ = _create_checkpoints(session, checkpoint_dir)
with self.assertRaises(errors_impl.OpError):
self.assertAllEqual(
@@ -92,7 +92,7 @@ class CheckpointsTest(test.TestCase):
def testGetTensor(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)
self.assertAllEqual(
checkpoint_utils.load_variable(checkpoint_dir, "var1"), v1)
@@ -105,7 +105,7 @@ class CheckpointsTest(test.TestCase):
def testGetAllVariables(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
_create_checkpoints(session, checkpoint_dir)
self.assertEqual(
checkpoint_utils.list_variables(checkpoint_dir),
@@ -114,7 +114,7 @@ class CheckpointsTest(test.TestCase):
def testInitFromCheckpoint(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)
# New graph and session.
@@ -148,7 +148,7 @@ class CheckpointsTest(test.TestCase):
def testInitialValueComesFromCheckpoint(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
v1, _, _, _ = _create_checkpoints(session, checkpoint_dir)
# New graph and session.
@@ -178,7 +178,7 @@ class CheckpointsTest(test.TestCase):
def testInitWithScopeDoesNotCaptureSuffixes(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
_, _, _, v4 = _create_checkpoints(session, checkpoint_dir)
with ops.Graph().as_default() as g:
@@ -197,7 +197,7 @@ class CheckpointsTest(test.TestCase):
def testRestoreRunsOnSameDevice(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
_create_checkpoints(session, checkpoint_dir)
with ops.Graph().as_default():
@@ -213,7 +213,7 @@ class CheckpointsTest(test.TestCase):
def testInitFromRootCheckpoint(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)
# New graph and session.
@@ -237,7 +237,7 @@ class CheckpointsTest(test.TestCase):
def testInitToRootCheckpoint(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)
# New graph and session.
@@ -260,7 +260,7 @@ class CheckpointsTest(test.TestCase):
def testInitFromPartitionVar(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
v1 = _create_partition_checkpoints(session, checkpoint_dir)
# New graph and session.
@@ -322,7 +322,7 @@ class CheckpointsTest(test.TestCase):
def testInitFromCheckpointMissing(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
_, _, _, _ = _create_checkpoints(session, checkpoint_dir)
# New graph and session.
@@ -367,7 +367,7 @@ class CheckpointsTest(test.TestCase):
def testNoAdditionalReadOpsForResourceVariables(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
v1, _, _, _ = _create_checkpoints(session, checkpoint_dir)
# New graph and session.