aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/saver_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/saver_test.py')
-rw-r--r--tensorflow/python/training/saver_test.py95
1 files changed, 60 insertions, 35 deletions
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py
index 78e2b073df..42c1cab802 100644
--- a/tensorflow/python/training/saver_test.py
+++ b/tensorflow/python/training/saver_test.py
@@ -68,18 +68,6 @@ from tensorflow.python.training import saver as saver_module
from tensorflow.python.util import compat
-# pylint: disable=invalid-name
-def _TestDir(test_name):
- test_dir = os.path.join(test.get_temp_dir(), test_name)
- if os.path.exists(test_dir):
- shutil.rmtree(test_dir)
- gfile.MakeDirs(test_dir)
- return test_dir
-
-
-# pylint: enable=invalid-name
-
-
class CheckpointedOp(object):
"""Op with a custom checkpointing implementation.
@@ -591,6 +579,11 @@ class SaverTest(test.TestCase):
class SaveRestoreShardedTest(test.TestCase):
+ def _get_test_dir(self, dirname):
+ test_dir = os.path.join(self.get_temp_dir(), dirname)
+ gfile.MakeDirs(test_dir)
+ return test_dir
+
def testBasics(self):
save_path = os.path.join(self.get_temp_dir(), "sharded_basics")
@@ -719,7 +712,9 @@ class SaveRestoreShardedTest(test.TestCase):
var_full_shape = [10, 3]
# Allows save/restore mechanism to work w/ different slicings.
var_name = "my_var"
- saved_path = os.path.join(_TestDir("partitioned_variables"), "ckpt")
+ saved_dir = self._get_test_dir("partitioned_variables")
+ saved_path = os.path.join(saved_dir, "ckpt")
+
call_saver_with_dict = False # updated by test loop below
def _save(slices=None, partitioner=None):
@@ -842,8 +837,13 @@ class SaveRestoreShardedTest(test.TestCase):
class MaxToKeepTest(test.TestCase):
+ def _get_test_dir(self, dirname):
+ test_dir = os.path.join(self.get_temp_dir(), dirname)
+ gfile.MakeDirs(test_dir)
+ return test_dir
+
def testNonSharded(self):
- save_dir = _TestDir("max_to_keep_non_sharded")
+ save_dir = self._get_test_dir("max_to_keep_non_sharded")
with self.test_session() as sess:
v = variables.Variable(10.0, name="v")
@@ -963,7 +963,7 @@ class MaxToKeepTest(test.TestCase):
saver_module.checkpoint_exists(save._MetaGraphFilename(s1)))
def testSharded(self):
- save_dir = _TestDir("max_to_keep_sharded")
+ save_dir = self._get_test_dir("max_to_keep_sharded")
with session.Session(
target="",
@@ -1018,8 +1018,8 @@ class MaxToKeepTest(test.TestCase):
self.assertTrue(gfile.Exists(save._MetaGraphFilename(s3)))
def testNoMaxToKeep(self):
- save_dir = _TestDir("no_max_to_keep")
- save_dir2 = _TestDir("max_to_keep_0")
+ save_dir = self._get_test_dir("no_max_to_keep")
+ save_dir2 = self._get_test_dir("max_to_keep_0")
with self.test_session() as sess:
v = variables.Variable(10.0, name="v")
@@ -1046,7 +1046,7 @@ class MaxToKeepTest(test.TestCase):
self.assertTrue(saver_module.checkpoint_exists(s2))
def testNoMetaGraph(self):
- save_dir = _TestDir("no_meta_graph")
+ save_dir = self._get_test_dir("no_meta_graph")
with self.test_session() as sess:
v = variables.Variable(10.0, name="v")
@@ -1060,8 +1060,13 @@ class MaxToKeepTest(test.TestCase):
class KeepCheckpointEveryNHoursTest(test.TestCase):
+ def _get_test_dir(self, dirname):
+ test_dir = os.path.join(self.get_temp_dir(), dirname)
+ gfile.MakeDirs(test_dir)
+ return test_dir
+
def testNonSharded(self):
- save_dir = _TestDir("keep_checkpoint_every_n_hours")
+ save_dir = self._get_test_dir("keep_checkpoint_every_n_hours")
with self.test_session() as sess:
v = variables.Variable([10.0], name="v")
@@ -1277,8 +1282,13 @@ class LatestCheckpointWithRelativePaths(test.TestCase):
class CheckpointStateTest(test.TestCase):
+ def _get_test_dir(self, dirname):
+ test_dir = os.path.join(self.get_temp_dir(), dirname)
+ gfile.MakeDirs(test_dir)
+ return test_dir
+
def testAbsPath(self):
- save_dir = _TestDir("abs_paths")
+ save_dir = self._get_test_dir("abs_paths")
abs_path = os.path.join(save_dir, "model-0")
ckpt = saver_module.generate_checkpoint_state_proto(save_dir, abs_path)
self.assertEqual(ckpt.model_checkpoint_path, abs_path)
@@ -1297,7 +1307,7 @@ class CheckpointStateTest(test.TestCase):
self.assertEqual(ckpt.all_model_checkpoint_paths[-1], new_rel_path)
def testAllModelCheckpointPaths(self):
- save_dir = _TestDir("all_models_test")
+ save_dir = self._get_test_dir("all_models_test")
abs_path = os.path.join(save_dir, "model-0")
for paths in [None, [], ["model-2"]]:
ckpt = saver_module.generate_checkpoint_state_proto(
@@ -1309,7 +1319,7 @@ class CheckpointStateTest(test.TestCase):
self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path)
def testUpdateCheckpointState(self):
- save_dir = _TestDir("update_checkpoint_state")
+ save_dir = self._get_test_dir("update_checkpoint_state")
os.chdir(save_dir)
# Make a temporary train directory.
train_dir = "train"
@@ -1325,7 +1335,7 @@ class CheckpointStateTest(test.TestCase):
self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path)
def testCheckPointStateFailsWhenIncomplete(self):
- save_dir = _TestDir("checkpoint_state_fails_when_incomplete")
+ save_dir = self._get_test_dir("checkpoint_state_fails_when_incomplete")
os.chdir(save_dir)
ckpt_path = os.path.join(save_dir, "checkpoint")
ckpt_file = open(ckpt_path, "w")
@@ -1335,7 +1345,7 @@ class CheckpointStateTest(test.TestCase):
saver_module.get_checkpoint_state(save_dir)
def testCheckPointCompletesRelativePaths(self):
- save_dir = _TestDir("checkpoint_completes_relative_paths")
+ save_dir = self._get_test_dir("checkpoint_completes_relative_paths")
os.chdir(save_dir)
ckpt_path = os.path.join(save_dir, "checkpoint")
ckpt_file = open(ckpt_path, "w")
@@ -1356,8 +1366,13 @@ class CheckpointStateTest(test.TestCase):
class MetaGraphTest(test.TestCase):
+ def _get_test_dir(self, dirname):
+ test_dir = os.path.join(self.get_temp_dir(), dirname)
+ gfile.MakeDirs(test_dir)
+ return test_dir
+
def testAddCollectionDef(self):
- test_dir = _TestDir("good_collection")
+ test_dir = self._get_test_dir("good_collection")
filename = os.path.join(test_dir, "metafile")
with self.test_session():
# Creates a graph.
@@ -1504,12 +1519,12 @@ class MetaGraphTest(test.TestCase):
self.assertEqual(11.0, v1.eval())
def testMultiSaverCollection(self):
- test_dir = _TestDir("saver_collection")
+ test_dir = self._get_test_dir("saver_collection")
self._testMultiSaverCollectionSave(test_dir)
self._testMultiSaverCollectionRestore(test_dir)
def testBinaryAndTextFormat(self):
- test_dir = _TestDir("binary_and_text")
+ test_dir = self._get_test_dir("binary_and_text")
filename = os.path.join(test_dir, "metafile")
with self.test_session(graph=ops_lib.Graph()):
# Creates a graph.
@@ -1541,7 +1556,7 @@ class MetaGraphTest(test.TestCase):
saver_module.import_meta_graph(filename)
def testSliceVariable(self):
- test_dir = _TestDir("slice_saver")
+ test_dir = self._get_test_dir("slice_saver")
filename = os.path.join(test_dir, "metafile")
with self.test_session():
v1 = variables.Variable([20.0], name="v1")
@@ -1679,7 +1694,7 @@ class MetaGraphTest(test.TestCase):
sess.run(train_op)
def testGraphExtension(self):
- test_dir = _TestDir("graph_extension")
+ test_dir = self._get_test_dir("graph_extension")
self._testGraphExtensionSave(test_dir)
self._testGraphExtensionRestore(test_dir)
self._testRestoreFromTrainGraphWithControlContext(test_dir)
@@ -1722,7 +1737,7 @@ class MetaGraphTest(test.TestCase):
def testImportIntoNamescope(self):
# Test that we can import a meta graph into a namescope.
- test_dir = _TestDir("import_into_namescope")
+ test_dir = self._get_test_dir("import_into_namescope")
filename = os.path.join(test_dir, "ckpt")
image = array_ops.placeholder(dtypes.float32, [None, 784])
label = array_ops.placeholder(dtypes.float32, [None, 10])
@@ -1870,8 +1885,13 @@ class CheckpointReaderForV2Test(CheckpointReaderTest):
class WriteGraphTest(test.TestCase):
+ def _get_test_dir(self, dirname):
+ test_dir = os.path.join(self.get_temp_dir(), dirname)
+ gfile.MakeDirs(test_dir)
+ return test_dir
+
def testWriteGraph(self):
- test_dir = _TestDir("write_graph_dir")
+ test_dir = self._get_test_dir("write_graph_dir")
variables.Variable([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0")
path = graph_io.write_graph(ops_lib.get_default_graph(),
os.path.join(test_dir, "l1"), "graph.pbtxt")
@@ -1881,7 +1901,7 @@ class WriteGraphTest(test.TestCase):
def testRecursiveCreate(self):
- test_dir = _TestDir("deep_dir")
+ test_dir = self._get_test_dir("deep_dir")
variables.Variable([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0")
path = graph_io.write_graph(ops_lib.get_default_graph().as_graph_def(),
os.path.join(test_dir, "l1", "l2", "l3"),
@@ -1935,6 +1955,11 @@ class SaverUtilsTest(test.TestCase):
class ScopedGraphTest(test.TestCase):
+ def _get_test_dir(self, dirname):
+ test_dir = os.path.join(self.get_temp_dir(), dirname)
+ gfile.MakeDirs(test_dir)
+ return test_dir
+
def _testScopedSave(self, test_dir, exported_filename, ckpt_filename):
graph = ops_lib.Graph()
with graph.as_default():
@@ -2067,7 +2092,7 @@ class ScopedGraphTest(test.TestCase):
# Verifies that we can save the subgraph under "hidden1" and restore it
# into "new_hidden1" in the new graph.
def testScopedSaveAndRestore(self):
- test_dir = _TestDir("scoped_export_import")
+ test_dir = self._get_test_dir("scoped_export_import")
ckpt_filename = "ckpt"
self._testScopedSave(test_dir, "exported_hidden1.pbtxt", ckpt_filename)
self._testScopedRestore(test_dir, "exported_hidden1.pbtxt",
@@ -2076,7 +2101,7 @@ class ScopedGraphTest(test.TestCase):
# Verifies that we can copy the subgraph under "hidden1" and copy it
# to different name scope in the same graph or different graph.
def testCopyScopedGraph(self):
- test_dir = _TestDir("scoped_copy")
+ test_dir = self._get_test_dir("scoped_copy")
saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
graph1 = ops_lib.Graph()
with graph1.as_default():
@@ -2132,7 +2157,7 @@ class ScopedGraphTest(test.TestCase):
self.assertAllClose(expected, sess.run("new_hidden1/relu:0"))
def testExportGraphDefWithScope(self):
- test_dir = _TestDir("export_graph_def")
+ test_dir = self._get_test_dir("export_graph_def")
saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
graph1 = ops_lib.Graph()
with graph1.as_default():