# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Tests for supervisor.py.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import glob import os import shutil import time import uuid from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.core.framework import graph_pb2 from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.core.util import event_pb2 from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import meta_graph from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import io_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.summary import summary from tensorflow.python.summary import summary_iterator from tensorflow.python.summary.writer import writer from tensorflow.python.training import checkpoint_management from tensorflow.python.training import input as input_lib from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import server_lib from tensorflow.python.training import session_manager as session_manager_lib from tensorflow.python.training import supervisor def _summary_iterator(test_dir): """Reads events from test_dir/events. Args: test_dir: Name of the test directory. Returns: A summary_iterator """ event_paths = sorted(glob.glob(os.path.join(test_dir, "event*"))) return summary_iterator.summary_iterator(event_paths[-1]) class SupervisorTest(test.TestCase): def _test_dir(self, test_name): test_dir = os.path.join(self.get_temp_dir(), test_name) if os.path.exists(test_dir): shutil.rmtree(test_dir) return test_dir def _wait_for_glob(self, pattern, timeout_secs, for_checkpoint=True): """Wait for a checkpoint file to appear. Args: pattern: A string. timeout_secs: How long to wait for in seconds. for_checkpoint: whether we're globbing for checkpoints. """ end_time = time.time() + timeout_secs while time.time() < end_time: if for_checkpoint: if checkpoint_management.checkpoint_exists(pattern): return else: if len(gfile.Glob(pattern)) >= 1: return time.sleep(0.05) self.assertFalse(True, "Glob never matched any file: %s" % pattern) # This test does not test much. def testBasics(self): logdir = self._test_dir("basics") with ops.Graph().as_default(): my_op = constant_op.constant(1.0) sv = supervisor.Supervisor(logdir=logdir) sess = sv.prepare_or_wait_for_session("") for _ in xrange(10): sess.run(my_op) sess.close() sv.stop() def testManagedSession(self): logdir = self._test_dir("managed_session") with ops.Graph().as_default(): my_op = constant_op.constant(1.0) sv = supervisor.Supervisor(logdir=logdir) with sv.managed_session("") as sess: for _ in xrange(10): sess.run(my_op) # Supervisor has been stopped. self.assertTrue(sv.should_stop()) def testManagedSessionUserError(self): logdir = self._test_dir("managed_user_error") with ops.Graph().as_default(): my_op = constant_op.constant(1.0) sv = supervisor.Supervisor(logdir=logdir) last_step = None with self.assertRaisesRegexp(RuntimeError, "failing here"): with sv.managed_session("") as sess: for step in xrange(10): last_step = step if step == 1: raise RuntimeError("failing here") else: sess.run(my_op) # Supervisor has been stopped. self.assertTrue(sv.should_stop()) self.assertEqual(1, last_step) def testManagedSessionIgnoreOutOfRangeError(self): logdir = self._test_dir("managed_out_of_range") with ops.Graph().as_default(): my_op = constant_op.constant(1.0) sv = supervisor.Supervisor(logdir=logdir) last_step = None with sv.managed_session("") as sess: for step in xrange(10): last_step = step if step == 3: raise errors_impl.OutOfRangeError(my_op.op.node_def, my_op.op, "all done") else: sess.run(my_op) # Supervisor has been stopped. OutOfRangeError was not thrown. self.assertTrue(sv.should_stop()) self.assertEqual(3, last_step) def testManagedSessionDoNotKeepSummaryWriter(self): logdir = self._test_dir("managed_not_keep_summary_writer") with ops.Graph().as_default(): summary.scalar("c1", constant_op.constant(1)) summary.scalar("c2", constant_op.constant(2)) summary.scalar("c3", constant_op.constant(3)) summ = summary.merge_all() sv = supervisor.Supervisor(logdir=logdir, summary_op=None) with sv.managed_session( "", close_summary_writer=True, start_standard_services=False) as sess: sv.summary_computed(sess, sess.run(summ)) # Sleep 1.2s to make sure that the next event file has a different name # than the current one. time.sleep(1.2) with sv.managed_session( "", close_summary_writer=True, start_standard_services=False) as sess: sv.summary_computed(sess, sess.run(summ)) event_paths = sorted(glob.glob(os.path.join(logdir, "event*"))) self.assertEquals(2, len(event_paths)) # The two event files should have the same contents. for path in event_paths: # The summary iterator should report the summary once as we closed the # summary writer across the 2 sessions. rr = summary_iterator.summary_iterator(path) # The first event should list the file_version. ev = next(rr) self.assertEquals("brain.Event:2", ev.file_version) # The next one has the graph and metagraph. ev = next(rr) self.assertTrue(ev.graph_def) ev = next(rr) self.assertTrue(ev.meta_graph_def) # The next one should have the values from the summary. # But only once. ev = next(rr) self.assertProtoEquals(""" value { tag: 'c1' simple_value: 1.0 } value { tag: 'c2' simple_value: 2.0 } value { tag: 'c3' simple_value: 3.0 } """, ev.summary) # The next one should be a stop message if we closed cleanly. ev = next(rr) self.assertEquals(event_pb2.SessionLog.STOP, ev.session_log.status) # We should be done. with self.assertRaises(StopIteration): next(rr) def testManagedSessionKeepSummaryWriter(self): logdir = self._test_dir("managed_keep_summary_writer") with ops.Graph().as_default(): summary.scalar("c1", constant_op.constant(1)) summary.scalar("c2", constant_op.constant(2)) summary.scalar("c3", constant_op.constant(3)) summ = summary.merge_all() sv = supervisor.Supervisor(logdir=logdir) with sv.managed_session( "", close_summary_writer=False, start_standard_services=False) as sess: sv.summary_computed(sess, sess.run(summ)) with sv.managed_session( "", close_summary_writer=False, start_standard_services=False) as sess: sv.summary_computed(sess, sess.run(summ)) # Now close the summary writer to flush the events. sv.summary_writer.close() # The summary iterator should report the summary twice as we reused # the same summary writer across the 2 sessions. rr = _summary_iterator(logdir) # The first event should list the file_version. ev = next(rr) self.assertEquals("brain.Event:2", ev.file_version) # The next one has the graph. ev = next(rr) self.assertTrue(ev.graph_def) ev = next(rr) self.assertTrue(ev.meta_graph_def) # The next one should have the values from the summary. ev = next(rr) self.assertProtoEquals(""" value { tag: 'c1' simple_value: 1.0 } value { tag: 'c2' simple_value: 2.0 } value { tag: 'c3' simple_value: 3.0 } """, ev.summary) # The next one should also have the values from the summary. ev = next(rr) self.assertProtoEquals(""" value { tag: 'c1' simple_value: 1.0 } value { tag: 'c2' simple_value: 2.0 } value { tag: 'c3' simple_value: 3.0 } """, ev.summary) # We should be done. self.assertRaises(StopIteration, lambda: next(rr)) def _csv_data(self, logdir): # Create a small data file with 3 CSV records. data_path = os.path.join(logdir, "data.csv") with open(data_path, "w") as f: f.write("1,2,3\n") f.write("4,5,6\n") f.write("7,8,9\n") return data_path def testManagedEndOfInputOneQueue(self): # Tests that the supervisor finishes without an error when using # a fixed number of epochs, reading from a single queue. logdir = self._test_dir("managed_end_of_input_one_queue") os.makedirs(logdir) data_path = self._csv_data(logdir) with ops.Graph().as_default(): # Create an input pipeline that reads the file 3 times. filename_queue = input_lib.string_input_producer( [data_path], num_epochs=3) reader = io_ops.TextLineReader() _, csv = reader.read(filename_queue) rec = parsing_ops.decode_csv(csv, record_defaults=[[1], [1], [1]]) sv = supervisor.Supervisor(logdir=logdir) with sv.managed_session("") as sess: while not sv.should_stop(): sess.run(rec) def testManagedEndOfInputTwoQueues(self): # Tests that the supervisor finishes without an error when using # a fixed number of epochs, reading from two queues, the second # one producing a batch from the first one. logdir = self._test_dir("managed_end_of_input_two_queues") os.makedirs(logdir) data_path = self._csv_data(logdir) with ops.Graph().as_default(): # Create an input pipeline that reads the file 3 times. filename_queue = input_lib.string_input_producer( [data_path], num_epochs=3) reader = io_ops.TextLineReader() _, csv = reader.read(filename_queue) rec = parsing_ops.decode_csv(csv, record_defaults=[[1], [1], [1]]) shuff_rec = input_lib.shuffle_batch(rec, 1, 6, 4) sv = supervisor.Supervisor(logdir=logdir) with sv.managed_session("") as sess: while not sv.should_stop(): sess.run(shuff_rec) def testManagedMainErrorTwoQueues(self): # Tests that the supervisor correctly raises a main loop # error even when using multiple queues for input. logdir = self._test_dir("managed_main_error_two_queues") os.makedirs(logdir) data_path = self._csv_data(logdir) with self.assertRaisesRegexp(RuntimeError, "fail at step 3"): with ops.Graph().as_default(): # Create an input pipeline that reads the file 3 times. filename_queue = input_lib.string_input_producer( [data_path], num_epochs=3) reader = io_ops.TextLineReader() _, csv = reader.read(filename_queue) rec = parsing_ops.decode_csv(csv, record_defaults=[[1], [1], [1]]) shuff_rec = input_lib.shuffle_batch(rec, 1, 6, 4) sv = supervisor.Supervisor(logdir=logdir) with sv.managed_session("") as sess: for step in range(9): if sv.should_stop(): break elif step == 3: raise RuntimeError("fail at step 3") else: sess.run(shuff_rec) def testSessionConfig(self): logdir = self._test_dir("session_config") with ops.Graph().as_default(): with ops.device("/cpu:1"): my_op = constant_op.constant([1.0]) sv = supervisor.Supervisor(logdir=logdir) sess = sv.prepare_or_wait_for_session( "", config=config_pb2.ConfigProto(device_count={"CPU": 2})) for _ in xrange(10): sess.run(my_op) sess.close() sv.stop() def testChiefCanWriteEvents(self): logdir = self._test_dir("can_write") with ops.Graph().as_default(): summary.scalar("c1", constant_op.constant(1)) summary.scalar("c2", constant_op.constant(2)) summary.scalar("c3", constant_op.constant(3)) summ = summary.merge_all() sv = supervisor.Supervisor(is_chief=True, logdir=logdir, summary_op=None) meta_graph_def = meta_graph.create_meta_graph_def() sess = sv.prepare_or_wait_for_session("") sv.summary_computed(sess, sess.run(summ)) sess.close() # Wait to make sure everything is written to file before stopping. time.sleep(1) sv.stop() rr = _summary_iterator(logdir) # The first event should list the file_version. ev = next(rr) self.assertEquals("brain.Event:2", ev.file_version) # The next one has the graph. ev = next(rr) ev_graph = graph_pb2.GraphDef() ev_graph.ParseFromString(ev.graph_def) self.assertProtoEquals(sess.graph.as_graph_def(add_shapes=True), ev_graph) # Stored MetaGraphDef ev = next(rr) ev_meta_graph = meta_graph_pb2.MetaGraphDef() ev_meta_graph.ParseFromString(ev.meta_graph_def) self.assertProtoEquals(meta_graph_def, ev_meta_graph) self.assertProtoEquals( sess.graph.as_graph_def(add_shapes=True), ev_meta_graph.graph_def) # The next one should have the values from the summary. ev = next(rr) self.assertProtoEquals(""" value { tag: 'c1' simple_value: 1.0 } value { tag: 'c2' simple_value: 2.0 } value { tag: 'c3' simple_value: 3.0 } """, ev.summary) # The next one should be a stop message if we closed cleanly. ev = next(rr) self.assertEquals(event_pb2.SessionLog.STOP, ev.session_log.status) # We should be done. self.assertRaises(StopIteration, lambda: next(rr)) def testNonChiefCannotWriteEvents(self): def _summary_computed(): with ops.Graph().as_default(): sv = supervisor.Supervisor(is_chief=False) sess = sv.prepare_or_wait_for_session("") summary.scalar("c1", constant_op.constant(1)) summary.scalar("c2", constant_op.constant(2)) summ = summary.merge_all() sv.summary_computed(sess, sess.run(summ)) def _start_standard_services(): with ops.Graph().as_default(): sv = supervisor.Supervisor(is_chief=False) sess = sv.prepare_or_wait_for_session("") sv.start_standard_services(sess) self.assertRaises(RuntimeError, _summary_computed) self.assertRaises(RuntimeError, _start_standard_services) def testNoLogdirButWantSummary(self): with ops.Graph().as_default(): summary.scalar("c1", constant_op.constant(1)) summary.scalar("c2", constant_op.constant(2)) summary.scalar("c3", constant_op.constant(3)) summ = summary.merge_all() sv = supervisor.Supervisor(logdir="", summary_op=None) sess = sv.prepare_or_wait_for_session("") with self.assertRaisesRegexp(RuntimeError, "requires a summary writer"): sv.summary_computed(sess, sess.run(summ)) def testLogdirButExplicitlyNoSummaryWriter(self): logdir = self._test_dir("explicit_no_summary_writer") with ops.Graph().as_default(): variables.VariableV1([1.0], name="foo") summary.scalar("c1", constant_op.constant(1)) summary.scalar("c2", constant_op.constant(2)) summary.scalar("c3", constant_op.constant(3)) summ = summary.merge_all() sv = supervisor.Supervisor(logdir=logdir, summary_writer=None) sess = sv.prepare_or_wait_for_session("") # Check that a checkpoint is still be generated. self._wait_for_glob(sv.save_path, 3.0) # Check that we cannot write a summary with self.assertRaisesRegexp(RuntimeError, "requires a summary writer"): sv.summary_computed(sess, sess.run(summ)) def testNoLogdirButExplicitSummaryWriter(self): logdir = self._test_dir("explicit_summary_writer") with ops.Graph().as_default(): summary.scalar("c1", constant_op.constant(1)) summary.scalar("c2", constant_op.constant(2)) summary.scalar("c3", constant_op.constant(3)) summ = summary.merge_all() sw = writer.FileWriter(logdir) sv = supervisor.Supervisor(logdir="", summary_op=None, summary_writer=sw) meta_graph_def = meta_graph.create_meta_graph_def() sess = sv.prepare_or_wait_for_session("") sv.summary_computed(sess, sess.run(summ)) sess.close() # Wait to make sure everything is written to file before stopping. time.sleep(1) sv.stop() # Check the summary was written to 'logdir' rr = _summary_iterator(logdir) # The first event should list the file_version. ev = next(rr) self.assertEquals("brain.Event:2", ev.file_version) # The next one has the graph. ev = next(rr) ev_graph = graph_pb2.GraphDef() ev_graph.ParseFromString(ev.graph_def) self.assertProtoEquals(sess.graph.as_graph_def(add_shapes=True), ev_graph) # Stored MetaGraphDef ev = next(rr) ev_meta_graph = meta_graph_pb2.MetaGraphDef() ev_meta_graph.ParseFromString(ev.meta_graph_def) self.assertProtoEquals(meta_graph_def, ev_meta_graph) self.assertProtoEquals( sess.graph.as_graph_def(add_shapes=True), ev_meta_graph.graph_def) # The next one should have the values from the summary. ev = next(rr) self.assertProtoEquals(""" value { tag: 'c1' simple_value: 1.0 } value { tag: 'c2' simple_value: 2.0 } value { tag: 'c3' simple_value: 3.0 } """, ev.summary) # The next one should be a stop message if we closed cleanly. ev = next(rr) self.assertEquals(event_pb2.SessionLog.STOP, ev.session_log.status) # We should be done. self.assertRaises(StopIteration, lambda: next(rr)) def testNoLogdirSucceeds(self): with ops.Graph().as_default(): variables.VariableV1([1.0, 2.0, 3.0]) sv = supervisor.Supervisor(logdir="", summary_op=None) sess = sv.prepare_or_wait_for_session("") sess.close() sv.stop() def testUseSessionManager(self): with ops.Graph().as_default(): variables.VariableV1([1.0, 2.0, 3.0]) sm = session_manager_lib.SessionManager() # Pass in session_manager. The additional init_op is ignored. sv = supervisor.Supervisor(logdir="", session_manager=sm) sv.prepare_or_wait_for_session("") def testInitOp(self): logdir = self._test_dir("default_init_op") with ops.Graph().as_default(): v = variables.VariableV1([1.0, 2.0, 3.0]) sv = supervisor.Supervisor(logdir=logdir) sess = sv.prepare_or_wait_for_session("") self.assertAllClose([1.0, 2.0, 3.0], sess.run(v)) sv.stop() def testInitFn(self): logdir = self._test_dir("default_init_op") with ops.Graph().as_default(): v = variables.VariableV1([1.0, 2.0, 3.0]) def _init_fn(sess): sess.run(v.initializer) sv = supervisor.Supervisor(logdir=logdir, init_op=None, init_fn=_init_fn) sess = sv.prepare_or_wait_for_session("") self.assertAllClose([1.0, 2.0, 3.0], sess.run(v)) sv.stop() def testInitOpWithFeedDict(self): logdir = self._test_dir("feed_dict_init_op") with ops.Graph().as_default(): p = array_ops.placeholder(dtypes.float32, shape=(3,)) v = variables.VariableV1(p, name="v") sv = supervisor.Supervisor( logdir=logdir, init_op=variables.global_variables_initializer(), init_feed_dict={p: [1.0, 2.0, 3.0]}) sess = sv.prepare_or_wait_for_session("") self.assertAllClose([1.0, 2.0, 3.0], sess.run(v)) sv.stop() def testReadyForLocalInitOp(self): server = server_lib.Server.create_local_server() logdir = self._test_dir("default_ready_for_local_init_op") uid = uuid.uuid4().hex def get_session(is_chief): g = ops.Graph() with g.as_default(): with ops.device("/job:local"): v = variables.VariableV1( 1, name="default_ready_for_local_init_op_v_" + str(uid)) vadd = v.assign_add(1) w = variables.VariableV1( v, trainable=False, collections=[ops.GraphKeys.LOCAL_VARIABLES], name="default_ready_for_local_init_op_w_" + str(uid)) ready_for_local_init_op = variables.report_uninitialized_variables( variables.global_variables()) sv = supervisor.Supervisor( logdir=logdir, is_chief=is_chief, graph=g, recovery_wait_secs=1, init_op=v.initializer, ready_for_local_init_op=ready_for_local_init_op) sess = sv.prepare_or_wait_for_session(server.target) return sv, sess, v, vadd, w sv0, sess0, v0, _, w0 = get_session(True) sv1, sess1, _, vadd1, w1 = get_session(False) self.assertEqual(1, sess0.run(w0)) self.assertEqual(2, sess1.run(vadd1)) self.assertEqual(1, sess1.run(w1)) self.assertEqual(2, sess0.run(v0)) sv0.stop() sv1.stop() def testReadyForLocalInitOpRestoreFromCheckpoint(self): server = server_lib.Server.create_local_server() logdir = self._test_dir("ready_for_local_init_op_restore") uid = uuid.uuid4().hex # Create a checkpoint. with ops.Graph().as_default(): v = variables.VariableV1( 10.0, name="ready_for_local_init_op_restore_v_" + str(uid)) summary.scalar("ready_for_local_init_op_restore_v_" + str(uid), v) sv = supervisor.Supervisor(logdir=logdir) sv.prepare_or_wait_for_session(server.target) save_path = sv.save_path self._wait_for_glob(save_path, 3.0) self._wait_for_glob( os.path.join(logdir, "*events*"), 3.0, for_checkpoint=False) # Wait to make sure everything is written to file before stopping. time.sleep(1) sv.stop() def get_session(is_chief): g = ops.Graph() with g.as_default(): with ops.device("/job:local"): v = variables.VariableV1( 1.0, name="ready_for_local_init_op_restore_v_" + str(uid)) vadd = v.assign_add(1) w = variables.VariableV1( v, trainable=False, collections=[ops.GraphKeys.LOCAL_VARIABLES], name="ready_for_local_init_op_restore_w_" + str(uid)) ready_for_local_init_op = variables.report_uninitialized_variables( variables.global_variables()) sv = supervisor.Supervisor( logdir=logdir, is_chief=is_chief, graph=g, recovery_wait_secs=1, ready_for_local_init_op=ready_for_local_init_op) sess = sv.prepare_or_wait_for_session(server.target) return sv, sess, v, vadd, w sv0, sess0, v0, _, w0 = get_session(True) sv1, sess1, _, vadd1, w1 = get_session(False) self.assertEqual(10, sess0.run(w0)) self.assertEqual(11, sess1.run(vadd1)) self.assertEqual(10, sess1.run(w1)) self.assertEqual(11, sess0.run(v0)) sv0.stop() sv1.stop() def testLocalInitOp(self): logdir = self._test_dir("default_local_init_op") with ops.Graph().as_default(): # A local variable. v = variables.VariableV1( [1.0, 2.0, 3.0], trainable=False, collections=[ops.GraphKeys.LOCAL_VARIABLES]) # An entity which is initialized through a TABLE_INITIALIZER. w = variables.VariableV1([4, 5, 6], trainable=False, collections=[]) ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, w.initializer) # This shouldn't add a variable to the VARIABLES collection responsible # for variables that are saved/restored from checkpoints. self.assertEquals(len(variables.global_variables()), 0) # Suppress normal variable inits to make sure the local one is # initialized via local_init_op. sv = supervisor.Supervisor(logdir=logdir, init_op=None) sess = sv.prepare_or_wait_for_session("") self.assertAllClose([1.0, 2.0, 3.0], sess.run(v)) self.assertAllClose([4, 5, 6], sess.run(w)) sv.stop() def testLocalInitOpForNonChief(self): logdir = self._test_dir("default_local_init_op_non_chief") with ops.Graph().as_default(): with ops.device("/job:localhost"): # A local variable. v = variables.VariableV1( [1.0, 2.0, 3.0], trainable=False, collections=[ops.GraphKeys.LOCAL_VARIABLES]) # This shouldn't add a variable to the VARIABLES collection responsible # for variables that are saved/restored from checkpoints. self.assertEquals(len(variables.global_variables()), 0) # Suppress normal variable inits to make sure the local one is # initialized via local_init_op. sv = supervisor.Supervisor(logdir=logdir, init_op=None, is_chief=False) sess = sv.prepare_or_wait_for_session("") self.assertAllClose([1.0, 2.0, 3.0], sess.run(v)) sv.stop() def testInitOpFails(self): server = server_lib.Server.create_local_server() logdir = self._test_dir("default_init_op_fails") with ops.Graph().as_default(): v = variables.VariableV1([1.0, 2.0, 3.0], name="v") variables.VariableV1([4.0, 5.0, 6.0], name="w") # w will not be initialized. sv = supervisor.Supervisor(logdir=logdir, init_op=v.initializer) with self.assertRaisesRegexp(RuntimeError, "Variables not initialized: w"): sv.prepare_or_wait_for_session(server.target) def testInitOpFailsForTransientVariable(self): server = server_lib.Server.create_local_server() logdir = self._test_dir("default_init_op_fails_for_local_variable") with ops.Graph().as_default(): v = variables.VariableV1( [1.0, 2.0, 3.0], name="v", collections=[ops.GraphKeys.LOCAL_VARIABLES]) variables.VariableV1( [1.0, 2.0, 3.0], name="w", collections=[ops.GraphKeys.LOCAL_VARIABLES]) # w will not be initialized. sv = supervisor.Supervisor(logdir=logdir, local_init_op=v.initializer) with self.assertRaisesRegexp(RuntimeError, "Variables not initialized: w"): sv.prepare_or_wait_for_session(server.target) def testSetupFail(self): logdir = self._test_dir("setup_fail") with ops.Graph().as_default(): variables.VariableV1([1.0, 2.0, 3.0], name="v") with self.assertRaisesRegexp(ValueError, "must have their device set"): supervisor.Supervisor(logdir=logdir, is_chief=False) with ops.Graph().as_default(), ops.device("/job:ps"): variables.VariableV1([1.0, 2.0, 3.0], name="v") supervisor.Supervisor(logdir=logdir, is_chief=False) def testDefaultGlobalStep(self): logdir = self._test_dir("default_global_step") with ops.Graph().as_default(): variables.VariableV1(287, name="global_step") sv = supervisor.Supervisor(logdir=logdir) sess = sv.prepare_or_wait_for_session("") self.assertEquals(287, sess.run(sv.global_step)) sv.stop() def testRestoreFromMetaGraph(self): logdir = self._test_dir("restore_from_meta_graph") with ops.Graph().as_default(): variables.VariableV1(1, name="v0") sv = supervisor.Supervisor(logdir=logdir) sess = sv.prepare_or_wait_for_session("") filename = sv.saver.save(sess, sv.save_path) sv.stop() # Create a new Graph and Supervisor and recover. with ops.Graph().as_default(): new_saver = saver_lib.import_meta_graph(".".join([filename, "meta"])) self.assertIsNotNone(new_saver) sv2 = supervisor.Supervisor(logdir=logdir, saver=new_saver) sess = sv2.prepare_or_wait_for_session("") self.assertEquals(1, sess.run("v0:0")) sv2.saver.save(sess, sv2.save_path) sv2.stop() # This test is based on the fact that the standard services start # right away and get to run once before sv.stop() returns. # We still sleep a bit to make the test robust. def testStandardServicesWithoutGlobalStep(self): logdir = self._test_dir("standard_services_without_global_step") # Create a checkpoint. with ops.Graph().as_default(): v = variables.VariableV1([1.0], name="foo") summary.scalar("v", v[0]) sv = supervisor.Supervisor(logdir=logdir) meta_graph_def = meta_graph.create_meta_graph_def( saver_def=sv.saver.saver_def) sess = sv.prepare_or_wait_for_session("") save_path = sv.save_path self._wait_for_glob(save_path, 3.0) self._wait_for_glob( os.path.join(logdir, "*events*"), 3.0, for_checkpoint=False) # Wait to make sure everything is written to file before stopping. time.sleep(1) sv.stop() # There should be an event file with a version number. rr = _summary_iterator(logdir) ev = next(rr) self.assertEquals("brain.Event:2", ev.file_version) ev = next(rr) ev_graph = graph_pb2.GraphDef() ev_graph.ParseFromString(ev.graph_def) self.assertProtoEquals(sess.graph.as_graph_def(add_shapes=True), ev_graph) # Stored MetaGraphDef ev = next(rr) ev_meta_graph = meta_graph_pb2.MetaGraphDef() ev_meta_graph.ParseFromString(ev.meta_graph_def) self.assertProtoEquals(meta_graph_def, ev_meta_graph) self.assertProtoEquals( sess.graph.as_graph_def(add_shapes=True), ev_meta_graph.graph_def) ev = next(rr) self.assertProtoEquals("value { tag: 'v' simple_value: 1.0 }", ev.summary) ev = next(rr) self.assertEquals(event_pb2.SessionLog.STOP, ev.session_log.status) self.assertRaises(StopIteration, lambda: next(rr)) # There should be a checkpoint file with the variable "foo" with ops.Graph().as_default(), self.cached_session() as sess: v = variables.VariableV1([10.10], name="foo") sav = saver_lib.Saver([v]) sav.restore(sess, save_path) self.assertEqual(1.0, v.eval()[0]) # Same as testStandardServicesNoGlobalStep but with a global step. # We should get a summary about the step time. def testStandardServicesWithGlobalStep(self): logdir = self._test_dir("standard_services_with_global_step") # Create a checkpoint. with ops.Graph().as_default(): v = variables.VariableV1([123], name="global_step") sv = supervisor.Supervisor(logdir=logdir) meta_graph_def = meta_graph.create_meta_graph_def( saver_def=sv.saver.saver_def) sess = sv.prepare_or_wait_for_session("") # This is where the checkpoint will appear, with step number 123. save_path = "%s-123" % sv.save_path self._wait_for_glob(save_path, 3.0) self._wait_for_glob( os.path.join(logdir, "*events*"), 3.0, for_checkpoint=False) # Wait to make sure everything is written to file before stopping. time.sleep(1) sv.stop() # There should be an event file with a version number. rr = _summary_iterator(logdir) ev = next(rr) self.assertEquals("brain.Event:2", ev.file_version) ev = next(rr) ev_graph = graph_pb2.GraphDef() ev_graph.ParseFromString(ev.graph_def) self.assertProtoEquals(sess.graph.as_graph_def(add_shapes=True), ev_graph) ev = next(rr) ev_meta_graph = meta_graph_pb2.MetaGraphDef() ev_meta_graph.ParseFromString(ev.meta_graph_def) self.assertProtoEquals(meta_graph_def, ev_meta_graph) self.assertProtoEquals( sess.graph.as_graph_def(add_shapes=True), ev_meta_graph.graph_def) ev = next(rr) # It is actually undeterministic whether SessionLog.START gets written # before the summary or the checkpoint, but this works when run 10000 times. self.assertEquals(123, ev.step) self.assertEquals(event_pb2.SessionLog.START, ev.session_log.status) first = next(rr) second = next(rr) # It is undeterministic whether the value gets written before the checkpoint # since they are on separate threads, so we check for both conditions. if first.HasField("summary"): self.assertProtoEquals("""value { tag: 'global_step/sec' simple_value: 0.0 }""", first.summary) self.assertEquals(123, second.step) self.assertEquals(event_pb2.SessionLog.CHECKPOINT, second.session_log.status) else: self.assertEquals(123, first.step) self.assertEquals(event_pb2.SessionLog.CHECKPOINT, first.session_log.status) self.assertProtoEquals("""value { tag: 'global_step/sec' simple_value: 0.0 }""", second.summary) ev = next(rr) self.assertEquals(event_pb2.SessionLog.STOP, ev.session_log.status) self.assertRaises(StopIteration, lambda: next(rr)) # There should be a checkpoint file with the variable "foo" with ops.Graph().as_default(), self.cached_session() as sess: v = variables.VariableV1([-12], name="global_step") sav = saver_lib.Saver([v]) sav.restore(sess, save_path) self.assertEqual(123, v.eval()[0]) def testNoQueueRunners(self): with ops.Graph().as_default(), self.cached_session() as sess: sv = supervisor.Supervisor(logdir=self._test_dir("no_queue_runners")) self.assertEqual(0, len(sv.start_queue_runners(sess))) sv.stop() def testPrepareSessionAfterStopForChief(self): logdir = self._test_dir("prepare_after_stop_chief") with ops.Graph().as_default(): sv = supervisor.Supervisor(logdir=logdir, is_chief=True) # Create a first session and then stop. sess = sv.prepare_or_wait_for_session("") sv.stop() sess.close() self.assertTrue(sv.should_stop()) # Now create a second session and test that we don't stay stopped, until # we ask to stop again. sess2 = sv.prepare_or_wait_for_session("") self.assertFalse(sv.should_stop()) sv.stop() sess2.close() self.assertTrue(sv.should_stop()) def testPrepareSessionAfterStopForNonChief(self): logdir = self._test_dir("prepare_after_stop_nonchief") with ops.Graph().as_default(): sv = supervisor.Supervisor(logdir=logdir, is_chief=False) # Create a first session and then stop. sess = sv.prepare_or_wait_for_session("") sv.stop() sess.close() self.assertTrue(sv.should_stop()) # Now create a second session and test that we don't stay stopped, until # we ask to stop again. sess2 = sv.prepare_or_wait_for_session("") self.assertFalse(sv.should_stop()) sv.stop() sess2.close() self.assertTrue(sv.should_stop()) if __name__ == "__main__": test.main()