diff options
Diffstat (limited to 'tensorflow/python/training/supervisor_test.py')
-rw-r--r-- | tensorflow/python/training/supervisor_test.py | 73 |
1 files changed, 36 insertions, 37 deletions
diff --git a/tensorflow/python/training/supervisor_test.py b/tensorflow/python/training/supervisor_test.py index 888f9f930a..4abce85852 100644 --- a/tensorflow/python/training/supervisor_test.py +++ b/tensorflow/python/training/supervisor_test.py @@ -64,15 +64,14 @@ def _summary_iterator(test_dir): return summary_iterator.summary_iterator(event_paths[-1]) -def _test_dir(test_name): - test_dir = os.path.join(test.get_temp_dir(), test_name) - if os.path.exists(test_dir): - shutil.rmtree(test_dir) - return test_dir - - class SupervisorTest(test.TestCase): + def _test_dir(self, test_name): + test_dir = os.path.join(self.get_temp_dir(), test_name) + if os.path.exists(test_dir): + shutil.rmtree(test_dir) + return test_dir + def _wait_for_glob(self, pattern, timeout_secs, for_checkpoint=True): """Wait for a checkpoint file to appear. @@ -94,7 +93,7 @@ class SupervisorTest(test.TestCase): # This test does not test much. def testBasics(self): - logdir = _test_dir("basics") + logdir = self._test_dir("basics") with ops.Graph().as_default(): my_op = constant_op.constant(1.0) sv = supervisor.Supervisor(logdir=logdir) @@ -105,7 +104,7 @@ class SupervisorTest(test.TestCase): sv.stop() def testManagedSession(self): - logdir = _test_dir("managed_session") + logdir = self._test_dir("managed_session") with ops.Graph().as_default(): my_op = constant_op.constant(1.0) sv = supervisor.Supervisor(logdir=logdir) @@ -116,7 +115,7 @@ class SupervisorTest(test.TestCase): self.assertTrue(sv.should_stop()) def testManagedSessionUserError(self): - logdir = _test_dir("managed_user_error") + logdir = self._test_dir("managed_user_error") with ops.Graph().as_default(): my_op = constant_op.constant(1.0) sv = supervisor.Supervisor(logdir=logdir) @@ -134,7 +133,7 @@ class SupervisorTest(test.TestCase): self.assertEqual(1, last_step) def testManagedSessionIgnoreOutOfRangeError(self): - logdir = _test_dir("managed_out_of_range") + logdir = self._test_dir("managed_out_of_range") with ops.Graph().as_default(): my_op = constant_op.constant(1.0) sv = supervisor.Supervisor(logdir=logdir) @@ -152,7 +151,7 @@ class SupervisorTest(test.TestCase): self.assertEqual(3, last_step) def testManagedSessionDoNotKeepSummaryWriter(self): - logdir = _test_dir("managed_not_keep_summary_writer") + logdir = self._test_dir("managed_not_keep_summary_writer") with ops.Graph().as_default(): summary.scalar("c1", constant_op.constant(1)) summary.scalar("c2", constant_op.constant(2)) @@ -204,7 +203,7 @@ class SupervisorTest(test.TestCase): next(rr) def testManagedSessionKeepSummaryWriter(self): - logdir = _test_dir("managed_keep_summary_writer") + logdir = self._test_dir("managed_keep_summary_writer") with ops.Graph().as_default(): summary.scalar("c1", constant_op.constant(1)) summary.scalar("c2", constant_op.constant(2)) @@ -266,7 +265,7 @@ class SupervisorTest(test.TestCase): def testManagedEndOfInputOneQueue(self): # Tests that the supervisor finishes without an error when using # a fixed number of epochs, reading from a single queue. - logdir = _test_dir("managed_end_of_input_one_queue") + logdir = self._test_dir("managed_end_of_input_one_queue") os.makedirs(logdir) data_path = self._csv_data(logdir) with ops.Graph().as_default(): @@ -285,7 +284,7 @@ class SupervisorTest(test.TestCase): # Tests that the supervisor finishes without an error when using # a fixed number of epochs, reading from two queues, the second # one producing a batch from the first one. - logdir = _test_dir("managed_end_of_input_two_queues") + logdir = self._test_dir("managed_end_of_input_two_queues") os.makedirs(logdir) data_path = self._csv_data(logdir) with ops.Graph().as_default(): @@ -304,7 +303,7 @@ class SupervisorTest(test.TestCase): def testManagedMainErrorTwoQueues(self): # Tests that the supervisor correctly raises a main loop # error even when using multiple queues for input. - logdir = _test_dir("managed_main_error_two_queues") + logdir = self._test_dir("managed_main_error_two_queues") os.makedirs(logdir) data_path = self._csv_data(logdir) with self.assertRaisesRegexp(RuntimeError, "fail at step 3"): @@ -327,7 +326,7 @@ class SupervisorTest(test.TestCase): sess.run(shuff_rec) def testSessionConfig(self): - logdir = _test_dir("session_config") + logdir = self._test_dir("session_config") with ops.Graph().as_default(): with ops.device("/cpu:1"): my_op = constant_op.constant([1.0]) @@ -340,7 +339,7 @@ class SupervisorTest(test.TestCase): sv.stop() def testChiefCanWriteEvents(self): - logdir = _test_dir("can_write") + logdir = self._test_dir("can_write") with ops.Graph().as_default(): summary.scalar("c1", constant_op.constant(1)) summary.scalar("c2", constant_op.constant(2)) @@ -421,7 +420,7 @@ class SupervisorTest(test.TestCase): sv.summary_computed(sess, sess.run(summ)) def testLogdirButExplicitlyNoSummaryWriter(self): - logdir = _test_dir("explicit_no_summary_writer") + logdir = self._test_dir("explicit_no_summary_writer") with ops.Graph().as_default(): variables.Variable([1.0], name="foo") summary.scalar("c1", constant_op.constant(1)) @@ -437,7 +436,7 @@ class SupervisorTest(test.TestCase): sv.summary_computed(sess, sess.run(summ)) def testNoLogdirButExplicitSummaryWriter(self): - logdir = _test_dir("explicit_summary_writer") + logdir = self._test_dir("explicit_summary_writer") with ops.Graph().as_default(): summary.scalar("c1", constant_op.constant(1)) summary.scalar("c2", constant_op.constant(2)) @@ -506,7 +505,7 @@ class SupervisorTest(test.TestCase): sv.prepare_or_wait_for_session("") def testInitOp(self): - logdir = _test_dir("default_init_op") + logdir = self._test_dir("default_init_op") with ops.Graph().as_default(): v = variables.Variable([1.0, 2.0, 3.0]) sv = supervisor.Supervisor(logdir=logdir) @@ -515,7 +514,7 @@ class SupervisorTest(test.TestCase): sv.stop() def testInitFn(self): - logdir = _test_dir("default_init_op") + logdir = self._test_dir("default_init_op") with ops.Graph().as_default(): v = variables.Variable([1.0, 2.0, 3.0]) @@ -528,7 +527,7 @@ class SupervisorTest(test.TestCase): sv.stop() def testInitOpWithFeedDict(self): - logdir = _test_dir("feed_dict_init_op") + logdir = self._test_dir("feed_dict_init_op") with ops.Graph().as_default(): p = array_ops.placeholder(dtypes.float32, shape=(3,)) v = variables.Variable(p, name="v") @@ -542,7 +541,7 @@ class SupervisorTest(test.TestCase): def testReadyForLocalInitOp(self): server = server_lib.Server.create_local_server() - logdir = _test_dir("default_ready_for_local_init_op") + logdir = self._test_dir("default_ready_for_local_init_op") uid = uuid.uuid4().hex @@ -584,7 +583,7 @@ class SupervisorTest(test.TestCase): def testReadyForLocalInitOpRestoreFromCheckpoint(self): server = server_lib.Server.create_local_server() - logdir = _test_dir("ready_for_local_init_op_restore") + logdir = self._test_dir("ready_for_local_init_op_restore") uid = uuid.uuid4().hex @@ -639,7 +638,7 @@ class SupervisorTest(test.TestCase): sv1.stop() def testLocalInitOp(self): - logdir = _test_dir("default_local_init_op") + logdir = self._test_dir("default_local_init_op") with ops.Graph().as_default(): # A local variable. v = variables.Variable( @@ -664,7 +663,7 @@ class SupervisorTest(test.TestCase): sv.stop() def testLocalInitOpForNonChief(self): - logdir = _test_dir("default_local_init_op_non_chief") + logdir = self._test_dir("default_local_init_op_non_chief") with ops.Graph().as_default(): with ops.device("/job:localhost"): # A local variable. @@ -685,7 +684,7 @@ class SupervisorTest(test.TestCase): def testInitOpFails(self): server = server_lib.Server.create_local_server() - logdir = _test_dir("default_init_op_fails") + logdir = self._test_dir("default_init_op_fails") with ops.Graph().as_default(): v = variables.Variable([1.0, 2.0, 3.0], name="v") variables.Variable([4.0, 5.0, 6.0], name="w") @@ -697,7 +696,7 @@ class SupervisorTest(test.TestCase): def testInitOpFailsForTransientVariable(self): server = server_lib.Server.create_local_server() - logdir = _test_dir("default_init_op_fails_for_local_variable") + logdir = self._test_dir("default_init_op_fails_for_local_variable") with ops.Graph().as_default(): v = variables.Variable( [1.0, 2.0, 3.0], @@ -714,7 +713,7 @@ class SupervisorTest(test.TestCase): sv.prepare_or_wait_for_session(server.target) def testSetupFail(self): - logdir = _test_dir("setup_fail") + logdir = self._test_dir("setup_fail") with ops.Graph().as_default(): variables.Variable([1.0, 2.0, 3.0], name="v") with self.assertRaisesRegexp(ValueError, "must have their device set"): @@ -724,7 +723,7 @@ class SupervisorTest(test.TestCase): supervisor.Supervisor(logdir=logdir, is_chief=False) def testDefaultGlobalStep(self): - logdir = _test_dir("default_global_step") + logdir = self._test_dir("default_global_step") with ops.Graph().as_default(): variables.Variable(287, name="global_step") sv = supervisor.Supervisor(logdir=logdir) @@ -733,7 +732,7 @@ class SupervisorTest(test.TestCase): sv.stop() def testRestoreFromMetaGraph(self): - logdir = _test_dir("restore_from_meta_graph") + logdir = self._test_dir("restore_from_meta_graph") with ops.Graph().as_default(): variables.Variable(1, name="v0") sv = supervisor.Supervisor(logdir=logdir) @@ -754,7 +753,7 @@ class SupervisorTest(test.TestCase): # right away and get to run once before sv.stop() returns. # We still sleep a bit to make the test robust. def testStandardServicesWithoutGlobalStep(self): - logdir = _test_dir("standard_services_without_global_step") + logdir = self._test_dir("standard_services_without_global_step") # Create a checkpoint. with ops.Graph().as_default(): v = variables.Variable([1.0], name="foo") @@ -804,7 +803,7 @@ class SupervisorTest(test.TestCase): # Same as testStandardServicesNoGlobalStep but with a global step. # We should get a summary about the step time. def testStandardServicesWithGlobalStep(self): - logdir = _test_dir("standard_services_with_global_step") + logdir = self._test_dir("standard_services_with_global_step") # Create a checkpoint. with ops.Graph().as_default(): v = variables.Variable([123], name="global_step") @@ -867,12 +866,12 @@ class SupervisorTest(test.TestCase): def testNoQueueRunners(self): with ops.Graph().as_default(), self.test_session() as sess: - sv = supervisor.Supervisor(logdir=_test_dir("no_queue_runners")) + sv = supervisor.Supervisor(logdir=self._test_dir("no_queue_runners")) self.assertEqual(0, len(sv.start_queue_runners(sess))) sv.stop() def testPrepareSessionAfterStopForChief(self): - logdir = _test_dir("prepare_after_stop_chief") + logdir = self._test_dir("prepare_after_stop_chief") with ops.Graph().as_default(): sv = supervisor.Supervisor(logdir=logdir, is_chief=True) @@ -891,7 +890,7 @@ class SupervisorTest(test.TestCase): self.assertTrue(sv.should_stop()) def testPrepareSessionAfterStopForNonChief(self): - logdir = _test_dir("prepare_after_stop_nonchief") + logdir = self._test_dir("prepare_after_stop_nonchief") with ops.Graph().as_default(): sv = supervisor.Supervisor(logdir=logdir, is_chief=False) |