aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training')
-rw-r--r--tensorflow/python/training/input.py3
-rw-r--r--tensorflow/python/training/queue_runner_test.py2
-rw-r--r--tensorflow/python/training/saver_test.py95
-rw-r--r--tensorflow/python/training/supervisor_test.py73
4 files changed, 99 insertions, 74 deletions
diff --git a/tensorflow/python/training/input.py b/tensorflow/python/training/input.py
index 042e68df76..cadb404a83 100644
--- a/tensorflow/python/training/input.py
+++ b/tensorflow/python/training/input.py
@@ -64,7 +64,8 @@ def match_filenames_once(pattern, name=None):
"""
with ops.name_scope(name, "matching_filenames", [pattern]) as name:
return variables.Variable(io_ops.matching_files(pattern), trainable=False,
- name=name, validate_shape=False)
+ name=name, validate_shape=False,
+ collections=[ops.GraphKeys.LOCAL_VARIABLES])
def limit_epochs(tensor, num_epochs=None, name=None):
diff --git a/tensorflow/python/training/queue_runner_test.py b/tensorflow/python/training/queue_runner_test.py
index 77317283e9..5b00ac9fc3 100644
--- a/tensorflow/python/training/queue_runner_test.py
+++ b/tensorflow/python/training/queue_runner_test.py
@@ -164,7 +164,7 @@ class QueueRunnerTest(test.TestCase):
coord.request_stop()
# We should be able to join because the RequestStop() will cause
# the queue to be closed and the enqueue to terminate.
- coord.join(stop_grace_period_secs=0.05)
+ coord.join(stop_grace_period_secs=1.0)
def testMultipleSessions(self):
with self.test_session() as sess:
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py
index 78e2b073df..42c1cab802 100644
--- a/tensorflow/python/training/saver_test.py
+++ b/tensorflow/python/training/saver_test.py
@@ -68,18 +68,6 @@ from tensorflow.python.training import saver as saver_module
from tensorflow.python.util import compat
-# pylint: disable=invalid-name
-def _TestDir(test_name):
- test_dir = os.path.join(test.get_temp_dir(), test_name)
- if os.path.exists(test_dir):
- shutil.rmtree(test_dir)
- gfile.MakeDirs(test_dir)
- return test_dir
-
-
-# pylint: enable=invalid-name
-
-
class CheckpointedOp(object):
"""Op with a custom checkpointing implementation.
@@ -591,6 +579,11 @@ class SaverTest(test.TestCase):
class SaveRestoreShardedTest(test.TestCase):
+ def _get_test_dir(self, dirname):
+ test_dir = os.path.join(self.get_temp_dir(), dirname)
+ gfile.MakeDirs(test_dir)
+ return test_dir
+
def testBasics(self):
save_path = os.path.join(self.get_temp_dir(), "sharded_basics")
@@ -719,7 +712,9 @@ class SaveRestoreShardedTest(test.TestCase):
var_full_shape = [10, 3]
# Allows save/restore mechanism to work w/ different slicings.
var_name = "my_var"
- saved_path = os.path.join(_TestDir("partitioned_variables"), "ckpt")
+ saved_dir = self._get_test_dir("partitioned_variables")
+ saved_path = os.path.join(saved_dir, "ckpt")
+
call_saver_with_dict = False # updated by test loop below
def _save(slices=None, partitioner=None):
@@ -842,8 +837,13 @@ class SaveRestoreShardedTest(test.TestCase):
class MaxToKeepTest(test.TestCase):
+ def _get_test_dir(self, dirname):
+ test_dir = os.path.join(self.get_temp_dir(), dirname)
+ gfile.MakeDirs(test_dir)
+ return test_dir
+
def testNonSharded(self):
- save_dir = _TestDir("max_to_keep_non_sharded")
+ save_dir = self._get_test_dir("max_to_keep_non_sharded")
with self.test_session() as sess:
v = variables.Variable(10.0, name="v")
@@ -963,7 +963,7 @@ class MaxToKeepTest(test.TestCase):
saver_module.checkpoint_exists(save._MetaGraphFilename(s1)))
def testSharded(self):
- save_dir = _TestDir("max_to_keep_sharded")
+ save_dir = self._get_test_dir("max_to_keep_sharded")
with session.Session(
target="",
@@ -1018,8 +1018,8 @@ class MaxToKeepTest(test.TestCase):
self.assertTrue(gfile.Exists(save._MetaGraphFilename(s3)))
def testNoMaxToKeep(self):
- save_dir = _TestDir("no_max_to_keep")
- save_dir2 = _TestDir("max_to_keep_0")
+ save_dir = self._get_test_dir("no_max_to_keep")
+ save_dir2 = self._get_test_dir("max_to_keep_0")
with self.test_session() as sess:
v = variables.Variable(10.0, name="v")
@@ -1046,7 +1046,7 @@ class MaxToKeepTest(test.TestCase):
self.assertTrue(saver_module.checkpoint_exists(s2))
def testNoMetaGraph(self):
- save_dir = _TestDir("no_meta_graph")
+ save_dir = self._get_test_dir("no_meta_graph")
with self.test_session() as sess:
v = variables.Variable(10.0, name="v")
@@ -1060,8 +1060,13 @@ class MaxToKeepTest(test.TestCase):
class KeepCheckpointEveryNHoursTest(test.TestCase):
+ def _get_test_dir(self, dirname):
+ test_dir = os.path.join(self.get_temp_dir(), dirname)
+ gfile.MakeDirs(test_dir)
+ return test_dir
+
def testNonSharded(self):
- save_dir = _TestDir("keep_checkpoint_every_n_hours")
+ save_dir = self._get_test_dir("keep_checkpoint_every_n_hours")
with self.test_session() as sess:
v = variables.Variable([10.0], name="v")
@@ -1277,8 +1282,13 @@ class LatestCheckpointWithRelativePaths(test.TestCase):
class CheckpointStateTest(test.TestCase):
+ def _get_test_dir(self, dirname):
+ test_dir = os.path.join(self.get_temp_dir(), dirname)
+ gfile.MakeDirs(test_dir)
+ return test_dir
+
def testAbsPath(self):
- save_dir = _TestDir("abs_paths")
+ save_dir = self._get_test_dir("abs_paths")
abs_path = os.path.join(save_dir, "model-0")
ckpt = saver_module.generate_checkpoint_state_proto(save_dir, abs_path)
self.assertEqual(ckpt.model_checkpoint_path, abs_path)
@@ -1297,7 +1307,7 @@ class CheckpointStateTest(test.TestCase):
self.assertEqual(ckpt.all_model_checkpoint_paths[-1], new_rel_path)
def testAllModelCheckpointPaths(self):
- save_dir = _TestDir("all_models_test")
+ save_dir = self._get_test_dir("all_models_test")
abs_path = os.path.join(save_dir, "model-0")
for paths in [None, [], ["model-2"]]:
ckpt = saver_module.generate_checkpoint_state_proto(
@@ -1309,7 +1319,7 @@ class CheckpointStateTest(test.TestCase):
self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path)
def testUpdateCheckpointState(self):
- save_dir = _TestDir("update_checkpoint_state")
+ save_dir = self._get_test_dir("update_checkpoint_state")
os.chdir(save_dir)
# Make a temporary train directory.
train_dir = "train"
@@ -1325,7 +1335,7 @@ class CheckpointStateTest(test.TestCase):
self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path)
def testCheckPointStateFailsWhenIncomplete(self):
- save_dir = _TestDir("checkpoint_state_fails_when_incomplete")
+ save_dir = self._get_test_dir("checkpoint_state_fails_when_incomplete")
os.chdir(save_dir)
ckpt_path = os.path.join(save_dir, "checkpoint")
ckpt_file = open(ckpt_path, "w")
@@ -1335,7 +1345,7 @@ class CheckpointStateTest(test.TestCase):
saver_module.get_checkpoint_state(save_dir)
def testCheckPointCompletesRelativePaths(self):
- save_dir = _TestDir("checkpoint_completes_relative_paths")
+ save_dir = self._get_test_dir("checkpoint_completes_relative_paths")
os.chdir(save_dir)
ckpt_path = os.path.join(save_dir, "checkpoint")
ckpt_file = open(ckpt_path, "w")
@@ -1356,8 +1366,13 @@ class CheckpointStateTest(test.TestCase):
class MetaGraphTest(test.TestCase):
+ def _get_test_dir(self, dirname):
+ test_dir = os.path.join(self.get_temp_dir(), dirname)
+ gfile.MakeDirs(test_dir)
+ return test_dir
+
def testAddCollectionDef(self):
- test_dir = _TestDir("good_collection")
+ test_dir = self._get_test_dir("good_collection")
filename = os.path.join(test_dir, "metafile")
with self.test_session():
# Creates a graph.
@@ -1504,12 +1519,12 @@ class MetaGraphTest(test.TestCase):
self.assertEqual(11.0, v1.eval())
def testMultiSaverCollection(self):
- test_dir = _TestDir("saver_collection")
+ test_dir = self._get_test_dir("saver_collection")
self._testMultiSaverCollectionSave(test_dir)
self._testMultiSaverCollectionRestore(test_dir)
def testBinaryAndTextFormat(self):
- test_dir = _TestDir("binary_and_text")
+ test_dir = self._get_test_dir("binary_and_text")
filename = os.path.join(test_dir, "metafile")
with self.test_session(graph=ops_lib.Graph()):
# Creates a graph.
@@ -1541,7 +1556,7 @@ class MetaGraphTest(test.TestCase):
saver_module.import_meta_graph(filename)
def testSliceVariable(self):
- test_dir = _TestDir("slice_saver")
+ test_dir = self._get_test_dir("slice_saver")
filename = os.path.join(test_dir, "metafile")
with self.test_session():
v1 = variables.Variable([20.0], name="v1")
@@ -1679,7 +1694,7 @@ class MetaGraphTest(test.TestCase):
sess.run(train_op)
def testGraphExtension(self):
- test_dir = _TestDir("graph_extension")
+ test_dir = self._get_test_dir("graph_extension")
self._testGraphExtensionSave(test_dir)
self._testGraphExtensionRestore(test_dir)
self._testRestoreFromTrainGraphWithControlContext(test_dir)
@@ -1722,7 +1737,7 @@ class MetaGraphTest(test.TestCase):
def testImportIntoNamescope(self):
# Test that we can import a meta graph into a namescope.
- test_dir = _TestDir("import_into_namescope")
+ test_dir = self._get_test_dir("import_into_namescope")
filename = os.path.join(test_dir, "ckpt")
image = array_ops.placeholder(dtypes.float32, [None, 784])
label = array_ops.placeholder(dtypes.float32, [None, 10])
@@ -1870,8 +1885,13 @@ class CheckpointReaderForV2Test(CheckpointReaderTest):
class WriteGraphTest(test.TestCase):
+ def _get_test_dir(self, dirname):
+ test_dir = os.path.join(self.get_temp_dir(), dirname)
+ gfile.MakeDirs(test_dir)
+ return test_dir
+
def testWriteGraph(self):
- test_dir = _TestDir("write_graph_dir")
+ test_dir = self._get_test_dir("write_graph_dir")
variables.Variable([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0")
path = graph_io.write_graph(ops_lib.get_default_graph(),
os.path.join(test_dir, "l1"), "graph.pbtxt")
@@ -1881,7 +1901,7 @@ class WriteGraphTest(test.TestCase):
def testRecursiveCreate(self):
- test_dir = _TestDir("deep_dir")
+ test_dir = self._get_test_dir("deep_dir")
variables.Variable([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0")
path = graph_io.write_graph(ops_lib.get_default_graph().as_graph_def(),
os.path.join(test_dir, "l1", "l2", "l3"),
@@ -1935,6 +1955,11 @@ class SaverUtilsTest(test.TestCase):
class ScopedGraphTest(test.TestCase):
+ def _get_test_dir(self, dirname):
+ test_dir = os.path.join(self.get_temp_dir(), dirname)
+ gfile.MakeDirs(test_dir)
+ return test_dir
+
def _testScopedSave(self, test_dir, exported_filename, ckpt_filename):
graph = ops_lib.Graph()
with graph.as_default():
@@ -2067,7 +2092,7 @@ class ScopedGraphTest(test.TestCase):
# Verifies that we can save the subgraph under "hidden1" and restore it
# into "new_hidden1" in the new graph.
def testScopedSaveAndRestore(self):
- test_dir = _TestDir("scoped_export_import")
+ test_dir = self._get_test_dir("scoped_export_import")
ckpt_filename = "ckpt"
self._testScopedSave(test_dir, "exported_hidden1.pbtxt", ckpt_filename)
self._testScopedRestore(test_dir, "exported_hidden1.pbtxt",
@@ -2076,7 +2101,7 @@ class ScopedGraphTest(test.TestCase):
# Verifies that we can copy the subgraph under "hidden1" and copy it
# to different name scope in the same graph or different graph.
def testCopyScopedGraph(self):
- test_dir = _TestDir("scoped_copy")
+ test_dir = self._get_test_dir("scoped_copy")
saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
graph1 = ops_lib.Graph()
with graph1.as_default():
@@ -2132,7 +2157,7 @@ class ScopedGraphTest(test.TestCase):
self.assertAllClose(expected, sess.run("new_hidden1/relu:0"))
def testExportGraphDefWithScope(self):
- test_dir = _TestDir("export_graph_def")
+ test_dir = self._get_test_dir("export_graph_def")
saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
graph1 = ops_lib.Graph()
with graph1.as_default():
diff --git a/tensorflow/python/training/supervisor_test.py b/tensorflow/python/training/supervisor_test.py
index 888f9f930a..4abce85852 100644
--- a/tensorflow/python/training/supervisor_test.py
+++ b/tensorflow/python/training/supervisor_test.py
@@ -64,15 +64,14 @@ def _summary_iterator(test_dir):
return summary_iterator.summary_iterator(event_paths[-1])
-def _test_dir(test_name):
- test_dir = os.path.join(test.get_temp_dir(), test_name)
- if os.path.exists(test_dir):
- shutil.rmtree(test_dir)
- return test_dir
-
-
class SupervisorTest(test.TestCase):
+ def _test_dir(self, test_name):
+ test_dir = os.path.join(self.get_temp_dir(), test_name)
+ if os.path.exists(test_dir):
+ shutil.rmtree(test_dir)
+ return test_dir
+
def _wait_for_glob(self, pattern, timeout_secs, for_checkpoint=True):
"""Wait for a checkpoint file to appear.
@@ -94,7 +93,7 @@ class SupervisorTest(test.TestCase):
# This test does not test much.
def testBasics(self):
- logdir = _test_dir("basics")
+ logdir = self._test_dir("basics")
with ops.Graph().as_default():
my_op = constant_op.constant(1.0)
sv = supervisor.Supervisor(logdir=logdir)
@@ -105,7 +104,7 @@ class SupervisorTest(test.TestCase):
sv.stop()
def testManagedSession(self):
- logdir = _test_dir("managed_session")
+ logdir = self._test_dir("managed_session")
with ops.Graph().as_default():
my_op = constant_op.constant(1.0)
sv = supervisor.Supervisor(logdir=logdir)
@@ -116,7 +115,7 @@ class SupervisorTest(test.TestCase):
self.assertTrue(sv.should_stop())
def testManagedSessionUserError(self):
- logdir = _test_dir("managed_user_error")
+ logdir = self._test_dir("managed_user_error")
with ops.Graph().as_default():
my_op = constant_op.constant(1.0)
sv = supervisor.Supervisor(logdir=logdir)
@@ -134,7 +133,7 @@ class SupervisorTest(test.TestCase):
self.assertEqual(1, last_step)
def testManagedSessionIgnoreOutOfRangeError(self):
- logdir = _test_dir("managed_out_of_range")
+ logdir = self._test_dir("managed_out_of_range")
with ops.Graph().as_default():
my_op = constant_op.constant(1.0)
sv = supervisor.Supervisor(logdir=logdir)
@@ -152,7 +151,7 @@ class SupervisorTest(test.TestCase):
self.assertEqual(3, last_step)
def testManagedSessionDoNotKeepSummaryWriter(self):
- logdir = _test_dir("managed_not_keep_summary_writer")
+ logdir = self._test_dir("managed_not_keep_summary_writer")
with ops.Graph().as_default():
summary.scalar("c1", constant_op.constant(1))
summary.scalar("c2", constant_op.constant(2))
@@ -204,7 +203,7 @@ class SupervisorTest(test.TestCase):
next(rr)
def testManagedSessionKeepSummaryWriter(self):
- logdir = _test_dir("managed_keep_summary_writer")
+ logdir = self._test_dir("managed_keep_summary_writer")
with ops.Graph().as_default():
summary.scalar("c1", constant_op.constant(1))
summary.scalar("c2", constant_op.constant(2))
@@ -266,7 +265,7 @@ class SupervisorTest(test.TestCase):
def testManagedEndOfInputOneQueue(self):
# Tests that the supervisor finishes without an error when using
# a fixed number of epochs, reading from a single queue.
- logdir = _test_dir("managed_end_of_input_one_queue")
+ logdir = self._test_dir("managed_end_of_input_one_queue")
os.makedirs(logdir)
data_path = self._csv_data(logdir)
with ops.Graph().as_default():
@@ -285,7 +284,7 @@ class SupervisorTest(test.TestCase):
# Tests that the supervisor finishes without an error when using
# a fixed number of epochs, reading from two queues, the second
# one producing a batch from the first one.
- logdir = _test_dir("managed_end_of_input_two_queues")
+ logdir = self._test_dir("managed_end_of_input_two_queues")
os.makedirs(logdir)
data_path = self._csv_data(logdir)
with ops.Graph().as_default():
@@ -304,7 +303,7 @@ class SupervisorTest(test.TestCase):
def testManagedMainErrorTwoQueues(self):
# Tests that the supervisor correctly raises a main loop
# error even when using multiple queues for input.
- logdir = _test_dir("managed_main_error_two_queues")
+ logdir = self._test_dir("managed_main_error_two_queues")
os.makedirs(logdir)
data_path = self._csv_data(logdir)
with self.assertRaisesRegexp(RuntimeError, "fail at step 3"):
@@ -327,7 +326,7 @@ class SupervisorTest(test.TestCase):
sess.run(shuff_rec)
def testSessionConfig(self):
- logdir = _test_dir("session_config")
+ logdir = self._test_dir("session_config")
with ops.Graph().as_default():
with ops.device("/cpu:1"):
my_op = constant_op.constant([1.0])
@@ -340,7 +339,7 @@ class SupervisorTest(test.TestCase):
sv.stop()
def testChiefCanWriteEvents(self):
- logdir = _test_dir("can_write")
+ logdir = self._test_dir("can_write")
with ops.Graph().as_default():
summary.scalar("c1", constant_op.constant(1))
summary.scalar("c2", constant_op.constant(2))
@@ -421,7 +420,7 @@ class SupervisorTest(test.TestCase):
sv.summary_computed(sess, sess.run(summ))
def testLogdirButExplicitlyNoSummaryWriter(self):
- logdir = _test_dir("explicit_no_summary_writer")
+ logdir = self._test_dir("explicit_no_summary_writer")
with ops.Graph().as_default():
variables.Variable([1.0], name="foo")
summary.scalar("c1", constant_op.constant(1))
@@ -437,7 +436,7 @@ class SupervisorTest(test.TestCase):
sv.summary_computed(sess, sess.run(summ))
def testNoLogdirButExplicitSummaryWriter(self):
- logdir = _test_dir("explicit_summary_writer")
+ logdir = self._test_dir("explicit_summary_writer")
with ops.Graph().as_default():
summary.scalar("c1", constant_op.constant(1))
summary.scalar("c2", constant_op.constant(2))
@@ -506,7 +505,7 @@ class SupervisorTest(test.TestCase):
sv.prepare_or_wait_for_session("")
def testInitOp(self):
- logdir = _test_dir("default_init_op")
+ logdir = self._test_dir("default_init_op")
with ops.Graph().as_default():
v = variables.Variable([1.0, 2.0, 3.0])
sv = supervisor.Supervisor(logdir=logdir)
@@ -515,7 +514,7 @@ class SupervisorTest(test.TestCase):
sv.stop()
def testInitFn(self):
- logdir = _test_dir("default_init_op")
+ logdir = self._test_dir("default_init_op")
with ops.Graph().as_default():
v = variables.Variable([1.0, 2.0, 3.0])
@@ -528,7 +527,7 @@ class SupervisorTest(test.TestCase):
sv.stop()
def testInitOpWithFeedDict(self):
- logdir = _test_dir("feed_dict_init_op")
+ logdir = self._test_dir("feed_dict_init_op")
with ops.Graph().as_default():
p = array_ops.placeholder(dtypes.float32, shape=(3,))
v = variables.Variable(p, name="v")
@@ -542,7 +541,7 @@ class SupervisorTest(test.TestCase):
def testReadyForLocalInitOp(self):
server = server_lib.Server.create_local_server()
- logdir = _test_dir("default_ready_for_local_init_op")
+ logdir = self._test_dir("default_ready_for_local_init_op")
uid = uuid.uuid4().hex
@@ -584,7 +583,7 @@ class SupervisorTest(test.TestCase):
def testReadyForLocalInitOpRestoreFromCheckpoint(self):
server = server_lib.Server.create_local_server()
- logdir = _test_dir("ready_for_local_init_op_restore")
+ logdir = self._test_dir("ready_for_local_init_op_restore")
uid = uuid.uuid4().hex
@@ -639,7 +638,7 @@ class SupervisorTest(test.TestCase):
sv1.stop()
def testLocalInitOp(self):
- logdir = _test_dir("default_local_init_op")
+ logdir = self._test_dir("default_local_init_op")
with ops.Graph().as_default():
# A local variable.
v = variables.Variable(
@@ -664,7 +663,7 @@ class SupervisorTest(test.TestCase):
sv.stop()
def testLocalInitOpForNonChief(self):
- logdir = _test_dir("default_local_init_op_non_chief")
+ logdir = self._test_dir("default_local_init_op_non_chief")
with ops.Graph().as_default():
with ops.device("/job:localhost"):
# A local variable.
@@ -685,7 +684,7 @@ class SupervisorTest(test.TestCase):
def testInitOpFails(self):
server = server_lib.Server.create_local_server()
- logdir = _test_dir("default_init_op_fails")
+ logdir = self._test_dir("default_init_op_fails")
with ops.Graph().as_default():
v = variables.Variable([1.0, 2.0, 3.0], name="v")
variables.Variable([4.0, 5.0, 6.0], name="w")
@@ -697,7 +696,7 @@ class SupervisorTest(test.TestCase):
def testInitOpFailsForTransientVariable(self):
server = server_lib.Server.create_local_server()
- logdir = _test_dir("default_init_op_fails_for_local_variable")
+ logdir = self._test_dir("default_init_op_fails_for_local_variable")
with ops.Graph().as_default():
v = variables.Variable(
[1.0, 2.0, 3.0],
@@ -714,7 +713,7 @@ class SupervisorTest(test.TestCase):
sv.prepare_or_wait_for_session(server.target)
def testSetupFail(self):
- logdir = _test_dir("setup_fail")
+ logdir = self._test_dir("setup_fail")
with ops.Graph().as_default():
variables.Variable([1.0, 2.0, 3.0], name="v")
with self.assertRaisesRegexp(ValueError, "must have their device set"):
@@ -724,7 +723,7 @@ class SupervisorTest(test.TestCase):
supervisor.Supervisor(logdir=logdir, is_chief=False)
def testDefaultGlobalStep(self):
- logdir = _test_dir("default_global_step")
+ logdir = self._test_dir("default_global_step")
with ops.Graph().as_default():
variables.Variable(287, name="global_step")
sv = supervisor.Supervisor(logdir=logdir)
@@ -733,7 +732,7 @@ class SupervisorTest(test.TestCase):
sv.stop()
def testRestoreFromMetaGraph(self):
- logdir = _test_dir("restore_from_meta_graph")
+ logdir = self._test_dir("restore_from_meta_graph")
with ops.Graph().as_default():
variables.Variable(1, name="v0")
sv = supervisor.Supervisor(logdir=logdir)
@@ -754,7 +753,7 @@ class SupervisorTest(test.TestCase):
# right away and get to run once before sv.stop() returns.
# We still sleep a bit to make the test robust.
def testStandardServicesWithoutGlobalStep(self):
- logdir = _test_dir("standard_services_without_global_step")
+ logdir = self._test_dir("standard_services_without_global_step")
# Create a checkpoint.
with ops.Graph().as_default():
v = variables.Variable([1.0], name="foo")
@@ -804,7 +803,7 @@ class SupervisorTest(test.TestCase):
# Same as testStandardServicesNoGlobalStep but with a global step.
# We should get a summary about the step time.
def testStandardServicesWithGlobalStep(self):
- logdir = _test_dir("standard_services_with_global_step")
+ logdir = self._test_dir("standard_services_with_global_step")
# Create a checkpoint.
with ops.Graph().as_default():
v = variables.Variable([123], name="global_step")
@@ -867,12 +866,12 @@ class SupervisorTest(test.TestCase):
def testNoQueueRunners(self):
with ops.Graph().as_default(), self.test_session() as sess:
- sv = supervisor.Supervisor(logdir=_test_dir("no_queue_runners"))
+ sv = supervisor.Supervisor(logdir=self._test_dir("no_queue_runners"))
self.assertEqual(0, len(sv.start_queue_runners(sess)))
sv.stop()
def testPrepareSessionAfterStopForChief(self):
- logdir = _test_dir("prepare_after_stop_chief")
+ logdir = self._test_dir("prepare_after_stop_chief")
with ops.Graph().as_default():
sv = supervisor.Supervisor(logdir=logdir, is_chief=True)
@@ -891,7 +890,7 @@ class SupervisorTest(test.TestCase):
self.assertTrue(sv.should_stop())
def testPrepareSessionAfterStopForNonChief(self):
- logdir = _test_dir("prepare_after_stop_nonchief")
+ logdir = self._test_dir("prepare_after_stop_nonchief")
with ops.Graph().as_default():
sv = supervisor.Supervisor(logdir=logdir, is_chief=False)