aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Jeremy Lau <lauj@google.com>2018-07-13 14:29:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-13 14:33:08 -0700
commit590af170ca85a4921db0c28e4fa2785462bdcebd (patch)
tree56ef59605e42a992dbe88c977da0f25e728f0020
parent3d949e2016a967e303b8ddbd3cbe3bd3408320e8 (diff)
TPUEstimator: Run tpu.initialize_system() in its own graph whenever the main
graph is finalized. PiperOrigin-RevId: 204529164
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_context.py10
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py101
2 files changed, 91 insertions, 20 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_context.py b/tensorflow/contrib/tpu/python/tpu/tpu_context.py
index 211c59cb90..e54395f05d 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_context.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_context.py
@@ -234,7 +234,7 @@ class _InternalTPUContext(object):
def mode(self):
return self._assert_mode()
- def _get_master_address(self):
+ def master_address(self):
mode = self._assert_mode()
config = self._config
master = (
@@ -244,7 +244,7 @@ class _InternalTPUContext(object):
def _get_tpu_system_metadata(self):
"""Gets the (maybe cached) TPU system metadata."""
- master = self._get_master_address()
+ master = self.master_address()
tpu_system_metadata = self._lazy_tpu_system_metadata_dict.get(master)
if tpu_system_metadata is not None:
return tpu_system_metadata
@@ -261,7 +261,7 @@ class _InternalTPUContext(object):
def _get_device_assignment(self):
"""Gets the (maybe cached) TPU device assignment."""
- master = self._get_master_address()
+ master = self.master_address()
device_assignment = self._lazy_device_assignment_dict.get(master)
if device_assignment is not None:
return device_assignment
@@ -589,7 +589,7 @@ class _InternalTPUContext(object):
'model-parallelism, the total number of TPU cores should be '
'num_cores_per_replica * num_replicas. Please set it '
'accordingly or leave it as `None`'.format(
- self._get_master_address(), num_replicas,
+ self.master_address(), num_replicas,
user_provided_num_replicas))
raise ValueError(message)
@@ -644,7 +644,7 @@ class _OneCoreTPUContext(_InternalTPUContext):
def _get_tpu_system_metadata(self):
"""Gets the (maybe cached) TPU system metadata."""
- master = self._get_master_address()
+ master = self.master_address()
tpu_system_metadata = self._lazy_tpu_system_metadata_dict.get(master)
if tpu_system_metadata is not None:
return tpu_system_metadata
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index 74157a6193..aa407cf4d8 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -43,6 +43,7 @@ from tensorflow.contrib.training.python.training import hparam
from tensorflow.core.framework import variable_pb2
from tensorflow.core.framework.summary_pb2 import Summary
from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.client import session as session_lib
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator import estimator as estimator_lib
from tensorflow.python.estimator import model_fn as model_fn_lib
@@ -67,6 +68,7 @@ from tensorflow.python.saved_model import tag_constants
from tensorflow.python.summary import summary
from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import evaluation
+from tensorflow.python.training import monitored_session
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training
from tensorflow.python.training import training_util
@@ -382,7 +384,14 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook):
def begin(self):
logging.info('TPU job name %s', self._master_job)
self._iterations_per_loop_var = _create_or_get_iterations_per_loop()
- self._init_ops = [tpu.initialize_system(job=self._master_job)]
+ self._init_ops = []
+ # For distributed sessions, we can't run initialize_system in a separate
+ # graph here because 'begin' is only invoked when the MonitoredSession is
+ # created. We need to reinitialize the system every time MonitoredSession
+ # creates an underlying tf.Session, so we initialize from Scaffold.finalize.
+ # See _get_and_wrap_scaffold for more details.
+ if self._master_job is None:
+ self._init_ops.append(tpu.initialize_system(job=self._master_job))
self._finalize_ops = [tpu.shutdown_system(job=self._master_job)]
summary_writer_init_ops = contrib_summary.summary_writer_initializer_op()
@@ -484,7 +493,7 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook):
return _OpQueueContext(name=name, target=target, args=args)
def after_create_session(self, session, coord):
- logging.info('Init TPU system')
+ logging.info('Running init_ops')
session.run(self._init_ops,
options=config_pb2.RunOptions(timeout_in_ms=5 * 60 * 1000))
@@ -2700,7 +2709,7 @@ def _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
outputs_from_all_shards=False,
device_assignment=ctx.device_assignment)
- scaffold = _get_scaffold(captured_scaffold_fn)
+ scaffold = _get_and_wrap_scaffold(captured_scaffold_fn, ctx)
return loss, host_calls, scaffold
@@ -2723,7 +2732,7 @@ def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
outputs_from_all_shards=False,
device_assignment=ctx.device_assignment)
- scaffold = _get_scaffold(captured_scaffold_fn)
+ scaffold = _get_and_wrap_scaffold(captured_scaffold_fn, ctx)
return loss, host_call, scaffold
@@ -2751,7 +2760,7 @@ def _predict_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
num_shards=num_cores,
outputs_from_all_shards=False)
- scaffold = _get_scaffold(captured_scaffold_fn)
+ scaffold = _get_and_wrap_scaffold(captured_scaffold_fn, ctx)
return dummy_predict_op, host_calls, scaffold
@@ -2841,8 +2850,20 @@ class _CapturedObject(object):
return self._object
-def _get_scaffold(captured_scaffold_fn):
- """Retrieves the Scaffold from `captured_scaffold_fn`."""
+def _get_and_wrap_scaffold(captured_scaffold_fn, ctx):
+ """Retrieves the Scaffold from `captured_scaffold_fn`.
+
+ Also wraps the scaffold's finalize method to initialize the TPU after the
+ graph is finalized.
+
+ Args:
+ captured_scaffold_fn: a `_CapturedObject` containing a scaffold_fn.
+ ctx: A `_InternalTPUContext` instance used to initialize the TPU.
+
+ Returns:
+ The Scaffold produced by captured_scaffold_fn, wrapped to initialize the TPU
+ after the graph is finalized.
+ """
with _CapturingContext(message='Inside scaffold_fn'):
scaffold_fn = captured_scaffold_fn.get()
if scaffold_fn:
@@ -2853,14 +2874,64 @@ def _get_scaffold(captured_scaffold_fn):
else:
scaffold = None
- if scaffold:
- wrapped_finalize = scaffold.finalize
-
- def _finalize():
- with _CapturingContext('Inside Scaffold.finalize'):
- wrapped_finalize()
-
- scaffold.finalize = _finalize
+ if scaffold is None:
+ # When master_address is None, we are using DirectSession, so we can't
+ # invoke initialize_system from finalize. See comments below.
+ if ctx.master_address() is None:
+ return scaffold
+ scaffold = monitored_session.Scaffold()
+
+ wrapped_finalize = scaffold.finalize
+
+ def _finalize():
+ """Invoke wrapped_finalize and initialize the TPU."""
+ with _CapturingContext('Inside Scaffold.finalize'):
+ wrapped_finalize()
+ # Run tpu.initialize_system in its own graph after finalizing the main graph
+ # for distributed sessions. This is necessary because the TPU must be
+ # initialized before the TPU graph rewrite pass runs. We can't put the
+ # initialization op in the main graph because the main graph also contains
+ # replicate ops created by tpu.shard. If we tried to run initialization from
+ # the main graph, the TPU graph rewrite pass would rewrite the replicate ops
+ # before actually evaluating the initialization ops.
+ #
+ # For distributed sessions, the master may independently restart. After a
+ # master restarts, the rewrite pass runs again when any op in the main graph
+ # runs, so we must reinitialize the system every time the main graph is
+ # finalized.
+ #
+ # Special case: When master_address is unset, we're using DirectSession.
+ # DirectSession resets device state between sessions, and uses
+ # place_pruned_graph. Initialization currently passes state to replication
+ # through the TPU_SYSTEM resource manager. Under DirectSession, this
+ # resource manager gets reset when init_session is closed, so DirectSession
+ # can't initialize here, and must instead initialize from the main graph's
+ # init_ops. This is possible with DirectSession because it uses
+ # place_pruned_graph, which removes unreferenced ops before invoking the
+ # rewrite pass. This makes it possible to run init_ops from the main graph,
+ # which contains both tpu.initialize_system and tpu.shard ops, without first
+ # triggering the TPU graph rewrite. We can't do this for distributed
+ # sessions because they don't support place_pruned_graph.
+ #
+ # TODO(b/110943344) Clean this up as part of the initialize_system dataflow
+ # cleanup. It should be possible to remove the special case for
+ # DirectSession and the other call to initialize_system from
+ # _obtain_topology, when topology info is always explicitly passed from
+ # tpu.initialize_system to tpu.shard, though this requires editing or
+ # rebuilding the main graph each time the master restarts.
+ if ctx.master_address() is None:
+ return
+ with ops.Graph().as_default():
+ logging.info('Init TPU system master_address %s', ctx.master_address())
+ with session_lib.Session(
+ ctx.master_address(),
+ config=ctx.config.session_config) as init_session:
+ run_options = config_pb2.RunOptions(timeout_in_ms=5 * 60 * 1000)
+ init_session.run(
+ tpu.initialize_system(job=ctx.master_job), options=run_options)
+ logging.info('TPU system initialized')
+
+ scaffold.finalize = _finalize
return scaffold