diff options
Diffstat (limited to 'tensorflow/python/training/saver_test.py')
-rw-r--r-- | tensorflow/python/training/saver_test.py | 95 |
1 files changed, 60 insertions, 35 deletions
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py index 78e2b073df..42c1cab802 100644 --- a/tensorflow/python/training/saver_test.py +++ b/tensorflow/python/training/saver_test.py @@ -68,18 +68,6 @@ from tensorflow.python.training import saver as saver_module from tensorflow.python.util import compat -# pylint: disable=invalid-name -def _TestDir(test_name): - test_dir = os.path.join(test.get_temp_dir(), test_name) - if os.path.exists(test_dir): - shutil.rmtree(test_dir) - gfile.MakeDirs(test_dir) - return test_dir - - -# pylint: enable=invalid-name - - class CheckpointedOp(object): """Op with a custom checkpointing implementation. @@ -591,6 +579,11 @@ class SaverTest(test.TestCase): class SaveRestoreShardedTest(test.TestCase): + def _get_test_dir(self, dirname): + test_dir = os.path.join(self.get_temp_dir(), dirname) + gfile.MakeDirs(test_dir) + return test_dir + def testBasics(self): save_path = os.path.join(self.get_temp_dir(), "sharded_basics") @@ -719,7 +712,9 @@ class SaveRestoreShardedTest(test.TestCase): var_full_shape = [10, 3] # Allows save/restore mechanism to work w/ different slicings. var_name = "my_var" - saved_path = os.path.join(_TestDir("partitioned_variables"), "ckpt") + saved_dir = self._get_test_dir("partitioned_variables") + saved_path = os.path.join(saved_dir, "ckpt") + call_saver_with_dict = False # updated by test loop below def _save(slices=None, partitioner=None): @@ -842,8 +837,13 @@ class SaveRestoreShardedTest(test.TestCase): class MaxToKeepTest(test.TestCase): + def _get_test_dir(self, dirname): + test_dir = os.path.join(self.get_temp_dir(), dirname) + gfile.MakeDirs(test_dir) + return test_dir + def testNonSharded(self): - save_dir = _TestDir("max_to_keep_non_sharded") + save_dir = self._get_test_dir("max_to_keep_non_sharded") with self.test_session() as sess: v = variables.Variable(10.0, name="v") @@ -963,7 +963,7 @@ class MaxToKeepTest(test.TestCase): saver_module.checkpoint_exists(save._MetaGraphFilename(s1))) def testSharded(self): - save_dir = _TestDir("max_to_keep_sharded") + save_dir = self._get_test_dir("max_to_keep_sharded") with session.Session( target="", @@ -1018,8 +1018,8 @@ class MaxToKeepTest(test.TestCase): self.assertTrue(gfile.Exists(save._MetaGraphFilename(s3))) def testNoMaxToKeep(self): - save_dir = _TestDir("no_max_to_keep") - save_dir2 = _TestDir("max_to_keep_0") + save_dir = self._get_test_dir("no_max_to_keep") + save_dir2 = self._get_test_dir("max_to_keep_0") with self.test_session() as sess: v = variables.Variable(10.0, name="v") @@ -1046,7 +1046,7 @@ class MaxToKeepTest(test.TestCase): self.assertTrue(saver_module.checkpoint_exists(s2)) def testNoMetaGraph(self): - save_dir = _TestDir("no_meta_graph") + save_dir = self._get_test_dir("no_meta_graph") with self.test_session() as sess: v = variables.Variable(10.0, name="v") @@ -1060,8 +1060,13 @@ class MaxToKeepTest(test.TestCase): class KeepCheckpointEveryNHoursTest(test.TestCase): + def _get_test_dir(self, dirname): + test_dir = os.path.join(self.get_temp_dir(), dirname) + gfile.MakeDirs(test_dir) + return test_dir + def testNonSharded(self): - save_dir = _TestDir("keep_checkpoint_every_n_hours") + save_dir = self._get_test_dir("keep_checkpoint_every_n_hours") with self.test_session() as sess: v = variables.Variable([10.0], name="v") @@ -1277,8 +1282,13 @@ class LatestCheckpointWithRelativePaths(test.TestCase): class CheckpointStateTest(test.TestCase): + def _get_test_dir(self, dirname): + test_dir = os.path.join(self.get_temp_dir(), dirname) + gfile.MakeDirs(test_dir) + return test_dir + def testAbsPath(self): - save_dir = _TestDir("abs_paths") + save_dir = self._get_test_dir("abs_paths") abs_path = os.path.join(save_dir, "model-0") ckpt = saver_module.generate_checkpoint_state_proto(save_dir, abs_path) self.assertEqual(ckpt.model_checkpoint_path, abs_path) @@ -1297,7 +1307,7 @@ class CheckpointStateTest(test.TestCase): self.assertEqual(ckpt.all_model_checkpoint_paths[-1], new_rel_path) def testAllModelCheckpointPaths(self): - save_dir = _TestDir("all_models_test") + save_dir = self._get_test_dir("all_models_test") abs_path = os.path.join(save_dir, "model-0") for paths in [None, [], ["model-2"]]: ckpt = saver_module.generate_checkpoint_state_proto( @@ -1309,7 +1319,7 @@ class CheckpointStateTest(test.TestCase): self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path) def testUpdateCheckpointState(self): - save_dir = _TestDir("update_checkpoint_state") + save_dir = self._get_test_dir("update_checkpoint_state") os.chdir(save_dir) # Make a temporary train directory. train_dir = "train" @@ -1325,7 +1335,7 @@ class CheckpointStateTest(test.TestCase): self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path) def testCheckPointStateFailsWhenIncomplete(self): - save_dir = _TestDir("checkpoint_state_fails_when_incomplete") + save_dir = self._get_test_dir("checkpoint_state_fails_when_incomplete") os.chdir(save_dir) ckpt_path = os.path.join(save_dir, "checkpoint") ckpt_file = open(ckpt_path, "w") @@ -1335,7 +1345,7 @@ class CheckpointStateTest(test.TestCase): saver_module.get_checkpoint_state(save_dir) def testCheckPointCompletesRelativePaths(self): - save_dir = _TestDir("checkpoint_completes_relative_paths") + save_dir = self._get_test_dir("checkpoint_completes_relative_paths") os.chdir(save_dir) ckpt_path = os.path.join(save_dir, "checkpoint") ckpt_file = open(ckpt_path, "w") @@ -1356,8 +1366,13 @@ class CheckpointStateTest(test.TestCase): class MetaGraphTest(test.TestCase): + def _get_test_dir(self, dirname): + test_dir = os.path.join(self.get_temp_dir(), dirname) + gfile.MakeDirs(test_dir) + return test_dir + def testAddCollectionDef(self): - test_dir = _TestDir("good_collection") + test_dir = self._get_test_dir("good_collection") filename = os.path.join(test_dir, "metafile") with self.test_session(): # Creates a graph. @@ -1504,12 +1519,12 @@ class MetaGraphTest(test.TestCase): self.assertEqual(11.0, v1.eval()) def testMultiSaverCollection(self): - test_dir = _TestDir("saver_collection") + test_dir = self._get_test_dir("saver_collection") self._testMultiSaverCollectionSave(test_dir) self._testMultiSaverCollectionRestore(test_dir) def testBinaryAndTextFormat(self): - test_dir = _TestDir("binary_and_text") + test_dir = self._get_test_dir("binary_and_text") filename = os.path.join(test_dir, "metafile") with self.test_session(graph=ops_lib.Graph()): # Creates a graph. @@ -1541,7 +1556,7 @@ class MetaGraphTest(test.TestCase): saver_module.import_meta_graph(filename) def testSliceVariable(self): - test_dir = _TestDir("slice_saver") + test_dir = self._get_test_dir("slice_saver") filename = os.path.join(test_dir, "metafile") with self.test_session(): v1 = variables.Variable([20.0], name="v1") @@ -1679,7 +1694,7 @@ class MetaGraphTest(test.TestCase): sess.run(train_op) def testGraphExtension(self): - test_dir = _TestDir("graph_extension") + test_dir = self._get_test_dir("graph_extension") self._testGraphExtensionSave(test_dir) self._testGraphExtensionRestore(test_dir) self._testRestoreFromTrainGraphWithControlContext(test_dir) @@ -1722,7 +1737,7 @@ class MetaGraphTest(test.TestCase): def testImportIntoNamescope(self): # Test that we can import a meta graph into a namescope. - test_dir = _TestDir("import_into_namescope") + test_dir = self._get_test_dir("import_into_namescope") filename = os.path.join(test_dir, "ckpt") image = array_ops.placeholder(dtypes.float32, [None, 784]) label = array_ops.placeholder(dtypes.float32, [None, 10]) @@ -1870,8 +1885,13 @@ class CheckpointReaderForV2Test(CheckpointReaderTest): class WriteGraphTest(test.TestCase): + def _get_test_dir(self, dirname): + test_dir = os.path.join(self.get_temp_dir(), dirname) + gfile.MakeDirs(test_dir) + return test_dir + def testWriteGraph(self): - test_dir = _TestDir("write_graph_dir") + test_dir = self._get_test_dir("write_graph_dir") variables.Variable([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0") path = graph_io.write_graph(ops_lib.get_default_graph(), os.path.join(test_dir, "l1"), "graph.pbtxt") @@ -1881,7 +1901,7 @@ class WriteGraphTest(test.TestCase): def testRecursiveCreate(self): - test_dir = _TestDir("deep_dir") + test_dir = self._get_test_dir("deep_dir") variables.Variable([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0") path = graph_io.write_graph(ops_lib.get_default_graph().as_graph_def(), os.path.join(test_dir, "l1", "l2", "l3"), @@ -1935,6 +1955,11 @@ class SaverUtilsTest(test.TestCase): class ScopedGraphTest(test.TestCase): + def _get_test_dir(self, dirname): + test_dir = os.path.join(self.get_temp_dir(), dirname) + gfile.MakeDirs(test_dir) + return test_dir + def _testScopedSave(self, test_dir, exported_filename, ckpt_filename): graph = ops_lib.Graph() with graph.as_default(): @@ -2067,7 +2092,7 @@ class ScopedGraphTest(test.TestCase): # Verifies that we can save the subgraph under "hidden1" and restore it # into "new_hidden1" in the new graph. def testScopedSaveAndRestore(self): - test_dir = _TestDir("scoped_export_import") + test_dir = self._get_test_dir("scoped_export_import") ckpt_filename = "ckpt" self._testScopedSave(test_dir, "exported_hidden1.pbtxt", ckpt_filename) self._testScopedRestore(test_dir, "exported_hidden1.pbtxt", @@ -2076,7 +2101,7 @@ class ScopedGraphTest(test.TestCase): # Verifies that we can copy the subgraph under "hidden1" and copy it # to different name scope in the same graph or different graph. def testCopyScopedGraph(self): - test_dir = _TestDir("scoped_copy") + test_dir = self._get_test_dir("scoped_copy") saver0_ckpt = os.path.join(test_dir, "saver0.ckpt") graph1 = ops_lib.Graph() with graph1.as_default(): @@ -2132,7 +2157,7 @@ class ScopedGraphTest(test.TestCase): self.assertAllClose(expected, sess.run("new_hidden1/relu:0")) def testExportGraphDefWithScope(self): - test_dir = _TestDir("export_graph_def") + test_dir = self._get_test_dir("export_graph_def") saver0_ckpt = os.path.join(test_dir, "saver0.ckpt") graph1 = ops_lib.Graph() with graph1.as_default(): |