aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/supervisor_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/supervisor_test.py')
-rw-r--r--tensorflow/python/training/supervisor_test.py73
1 files changed, 36 insertions, 37 deletions
diff --git a/tensorflow/python/training/supervisor_test.py b/tensorflow/python/training/supervisor_test.py
index 888f9f930a..4abce85852 100644
--- a/tensorflow/python/training/supervisor_test.py
+++ b/tensorflow/python/training/supervisor_test.py
@@ -64,15 +64,14 @@ def _summary_iterator(test_dir):
return summary_iterator.summary_iterator(event_paths[-1])
-def _test_dir(test_name):
- test_dir = os.path.join(test.get_temp_dir(), test_name)
- if os.path.exists(test_dir):
- shutil.rmtree(test_dir)
- return test_dir
-
-
class SupervisorTest(test.TestCase):
+ def _test_dir(self, test_name):
+ test_dir = os.path.join(self.get_temp_dir(), test_name)
+ if os.path.exists(test_dir):
+ shutil.rmtree(test_dir)
+ return test_dir
+
def _wait_for_glob(self, pattern, timeout_secs, for_checkpoint=True):
"""Wait for a checkpoint file to appear.
@@ -94,7 +93,7 @@ class SupervisorTest(test.TestCase):
# This test does not test much.
def testBasics(self):
- logdir = _test_dir("basics")
+ logdir = self._test_dir("basics")
with ops.Graph().as_default():
my_op = constant_op.constant(1.0)
sv = supervisor.Supervisor(logdir=logdir)
@@ -105,7 +104,7 @@ class SupervisorTest(test.TestCase):
sv.stop()
def testManagedSession(self):
- logdir = _test_dir("managed_session")
+ logdir = self._test_dir("managed_session")
with ops.Graph().as_default():
my_op = constant_op.constant(1.0)
sv = supervisor.Supervisor(logdir=logdir)
@@ -116,7 +115,7 @@ class SupervisorTest(test.TestCase):
self.assertTrue(sv.should_stop())
def testManagedSessionUserError(self):
- logdir = _test_dir("managed_user_error")
+ logdir = self._test_dir("managed_user_error")
with ops.Graph().as_default():
my_op = constant_op.constant(1.0)
sv = supervisor.Supervisor(logdir=logdir)
@@ -134,7 +133,7 @@ class SupervisorTest(test.TestCase):
self.assertEqual(1, last_step)
def testManagedSessionIgnoreOutOfRangeError(self):
- logdir = _test_dir("managed_out_of_range")
+ logdir = self._test_dir("managed_out_of_range")
with ops.Graph().as_default():
my_op = constant_op.constant(1.0)
sv = supervisor.Supervisor(logdir=logdir)
@@ -152,7 +151,7 @@ class SupervisorTest(test.TestCase):
self.assertEqual(3, last_step)
def testManagedSessionDoNotKeepSummaryWriter(self):
- logdir = _test_dir("managed_not_keep_summary_writer")
+ logdir = self._test_dir("managed_not_keep_summary_writer")
with ops.Graph().as_default():
summary.scalar("c1", constant_op.constant(1))
summary.scalar("c2", constant_op.constant(2))
@@ -204,7 +203,7 @@ class SupervisorTest(test.TestCase):
next(rr)
def testManagedSessionKeepSummaryWriter(self):
- logdir = _test_dir("managed_keep_summary_writer")
+ logdir = self._test_dir("managed_keep_summary_writer")
with ops.Graph().as_default():
summary.scalar("c1", constant_op.constant(1))
summary.scalar("c2", constant_op.constant(2))
@@ -266,7 +265,7 @@ class SupervisorTest(test.TestCase):
def testManagedEndOfInputOneQueue(self):
# Tests that the supervisor finishes without an error when using
# a fixed number of epochs, reading from a single queue.
- logdir = _test_dir("managed_end_of_input_one_queue")
+ logdir = self._test_dir("managed_end_of_input_one_queue")
os.makedirs(logdir)
data_path = self._csv_data(logdir)
with ops.Graph().as_default():
@@ -285,7 +284,7 @@ class SupervisorTest(test.TestCase):
# Tests that the supervisor finishes without an error when using
# a fixed number of epochs, reading from two queues, the second
# one producing a batch from the first one.
- logdir = _test_dir("managed_end_of_input_two_queues")
+ logdir = self._test_dir("managed_end_of_input_two_queues")
os.makedirs(logdir)
data_path = self._csv_data(logdir)
with ops.Graph().as_default():
@@ -304,7 +303,7 @@ class SupervisorTest(test.TestCase):
def testManagedMainErrorTwoQueues(self):
# Tests that the supervisor correctly raises a main loop
# error even when using multiple queues for input.
- logdir = _test_dir("managed_main_error_two_queues")
+ logdir = self._test_dir("managed_main_error_two_queues")
os.makedirs(logdir)
data_path = self._csv_data(logdir)
with self.assertRaisesRegexp(RuntimeError, "fail at step 3"):
@@ -327,7 +326,7 @@ class SupervisorTest(test.TestCase):
sess.run(shuff_rec)
def testSessionConfig(self):
- logdir = _test_dir("session_config")
+ logdir = self._test_dir("session_config")
with ops.Graph().as_default():
with ops.device("/cpu:1"):
my_op = constant_op.constant([1.0])
@@ -340,7 +339,7 @@ class SupervisorTest(test.TestCase):
sv.stop()
def testChiefCanWriteEvents(self):
- logdir = _test_dir("can_write")
+ logdir = self._test_dir("can_write")
with ops.Graph().as_default():
summary.scalar("c1", constant_op.constant(1))
summary.scalar("c2", constant_op.constant(2))
@@ -421,7 +420,7 @@ class SupervisorTest(test.TestCase):
sv.summary_computed(sess, sess.run(summ))
def testLogdirButExplicitlyNoSummaryWriter(self):
- logdir = _test_dir("explicit_no_summary_writer")
+ logdir = self._test_dir("explicit_no_summary_writer")
with ops.Graph().as_default():
variables.Variable([1.0], name="foo")
summary.scalar("c1", constant_op.constant(1))
@@ -437,7 +436,7 @@ class SupervisorTest(test.TestCase):
sv.summary_computed(sess, sess.run(summ))
def testNoLogdirButExplicitSummaryWriter(self):
- logdir = _test_dir("explicit_summary_writer")
+ logdir = self._test_dir("explicit_summary_writer")
with ops.Graph().as_default():
summary.scalar("c1", constant_op.constant(1))
summary.scalar("c2", constant_op.constant(2))
@@ -506,7 +505,7 @@ class SupervisorTest(test.TestCase):
sv.prepare_or_wait_for_session("")
def testInitOp(self):
- logdir = _test_dir("default_init_op")
+ logdir = self._test_dir("default_init_op")
with ops.Graph().as_default():
v = variables.Variable([1.0, 2.0, 3.0])
sv = supervisor.Supervisor(logdir=logdir)
@@ -515,7 +514,7 @@ class SupervisorTest(test.TestCase):
sv.stop()
def testInitFn(self):
- logdir = _test_dir("default_init_op")
+ logdir = self._test_dir("default_init_op")
with ops.Graph().as_default():
v = variables.Variable([1.0, 2.0, 3.0])
@@ -528,7 +527,7 @@ class SupervisorTest(test.TestCase):
sv.stop()
def testInitOpWithFeedDict(self):
- logdir = _test_dir("feed_dict_init_op")
+ logdir = self._test_dir("feed_dict_init_op")
with ops.Graph().as_default():
p = array_ops.placeholder(dtypes.float32, shape=(3,))
v = variables.Variable(p, name="v")
@@ -542,7 +541,7 @@ class SupervisorTest(test.TestCase):
def testReadyForLocalInitOp(self):
server = server_lib.Server.create_local_server()
- logdir = _test_dir("default_ready_for_local_init_op")
+ logdir = self._test_dir("default_ready_for_local_init_op")
uid = uuid.uuid4().hex
@@ -584,7 +583,7 @@ class SupervisorTest(test.TestCase):
def testReadyForLocalInitOpRestoreFromCheckpoint(self):
server = server_lib.Server.create_local_server()
- logdir = _test_dir("ready_for_local_init_op_restore")
+ logdir = self._test_dir("ready_for_local_init_op_restore")
uid = uuid.uuid4().hex
@@ -639,7 +638,7 @@ class SupervisorTest(test.TestCase):
sv1.stop()
def testLocalInitOp(self):
- logdir = _test_dir("default_local_init_op")
+ logdir = self._test_dir("default_local_init_op")
with ops.Graph().as_default():
# A local variable.
v = variables.Variable(
@@ -664,7 +663,7 @@ class SupervisorTest(test.TestCase):
sv.stop()
def testLocalInitOpForNonChief(self):
- logdir = _test_dir("default_local_init_op_non_chief")
+ logdir = self._test_dir("default_local_init_op_non_chief")
with ops.Graph().as_default():
with ops.device("/job:localhost"):
# A local variable.
@@ -685,7 +684,7 @@ class SupervisorTest(test.TestCase):
def testInitOpFails(self):
server = server_lib.Server.create_local_server()
- logdir = _test_dir("default_init_op_fails")
+ logdir = self._test_dir("default_init_op_fails")
with ops.Graph().as_default():
v = variables.Variable([1.0, 2.0, 3.0], name="v")
variables.Variable([4.0, 5.0, 6.0], name="w")
@@ -697,7 +696,7 @@ class SupervisorTest(test.TestCase):
def testInitOpFailsForTransientVariable(self):
server = server_lib.Server.create_local_server()
- logdir = _test_dir("default_init_op_fails_for_local_variable")
+ logdir = self._test_dir("default_init_op_fails_for_local_variable")
with ops.Graph().as_default():
v = variables.Variable(
[1.0, 2.0, 3.0],
@@ -714,7 +713,7 @@ class SupervisorTest(test.TestCase):
sv.prepare_or_wait_for_session(server.target)
def testSetupFail(self):
- logdir = _test_dir("setup_fail")
+ logdir = self._test_dir("setup_fail")
with ops.Graph().as_default():
variables.Variable([1.0, 2.0, 3.0], name="v")
with self.assertRaisesRegexp(ValueError, "must have their device set"):
@@ -724,7 +723,7 @@ class SupervisorTest(test.TestCase):
supervisor.Supervisor(logdir=logdir, is_chief=False)
def testDefaultGlobalStep(self):
- logdir = _test_dir("default_global_step")
+ logdir = self._test_dir("default_global_step")
with ops.Graph().as_default():
variables.Variable(287, name="global_step")
sv = supervisor.Supervisor(logdir=logdir)
@@ -733,7 +732,7 @@ class SupervisorTest(test.TestCase):
sv.stop()
def testRestoreFromMetaGraph(self):
- logdir = _test_dir("restore_from_meta_graph")
+ logdir = self._test_dir("restore_from_meta_graph")
with ops.Graph().as_default():
variables.Variable(1, name="v0")
sv = supervisor.Supervisor(logdir=logdir)
@@ -754,7 +753,7 @@ class SupervisorTest(test.TestCase):
# right away and get to run once before sv.stop() returns.
# We still sleep a bit to make the test robust.
def testStandardServicesWithoutGlobalStep(self):
- logdir = _test_dir("standard_services_without_global_step")
+ logdir = self._test_dir("standard_services_without_global_step")
# Create a checkpoint.
with ops.Graph().as_default():
v = variables.Variable([1.0], name="foo")
@@ -804,7 +803,7 @@ class SupervisorTest(test.TestCase):
# Same as testStandardServicesNoGlobalStep but with a global step.
# We should get a summary about the step time.
def testStandardServicesWithGlobalStep(self):
- logdir = _test_dir("standard_services_with_global_step")
+ logdir = self._test_dir("standard_services_with_global_step")
# Create a checkpoint.
with ops.Graph().as_default():
v = variables.Variable([123], name="global_step")
@@ -867,12 +866,12 @@ class SupervisorTest(test.TestCase):
def testNoQueueRunners(self):
with ops.Graph().as_default(), self.test_session() as sess:
- sv = supervisor.Supervisor(logdir=_test_dir("no_queue_runners"))
+ sv = supervisor.Supervisor(logdir=self._test_dir("no_queue_runners"))
self.assertEqual(0, len(sv.start_queue_runners(sess)))
sv.stop()
def testPrepareSessionAfterStopForChief(self):
- logdir = _test_dir("prepare_after_stop_chief")
+ logdir = self._test_dir("prepare_after_stop_chief")
with ops.Graph().as_default():
sv = supervisor.Supervisor(logdir=logdir, is_chief=True)
@@ -891,7 +890,7 @@ class SupervisorTest(test.TestCase):
self.assertTrue(sv.should_stop())
def testPrepareSessionAfterStopForNonChief(self):
- logdir = _test_dir("prepare_after_stop_nonchief")
+ logdir = self._test_dir("prepare_after_stop_nonchief")
with ops.Graph().as_default():
sv = supervisor.Supervisor(logdir=logdir, is_chief=False)