diff options
Diffstat (limited to 'tensorflow/python/training')
-rw-r--r-- | tensorflow/python/training/input.py | 3 | ||||
-rw-r--r-- | tensorflow/python/training/queue_runner_test.py | 2 | ||||
-rw-r--r-- | tensorflow/python/training/saver_test.py | 95 | ||||
-rw-r--r-- | tensorflow/python/training/supervisor_test.py | 73 |
4 files changed, 99 insertions, 74 deletions
diff --git a/tensorflow/python/training/input.py b/tensorflow/python/training/input.py index 042e68df76..cadb404a83 100644 --- a/tensorflow/python/training/input.py +++ b/tensorflow/python/training/input.py @@ -64,7 +64,8 @@ def match_filenames_once(pattern, name=None): """ with ops.name_scope(name, "matching_filenames", [pattern]) as name: return variables.Variable(io_ops.matching_files(pattern), trainable=False, - name=name, validate_shape=False) + name=name, validate_shape=False, + collections=[ops.GraphKeys.LOCAL_VARIABLES]) def limit_epochs(tensor, num_epochs=None, name=None): diff --git a/tensorflow/python/training/queue_runner_test.py b/tensorflow/python/training/queue_runner_test.py index 77317283e9..5b00ac9fc3 100644 --- a/tensorflow/python/training/queue_runner_test.py +++ b/tensorflow/python/training/queue_runner_test.py @@ -164,7 +164,7 @@ class QueueRunnerTest(test.TestCase): coord.request_stop() # We should be able to join because the RequestStop() will cause # the queue to be closed and the enqueue to terminate. - coord.join(stop_grace_period_secs=0.05) + coord.join(stop_grace_period_secs=1.0) def testMultipleSessions(self): with self.test_session() as sess: diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py index 78e2b073df..42c1cab802 100644 --- a/tensorflow/python/training/saver_test.py +++ b/tensorflow/python/training/saver_test.py @@ -68,18 +68,6 @@ from tensorflow.python.training import saver as saver_module from tensorflow.python.util import compat -# pylint: disable=invalid-name -def _TestDir(test_name): - test_dir = os.path.join(test.get_temp_dir(), test_name) - if os.path.exists(test_dir): - shutil.rmtree(test_dir) - gfile.MakeDirs(test_dir) - return test_dir - - -# pylint: enable=invalid-name - - class CheckpointedOp(object): """Op with a custom checkpointing implementation. @@ -591,6 +579,11 @@ class SaverTest(test.TestCase): class SaveRestoreShardedTest(test.TestCase): + def _get_test_dir(self, dirname): + test_dir = os.path.join(self.get_temp_dir(), dirname) + gfile.MakeDirs(test_dir) + return test_dir + def testBasics(self): save_path = os.path.join(self.get_temp_dir(), "sharded_basics") @@ -719,7 +712,9 @@ class SaveRestoreShardedTest(test.TestCase): var_full_shape = [10, 3] # Allows save/restore mechanism to work w/ different slicings. var_name = "my_var" - saved_path = os.path.join(_TestDir("partitioned_variables"), "ckpt") + saved_dir = self._get_test_dir("partitioned_variables") + saved_path = os.path.join(saved_dir, "ckpt") + call_saver_with_dict = False # updated by test loop below def _save(slices=None, partitioner=None): @@ -842,8 +837,13 @@ class SaveRestoreShardedTest(test.TestCase): class MaxToKeepTest(test.TestCase): + def _get_test_dir(self, dirname): + test_dir = os.path.join(self.get_temp_dir(), dirname) + gfile.MakeDirs(test_dir) + return test_dir + def testNonSharded(self): - save_dir = _TestDir("max_to_keep_non_sharded") + save_dir = self._get_test_dir("max_to_keep_non_sharded") with self.test_session() as sess: v = variables.Variable(10.0, name="v") @@ -963,7 +963,7 @@ class MaxToKeepTest(test.TestCase): saver_module.checkpoint_exists(save._MetaGraphFilename(s1))) def testSharded(self): - save_dir = _TestDir("max_to_keep_sharded") + save_dir = self._get_test_dir("max_to_keep_sharded") with session.Session( target="", @@ -1018,8 +1018,8 @@ class MaxToKeepTest(test.TestCase): self.assertTrue(gfile.Exists(save._MetaGraphFilename(s3))) def testNoMaxToKeep(self): - save_dir = _TestDir("no_max_to_keep") - save_dir2 = _TestDir("max_to_keep_0") + save_dir = self._get_test_dir("no_max_to_keep") + save_dir2 = self._get_test_dir("max_to_keep_0") with self.test_session() as sess: v = variables.Variable(10.0, name="v") @@ -1046,7 +1046,7 @@ class MaxToKeepTest(test.TestCase): self.assertTrue(saver_module.checkpoint_exists(s2)) def testNoMetaGraph(self): - save_dir = _TestDir("no_meta_graph") + save_dir = self._get_test_dir("no_meta_graph") with self.test_session() as sess: v = variables.Variable(10.0, name="v") @@ -1060,8 +1060,13 @@ class MaxToKeepTest(test.TestCase): class KeepCheckpointEveryNHoursTest(test.TestCase): + def _get_test_dir(self, dirname): + test_dir = os.path.join(self.get_temp_dir(), dirname) + gfile.MakeDirs(test_dir) + return test_dir + def testNonSharded(self): - save_dir = _TestDir("keep_checkpoint_every_n_hours") + save_dir = self._get_test_dir("keep_checkpoint_every_n_hours") with self.test_session() as sess: v = variables.Variable([10.0], name="v") @@ -1277,8 +1282,13 @@ class LatestCheckpointWithRelativePaths(test.TestCase): class CheckpointStateTest(test.TestCase): + def _get_test_dir(self, dirname): + test_dir = os.path.join(self.get_temp_dir(), dirname) + gfile.MakeDirs(test_dir) + return test_dir + def testAbsPath(self): - save_dir = _TestDir("abs_paths") + save_dir = self._get_test_dir("abs_paths") abs_path = os.path.join(save_dir, "model-0") ckpt = saver_module.generate_checkpoint_state_proto(save_dir, abs_path) self.assertEqual(ckpt.model_checkpoint_path, abs_path) @@ -1297,7 +1307,7 @@ class CheckpointStateTest(test.TestCase): self.assertEqual(ckpt.all_model_checkpoint_paths[-1], new_rel_path) def testAllModelCheckpointPaths(self): - save_dir = _TestDir("all_models_test") + save_dir = self._get_test_dir("all_models_test") abs_path = os.path.join(save_dir, "model-0") for paths in [None, [], ["model-2"]]: ckpt = saver_module.generate_checkpoint_state_proto( @@ -1309,7 +1319,7 @@ class CheckpointStateTest(test.TestCase): self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path) def testUpdateCheckpointState(self): - save_dir = _TestDir("update_checkpoint_state") + save_dir = self._get_test_dir("update_checkpoint_state") os.chdir(save_dir) # Make a temporary train directory. train_dir = "train" @@ -1325,7 +1335,7 @@ class CheckpointStateTest(test.TestCase): self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path) def testCheckPointStateFailsWhenIncomplete(self): - save_dir = _TestDir("checkpoint_state_fails_when_incomplete") + save_dir = self._get_test_dir("checkpoint_state_fails_when_incomplete") os.chdir(save_dir) ckpt_path = os.path.join(save_dir, "checkpoint") ckpt_file = open(ckpt_path, "w") @@ -1335,7 +1345,7 @@ class CheckpointStateTest(test.TestCase): saver_module.get_checkpoint_state(save_dir) def testCheckPointCompletesRelativePaths(self): - save_dir = _TestDir("checkpoint_completes_relative_paths") + save_dir = self._get_test_dir("checkpoint_completes_relative_paths") os.chdir(save_dir) ckpt_path = os.path.join(save_dir, "checkpoint") ckpt_file = open(ckpt_path, "w") @@ -1356,8 +1366,13 @@ class CheckpointStateTest(test.TestCase): class MetaGraphTest(test.TestCase): + def _get_test_dir(self, dirname): + test_dir = os.path.join(self.get_temp_dir(), dirname) + gfile.MakeDirs(test_dir) + return test_dir + def testAddCollectionDef(self): - test_dir = _TestDir("good_collection") + test_dir = self._get_test_dir("good_collection") filename = os.path.join(test_dir, "metafile") with self.test_session(): # Creates a graph. @@ -1504,12 +1519,12 @@ class MetaGraphTest(test.TestCase): self.assertEqual(11.0, v1.eval()) def testMultiSaverCollection(self): - test_dir = _TestDir("saver_collection") + test_dir = self._get_test_dir("saver_collection") self._testMultiSaverCollectionSave(test_dir) self._testMultiSaverCollectionRestore(test_dir) def testBinaryAndTextFormat(self): - test_dir = _TestDir("binary_and_text") + test_dir = self._get_test_dir("binary_and_text") filename = os.path.join(test_dir, "metafile") with self.test_session(graph=ops_lib.Graph()): # Creates a graph. @@ -1541,7 +1556,7 @@ class MetaGraphTest(test.TestCase): saver_module.import_meta_graph(filename) def testSliceVariable(self): - test_dir = _TestDir("slice_saver") + test_dir = self._get_test_dir("slice_saver") filename = os.path.join(test_dir, "metafile") with self.test_session(): v1 = variables.Variable([20.0], name="v1") @@ -1679,7 +1694,7 @@ class MetaGraphTest(test.TestCase): sess.run(train_op) def testGraphExtension(self): - test_dir = _TestDir("graph_extension") + test_dir = self._get_test_dir("graph_extension") self._testGraphExtensionSave(test_dir) self._testGraphExtensionRestore(test_dir) self._testRestoreFromTrainGraphWithControlContext(test_dir) @@ -1722,7 +1737,7 @@ class MetaGraphTest(test.TestCase): def testImportIntoNamescope(self): # Test that we can import a meta graph into a namescope. - test_dir = _TestDir("import_into_namescope") + test_dir = self._get_test_dir("import_into_namescope") filename = os.path.join(test_dir, "ckpt") image = array_ops.placeholder(dtypes.float32, [None, 784]) label = array_ops.placeholder(dtypes.float32, [None, 10]) @@ -1870,8 +1885,13 @@ class CheckpointReaderForV2Test(CheckpointReaderTest): class WriteGraphTest(test.TestCase): + def _get_test_dir(self, dirname): + test_dir = os.path.join(self.get_temp_dir(), dirname) + gfile.MakeDirs(test_dir) + return test_dir + def testWriteGraph(self): - test_dir = _TestDir("write_graph_dir") + test_dir = self._get_test_dir("write_graph_dir") variables.Variable([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0") path = graph_io.write_graph(ops_lib.get_default_graph(), os.path.join(test_dir, "l1"), "graph.pbtxt") @@ -1881,7 +1901,7 @@ class WriteGraphTest(test.TestCase): def testRecursiveCreate(self): - test_dir = _TestDir("deep_dir") + test_dir = self._get_test_dir("deep_dir") variables.Variable([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0") path = graph_io.write_graph(ops_lib.get_default_graph().as_graph_def(), os.path.join(test_dir, "l1", "l2", "l3"), @@ -1935,6 +1955,11 @@ class SaverUtilsTest(test.TestCase): class ScopedGraphTest(test.TestCase): + def _get_test_dir(self, dirname): + test_dir = os.path.join(self.get_temp_dir(), dirname) + gfile.MakeDirs(test_dir) + return test_dir + def _testScopedSave(self, test_dir, exported_filename, ckpt_filename): graph = ops_lib.Graph() with graph.as_default(): @@ -2067,7 +2092,7 @@ class ScopedGraphTest(test.TestCase): # Verifies that we can save the subgraph under "hidden1" and restore it # into "new_hidden1" in the new graph. def testScopedSaveAndRestore(self): - test_dir = _TestDir("scoped_export_import") + test_dir = self._get_test_dir("scoped_export_import") ckpt_filename = "ckpt" self._testScopedSave(test_dir, "exported_hidden1.pbtxt", ckpt_filename) self._testScopedRestore(test_dir, "exported_hidden1.pbtxt", @@ -2076,7 +2101,7 @@ class ScopedGraphTest(test.TestCase): # Verifies that we can copy the subgraph under "hidden1" and copy it # to different name scope in the same graph or different graph. def testCopyScopedGraph(self): - test_dir = _TestDir("scoped_copy") + test_dir = self._get_test_dir("scoped_copy") saver0_ckpt = os.path.join(test_dir, "saver0.ckpt") graph1 = ops_lib.Graph() with graph1.as_default(): @@ -2132,7 +2157,7 @@ class ScopedGraphTest(test.TestCase): self.assertAllClose(expected, sess.run("new_hidden1/relu:0")) def testExportGraphDefWithScope(self): - test_dir = _TestDir("export_graph_def") + test_dir = self._get_test_dir("export_graph_def") saver0_ckpt = os.path.join(test_dir, "saver0.ckpt") graph1 = ops_lib.Graph() with graph1.as_default(): diff --git a/tensorflow/python/training/supervisor_test.py b/tensorflow/python/training/supervisor_test.py index 888f9f930a..4abce85852 100644 --- a/tensorflow/python/training/supervisor_test.py +++ b/tensorflow/python/training/supervisor_test.py @@ -64,15 +64,14 @@ def _summary_iterator(test_dir): return summary_iterator.summary_iterator(event_paths[-1]) -def _test_dir(test_name): - test_dir = os.path.join(test.get_temp_dir(), test_name) - if os.path.exists(test_dir): - shutil.rmtree(test_dir) - return test_dir - - class SupervisorTest(test.TestCase): + def _test_dir(self, test_name): + test_dir = os.path.join(self.get_temp_dir(), test_name) + if os.path.exists(test_dir): + shutil.rmtree(test_dir) + return test_dir + def _wait_for_glob(self, pattern, timeout_secs, for_checkpoint=True): """Wait for a checkpoint file to appear. @@ -94,7 +93,7 @@ class SupervisorTest(test.TestCase): # This test does not test much. def testBasics(self): - logdir = _test_dir("basics") + logdir = self._test_dir("basics") with ops.Graph().as_default(): my_op = constant_op.constant(1.0) sv = supervisor.Supervisor(logdir=logdir) @@ -105,7 +104,7 @@ class SupervisorTest(test.TestCase): sv.stop() def testManagedSession(self): - logdir = _test_dir("managed_session") + logdir = self._test_dir("managed_session") with ops.Graph().as_default(): my_op = constant_op.constant(1.0) sv = supervisor.Supervisor(logdir=logdir) @@ -116,7 +115,7 @@ class SupervisorTest(test.TestCase): self.assertTrue(sv.should_stop()) def testManagedSessionUserError(self): - logdir = _test_dir("managed_user_error") + logdir = self._test_dir("managed_user_error") with ops.Graph().as_default(): my_op = constant_op.constant(1.0) sv = supervisor.Supervisor(logdir=logdir) @@ -134,7 +133,7 @@ class SupervisorTest(test.TestCase): self.assertEqual(1, last_step) def testManagedSessionIgnoreOutOfRangeError(self): - logdir = _test_dir("managed_out_of_range") + logdir = self._test_dir("managed_out_of_range") with ops.Graph().as_default(): my_op = constant_op.constant(1.0) sv = supervisor.Supervisor(logdir=logdir) @@ -152,7 +151,7 @@ class SupervisorTest(test.TestCase): self.assertEqual(3, last_step) def testManagedSessionDoNotKeepSummaryWriter(self): - logdir = _test_dir("managed_not_keep_summary_writer") + logdir = self._test_dir("managed_not_keep_summary_writer") with ops.Graph().as_default(): summary.scalar("c1", constant_op.constant(1)) summary.scalar("c2", constant_op.constant(2)) @@ -204,7 +203,7 @@ class SupervisorTest(test.TestCase): next(rr) def testManagedSessionKeepSummaryWriter(self): - logdir = _test_dir("managed_keep_summary_writer") + logdir = self._test_dir("managed_keep_summary_writer") with ops.Graph().as_default(): summary.scalar("c1", constant_op.constant(1)) summary.scalar("c2", constant_op.constant(2)) @@ -266,7 +265,7 @@ class SupervisorTest(test.TestCase): def testManagedEndOfInputOneQueue(self): # Tests that the supervisor finishes without an error when using # a fixed number of epochs, reading from a single queue. - logdir = _test_dir("managed_end_of_input_one_queue") + logdir = self._test_dir("managed_end_of_input_one_queue") os.makedirs(logdir) data_path = self._csv_data(logdir) with ops.Graph().as_default(): @@ -285,7 +284,7 @@ class SupervisorTest(test.TestCase): # Tests that the supervisor finishes without an error when using # a fixed number of epochs, reading from two queues, the second # one producing a batch from the first one. - logdir = _test_dir("managed_end_of_input_two_queues") + logdir = self._test_dir("managed_end_of_input_two_queues") os.makedirs(logdir) data_path = self._csv_data(logdir) with ops.Graph().as_default(): @@ -304,7 +303,7 @@ class SupervisorTest(test.TestCase): def testManagedMainErrorTwoQueues(self): # Tests that the supervisor correctly raises a main loop # error even when using multiple queues for input. - logdir = _test_dir("managed_main_error_two_queues") + logdir = self._test_dir("managed_main_error_two_queues") os.makedirs(logdir) data_path = self._csv_data(logdir) with self.assertRaisesRegexp(RuntimeError, "fail at step 3"): @@ -327,7 +326,7 @@ class SupervisorTest(test.TestCase): sess.run(shuff_rec) def testSessionConfig(self): - logdir = _test_dir("session_config") + logdir = self._test_dir("session_config") with ops.Graph().as_default(): with ops.device("/cpu:1"): my_op = constant_op.constant([1.0]) @@ -340,7 +339,7 @@ class SupervisorTest(test.TestCase): sv.stop() def testChiefCanWriteEvents(self): - logdir = _test_dir("can_write") + logdir = self._test_dir("can_write") with ops.Graph().as_default(): summary.scalar("c1", constant_op.constant(1)) summary.scalar("c2", constant_op.constant(2)) @@ -421,7 +420,7 @@ class SupervisorTest(test.TestCase): sv.summary_computed(sess, sess.run(summ)) def testLogdirButExplicitlyNoSummaryWriter(self): - logdir = _test_dir("explicit_no_summary_writer") + logdir = self._test_dir("explicit_no_summary_writer") with ops.Graph().as_default(): variables.Variable([1.0], name="foo") summary.scalar("c1", constant_op.constant(1)) @@ -437,7 +436,7 @@ class SupervisorTest(test.TestCase): sv.summary_computed(sess, sess.run(summ)) def testNoLogdirButExplicitSummaryWriter(self): - logdir = _test_dir("explicit_summary_writer") + logdir = self._test_dir("explicit_summary_writer") with ops.Graph().as_default(): summary.scalar("c1", constant_op.constant(1)) summary.scalar("c2", constant_op.constant(2)) @@ -506,7 +505,7 @@ class SupervisorTest(test.TestCase): sv.prepare_or_wait_for_session("") def testInitOp(self): - logdir = _test_dir("default_init_op") + logdir = self._test_dir("default_init_op") with ops.Graph().as_default(): v = variables.Variable([1.0, 2.0, 3.0]) sv = supervisor.Supervisor(logdir=logdir) @@ -515,7 +514,7 @@ class SupervisorTest(test.TestCase): sv.stop() def testInitFn(self): - logdir = _test_dir("default_init_op") + logdir = self._test_dir("default_init_op") with ops.Graph().as_default(): v = variables.Variable([1.0, 2.0, 3.0]) @@ -528,7 +527,7 @@ class SupervisorTest(test.TestCase): sv.stop() def testInitOpWithFeedDict(self): - logdir = _test_dir("feed_dict_init_op") + logdir = self._test_dir("feed_dict_init_op") with ops.Graph().as_default(): p = array_ops.placeholder(dtypes.float32, shape=(3,)) v = variables.Variable(p, name="v") @@ -542,7 +541,7 @@ class SupervisorTest(test.TestCase): def testReadyForLocalInitOp(self): server = server_lib.Server.create_local_server() - logdir = _test_dir("default_ready_for_local_init_op") + logdir = self._test_dir("default_ready_for_local_init_op") uid = uuid.uuid4().hex @@ -584,7 +583,7 @@ class SupervisorTest(test.TestCase): def testReadyForLocalInitOpRestoreFromCheckpoint(self): server = server_lib.Server.create_local_server() - logdir = _test_dir("ready_for_local_init_op_restore") + logdir = self._test_dir("ready_for_local_init_op_restore") uid = uuid.uuid4().hex @@ -639,7 +638,7 @@ class SupervisorTest(test.TestCase): sv1.stop() def testLocalInitOp(self): - logdir = _test_dir("default_local_init_op") + logdir = self._test_dir("default_local_init_op") with ops.Graph().as_default(): # A local variable. v = variables.Variable( @@ -664,7 +663,7 @@ class SupervisorTest(test.TestCase): sv.stop() def testLocalInitOpForNonChief(self): - logdir = _test_dir("default_local_init_op_non_chief") + logdir = self._test_dir("default_local_init_op_non_chief") with ops.Graph().as_default(): with ops.device("/job:localhost"): # A local variable. @@ -685,7 +684,7 @@ class SupervisorTest(test.TestCase): def testInitOpFails(self): server = server_lib.Server.create_local_server() - logdir = _test_dir("default_init_op_fails") + logdir = self._test_dir("default_init_op_fails") with ops.Graph().as_default(): v = variables.Variable([1.0, 2.0, 3.0], name="v") variables.Variable([4.0, 5.0, 6.0], name="w") @@ -697,7 +696,7 @@ class SupervisorTest(test.TestCase): def testInitOpFailsForTransientVariable(self): server = server_lib.Server.create_local_server() - logdir = _test_dir("default_init_op_fails_for_local_variable") + logdir = self._test_dir("default_init_op_fails_for_local_variable") with ops.Graph().as_default(): v = variables.Variable( [1.0, 2.0, 3.0], @@ -714,7 +713,7 @@ class SupervisorTest(test.TestCase): sv.prepare_or_wait_for_session(server.target) def testSetupFail(self): - logdir = _test_dir("setup_fail") + logdir = self._test_dir("setup_fail") with ops.Graph().as_default(): variables.Variable([1.0, 2.0, 3.0], name="v") with self.assertRaisesRegexp(ValueError, "must have their device set"): @@ -724,7 +723,7 @@ class SupervisorTest(test.TestCase): supervisor.Supervisor(logdir=logdir, is_chief=False) def testDefaultGlobalStep(self): - logdir = _test_dir("default_global_step") + logdir = self._test_dir("default_global_step") with ops.Graph().as_default(): variables.Variable(287, name="global_step") sv = supervisor.Supervisor(logdir=logdir) @@ -733,7 +732,7 @@ class SupervisorTest(test.TestCase): sv.stop() def testRestoreFromMetaGraph(self): - logdir = _test_dir("restore_from_meta_graph") + logdir = self._test_dir("restore_from_meta_graph") with ops.Graph().as_default(): variables.Variable(1, name="v0") sv = supervisor.Supervisor(logdir=logdir) @@ -754,7 +753,7 @@ class SupervisorTest(test.TestCase): # right away and get to run once before sv.stop() returns. # We still sleep a bit to make the test robust. def testStandardServicesWithoutGlobalStep(self): - logdir = _test_dir("standard_services_without_global_step") + logdir = self._test_dir("standard_services_without_global_step") # Create a checkpoint. with ops.Graph().as_default(): v = variables.Variable([1.0], name="foo") @@ -804,7 +803,7 @@ class SupervisorTest(test.TestCase): # Same as testStandardServicesNoGlobalStep but with a global step. # We should get a summary about the step time. def testStandardServicesWithGlobalStep(self): - logdir = _test_dir("standard_services_with_global_step") + logdir = self._test_dir("standard_services_with_global_step") # Create a checkpoint. with ops.Graph().as_default(): v = variables.Variable([123], name="global_step") @@ -867,12 +866,12 @@ class SupervisorTest(test.TestCase): def testNoQueueRunners(self): with ops.Graph().as_default(), self.test_session() as sess: - sv = supervisor.Supervisor(logdir=_test_dir("no_queue_runners")) + sv = supervisor.Supervisor(logdir=self._test_dir("no_queue_runners")) self.assertEqual(0, len(sv.start_queue_runners(sess))) sv.stop() def testPrepareSessionAfterStopForChief(self): - logdir = _test_dir("prepare_after_stop_chief") + logdir = self._test_dir("prepare_after_stop_chief") with ops.Graph().as_default(): sv = supervisor.Supervisor(logdir=logdir, is_chief=True) @@ -891,7 +890,7 @@ class SupervisorTest(test.TestCase): self.assertTrue(sv.should_stop()) def testPrepareSessionAfterStopForNonChief(self): - logdir = _test_dir("prepare_after_stop_nonchief") + logdir = self._test_dir("prepare_after_stop_nonchief") with ops.Graph().as_default(): sv = supervisor.Supervisor(logdir=logdir, is_chief=False) |