diff options
Diffstat (limited to 'tensorflow/python/training/checkpoint_utils_test.py')
-rw-r--r-- | tensorflow/python/training/checkpoint_utils_test.py | 24 |
1 files changed, 12 insertions, 12 deletions
diff --git a/tensorflow/python/training/checkpoint_utils_test.py b/tensorflow/python/training/checkpoint_utils_test.py index 1aab16338a..61dcbdb2b8 100644 --- a/tensorflow/python/training/checkpoint_utils_test.py +++ b/tensorflow/python/training/checkpoint_utils_test.py @@ -84,7 +84,7 @@ class CheckpointsTest(test.TestCase): def testNoTensor(self): checkpoint_dir = self.get_temp_dir() - with self.test_session() as session: + with self.cached_session() as session: _, _, _, _ = _create_checkpoints(session, checkpoint_dir) with self.assertRaises(errors_impl.OpError): self.assertAllEqual( @@ -92,7 +92,7 @@ class CheckpointsTest(test.TestCase): def testGetTensor(self): checkpoint_dir = self.get_temp_dir() - with self.test_session() as session: + with self.cached_session() as session: v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir) self.assertAllEqual( checkpoint_utils.load_variable(checkpoint_dir, "var1"), v1) @@ -105,7 +105,7 @@ class CheckpointsTest(test.TestCase): def testGetAllVariables(self): checkpoint_dir = self.get_temp_dir() - with self.test_session() as session: + with self.cached_session() as session: _create_checkpoints(session, checkpoint_dir) self.assertEqual( checkpoint_utils.list_variables(checkpoint_dir), @@ -114,7 +114,7 @@ class CheckpointsTest(test.TestCase): def testInitFromCheckpoint(self): checkpoint_dir = self.get_temp_dir() - with self.test_session() as session: + with self.cached_session() as session: v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir) # New graph and session. @@ -148,7 +148,7 @@ class CheckpointsTest(test.TestCase): def testInitialValueComesFromCheckpoint(self): checkpoint_dir = self.get_temp_dir() - with self.test_session() as session: + with self.cached_session() as session: v1, _, _, _ = _create_checkpoints(session, checkpoint_dir) # New graph and session. @@ -178,7 +178,7 @@ class CheckpointsTest(test.TestCase): def testInitWithScopeDoesNotCaptureSuffixes(self): checkpoint_dir = self.get_temp_dir() - with self.test_session() as session: + with self.cached_session() as session: _, _, _, v4 = _create_checkpoints(session, checkpoint_dir) with ops.Graph().as_default() as g: @@ -197,7 +197,7 @@ class CheckpointsTest(test.TestCase): def testRestoreRunsOnSameDevice(self): checkpoint_dir = self.get_temp_dir() - with self.test_session() as session: + with self.cached_session() as session: _create_checkpoints(session, checkpoint_dir) with ops.Graph().as_default(): @@ -213,7 +213,7 @@ class CheckpointsTest(test.TestCase): def testInitFromRootCheckpoint(self): checkpoint_dir = self.get_temp_dir() - with self.test_session() as session: + with self.cached_session() as session: v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir) # New graph and session. @@ -237,7 +237,7 @@ class CheckpointsTest(test.TestCase): def testInitToRootCheckpoint(self): checkpoint_dir = self.get_temp_dir() - with self.test_session() as session: + with self.cached_session() as session: v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir) # New graph and session. @@ -260,7 +260,7 @@ class CheckpointsTest(test.TestCase): def testInitFromPartitionVar(self): checkpoint_dir = self.get_temp_dir() - with self.test_session() as session: + with self.cached_session() as session: v1 = _create_partition_checkpoints(session, checkpoint_dir) # New graph and session. @@ -322,7 +322,7 @@ class CheckpointsTest(test.TestCase): def testInitFromCheckpointMissing(self): checkpoint_dir = self.get_temp_dir() - with self.test_session() as session: + with self.cached_session() as session: _, _, _, _ = _create_checkpoints(session, checkpoint_dir) # New graph and session. @@ -367,7 +367,7 @@ class CheckpointsTest(test.TestCase): def testNoAdditionalReadOpsForResourceVariables(self): checkpoint_dir = self.get_temp_dir() - with self.test_session() as session: + with self.cached_session() as session: v1, _, _, _ = _create_checkpoints(session, checkpoint_dir) # New graph and session. |