diff options
author | Jeremy Lau <lauj@google.com> | 2018-07-13 14:29:12 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-13 14:33:08 -0700 |
commit | 590af170ca85a4921db0c28e4fa2785462bdcebd (patch) | |
tree | 56ef59605e42a992dbe88c977da0f25e728f0020 | |
parent | 3d949e2016a967e303b8ddbd3cbe3bd3408320e8 (diff) |
TPUEstimator: Run tpu.initialize_system() in its own graph whenever the main
graph is finalized.
PiperOrigin-RevId: 204529164
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/tpu_context.py | 10 | ||||
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/tpu_estimator.py | 101 |
2 files changed, 91 insertions, 20 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_context.py b/tensorflow/contrib/tpu/python/tpu/tpu_context.py index 211c59cb90..e54395f05d 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_context.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_context.py @@ -234,7 +234,7 @@ class _InternalTPUContext(object): def mode(self): return self._assert_mode() - def _get_master_address(self): + def master_address(self): mode = self._assert_mode() config = self._config master = ( @@ -244,7 +244,7 @@ class _InternalTPUContext(object): def _get_tpu_system_metadata(self): """Gets the (maybe cached) TPU system metadata.""" - master = self._get_master_address() + master = self.master_address() tpu_system_metadata = self._lazy_tpu_system_metadata_dict.get(master) if tpu_system_metadata is not None: return tpu_system_metadata @@ -261,7 +261,7 @@ class _InternalTPUContext(object): def _get_device_assignment(self): """Gets the (maybe cached) TPU device assignment.""" - master = self._get_master_address() + master = self.master_address() device_assignment = self._lazy_device_assignment_dict.get(master) if device_assignment is not None: return device_assignment @@ -589,7 +589,7 @@ class _InternalTPUContext(object): 'model-parallelism, the total number of TPU cores should be ' 'num_cores_per_replica * num_replicas. Please set it ' 'accordingly or leave it as `None`'.format( - self._get_master_address(), num_replicas, + self.master_address(), num_replicas, user_provided_num_replicas)) raise ValueError(message) @@ -644,7 +644,7 @@ class _OneCoreTPUContext(_InternalTPUContext): def _get_tpu_system_metadata(self): """Gets the (maybe cached) TPU system metadata.""" - master = self._get_master_address() + master = self.master_address() tpu_system_metadata = self._lazy_tpu_system_metadata_dict.get(master) if tpu_system_metadata is not None: return tpu_system_metadata diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 74157a6193..aa407cf4d8 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -43,6 +43,7 @@ from tensorflow.contrib.training.python.training import hparam from tensorflow.core.framework import variable_pb2 from tensorflow.core.framework.summary_pb2 import Summary from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.client import session as session_lib from tensorflow.python.data.ops import dataset_ops from tensorflow.python.estimator import estimator as estimator_lib from tensorflow.python.estimator import model_fn as model_fn_lib @@ -67,6 +68,7 @@ from tensorflow.python.saved_model import tag_constants from tensorflow.python.summary import summary from tensorflow.python.training import basic_session_run_hooks from tensorflow.python.training import evaluation +from tensorflow.python.training import monitored_session from tensorflow.python.training import session_run_hook from tensorflow.python.training import training from tensorflow.python.training import training_util @@ -382,7 +384,14 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): def begin(self): logging.info('TPU job name %s', self._master_job) self._iterations_per_loop_var = _create_or_get_iterations_per_loop() - self._init_ops = [tpu.initialize_system(job=self._master_job)] + self._init_ops = [] + # For distributed sessions, we can't run initialize_system in a separate + # graph here because 'begin' is only invoked when the MonitoredSession is + # created. We need to reinitialize the system every time MonitoredSession + # creates an underlying tf.Session, so we initialize from Scaffold.finalize. + # See _get_and_wrap_scaffold for more details. + if self._master_job is None: + self._init_ops.append(tpu.initialize_system(job=self._master_job)) self._finalize_ops = [tpu.shutdown_system(job=self._master_job)] summary_writer_init_ops = contrib_summary.summary_writer_initializer_op() @@ -484,7 +493,7 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): return _OpQueueContext(name=name, target=target, args=args) def after_create_session(self, session, coord): - logging.info('Init TPU system') + logging.info('Running init_ops') session.run(self._init_ops, options=config_pb2.RunOptions(timeout_in_ms=5 * 60 * 1000)) @@ -2700,7 +2709,7 @@ def _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): outputs_from_all_shards=False, device_assignment=ctx.device_assignment) - scaffold = _get_scaffold(captured_scaffold_fn) + scaffold = _get_and_wrap_scaffold(captured_scaffold_fn, ctx) return loss, host_calls, scaffold @@ -2723,7 +2732,7 @@ def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): outputs_from_all_shards=False, device_assignment=ctx.device_assignment) - scaffold = _get_scaffold(captured_scaffold_fn) + scaffold = _get_and_wrap_scaffold(captured_scaffold_fn, ctx) return loss, host_call, scaffold @@ -2751,7 +2760,7 @@ def _predict_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): num_shards=num_cores, outputs_from_all_shards=False) - scaffold = _get_scaffold(captured_scaffold_fn) + scaffold = _get_and_wrap_scaffold(captured_scaffold_fn, ctx) return dummy_predict_op, host_calls, scaffold @@ -2841,8 +2850,20 @@ class _CapturedObject(object): return self._object -def _get_scaffold(captured_scaffold_fn): - """Retrieves the Scaffold from `captured_scaffold_fn`.""" +def _get_and_wrap_scaffold(captured_scaffold_fn, ctx): + """Retrieves the Scaffold from `captured_scaffold_fn`. + + Also wraps the scaffold's finalize method to initialize the TPU after the + graph is finalized. + + Args: + captured_scaffold_fn: a `_CapturedObject` containing a scaffold_fn. + ctx: A `_InternalTPUContext` instance used to initialize the TPU. + + Returns: + The Scaffold produced by captured_scaffold_fn, wrapped to initialize the TPU + after the graph is finalized. + """ with _CapturingContext(message='Inside scaffold_fn'): scaffold_fn = captured_scaffold_fn.get() if scaffold_fn: @@ -2853,14 +2874,64 @@ def _get_scaffold(captured_scaffold_fn): else: scaffold = None - if scaffold: - wrapped_finalize = scaffold.finalize - - def _finalize(): - with _CapturingContext('Inside Scaffold.finalize'): - wrapped_finalize() - - scaffold.finalize = _finalize + if scaffold is None: + # When master_address is None, we are using DirectSession, so we can't + # invoke initialize_system from finalize. See comments below. + if ctx.master_address() is None: + return scaffold + scaffold = monitored_session.Scaffold() + + wrapped_finalize = scaffold.finalize + + def _finalize(): + """Invoke wrapped_finalize and initialize the TPU.""" + with _CapturingContext('Inside Scaffold.finalize'): + wrapped_finalize() + # Run tpu.initialize_system in its own graph after finalizing the main graph + # for distributed sessions. This is necessary because the TPU must be + # initialized before the TPU graph rewrite pass runs. We can't put the + # initialization op in the main graph because the main graph also contains + # replicate ops created by tpu.shard. If we tried to run initialization from + # the main graph, the TPU graph rewrite pass would rewrite the replicate ops + # before actually evaluating the initialization ops. + # + # For distributed sessions, the master may independently restart. After a + # master restarts, the rewrite pass runs again when any op in the main graph + # runs, so we must reinitialize the system every time the main graph is + # finalized. + # + # Special case: When master_address is unset, we're using DirectSession. + # DirectSession resets device state between sessions, and uses + # place_pruned_graph. Initialization currently passes state to replication + # through the TPU_SYSTEM resource manager. Under DirectSession, this + # resource manager gets reset when init_session is closed, so DirectSession + # can't initialize here, and must instead initialize from the main graph's + # init_ops. This is possible with DirectSession because it uses + # place_pruned_graph, which removes unreferenced ops before invoking the + # rewrite pass. This makes it possible to run init_ops from the main graph, + # which contains both tpu.initialize_system and tpu.shard ops, without first + # triggering the TPU graph rewrite. We can't do this for distributed + # sessions because they don't support place_pruned_graph. + # + # TODO(b/110943344) Clean this up as part of the initialize_system dataflow + # cleanup. It should be possible to remove the special case for + # DirectSession and the other call to initialize_system from + # _obtain_topology, when topology info is always explicitly passed from + # tpu.initialize_system to tpu.shard, though this requires editing or + # rebuilding the main graph each time the master restarts. + if ctx.master_address() is None: + return + with ops.Graph().as_default(): + logging.info('Init TPU system master_address %s', ctx.master_address()) + with session_lib.Session( + ctx.master_address(), + config=ctx.config.session_config) as init_session: + run_options = config_pb2.RunOptions(timeout_in_ms=5 * 60 * 1000) + init_session.run( + tpu.initialize_system(job=ctx.master_job), options=run_options) + logging.info('TPU system initialized') + + scaffold.finalize = _finalize return scaffold |