diff options
author | Russell Power <power@google.com> | 2018-09-27 14:03:52 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-27 14:10:40 -0700 |
commit | 2fb9377a5ec610b8eff853fd1d2d53eabf711eda (patch) | |
tree | 9ecd96e8a4c39caa31e769ca34f49333b272ec48 /tensorflow/contrib/tpu | |
parent | 5220e565b7cc32a5f757896c76c7d57c33bcd323 (diff) |
Enable worker heartbeat polling for all available workers.
PiperOrigin-RevId: 214831772
Diffstat (limited to 'tensorflow/contrib/tpu')
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/session_support.py | 52 |
1 files changed, 32 insertions, 20 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/session_support.py b/tensorflow/contrib/tpu/python/tpu/session_support.py index 3e91e2df32..24b9bd136b 100644 --- a/tensorflow/contrib/tpu/python/tpu/session_support.py +++ b/tensorflow/contrib/tpu/python/tpu/session_support.py @@ -41,6 +41,25 @@ class CoordinatorShutdownException(Exception): pass +def _make_heartbeat_op(session, device, request_ph): + """Return a heartbeat op or None if heartbeats are not supported by device.""" + try: + with ops.device(device): + heartbeat_op = tpu_ops.worker_heartbeat(request_ph) + request = event_pb2.WorkerHeartbeatRequest() + options = config_pb2.RunOptions(timeout_in_ms=5000) + session.run( + heartbeat_op, + feed_dict={request_ph: request.SerializeToString()}, + options=options) + return heartbeat_op + except errors.InvalidArgumentError as _: + return None + except errors.DeadlineExceededError as _: + logging.warning('Timeout connecting to %s when testing heartbeat', device) + return None + + class WorkerHeartbeatManager(object): """Manages the status/heartbeat monitor for a set of workers.""" @@ -72,30 +91,27 @@ class WorkerHeartbeatManager(object): name='worker_heartbeat_request', dtype=dtypes.string) heartbeat_ops = [] + kept_devices = [] for device in devices: - with ops.device(device): - heartbeat_ops.append(tpu_ops.worker_heartbeat(request_placeholder)) + heartbeat_op = _make_heartbeat_op(session, device, request_placeholder) + if heartbeat_op is not None: + kept_devices.append(device) + heartbeat_ops.append(heartbeat_op) + else: + logging.warning('Heartbeat support not available for %s', device) - return WorkerHeartbeatManager(session, devices, heartbeat_ops, + return WorkerHeartbeatManager(session, kept_devices, heartbeat_ops, request_placeholder) - def heartbeat_supported(self): - """Returns True if heartbeat operations are supported on all workers.""" - try: - # Send ping to verify worker has heartbeat support. - self.ping() - return True - except errors.InvalidArgumentError as _: - return False + def num_workers(self): + return len(self._devices) def configure(self, message): """Configure heartbeat manager for all devices. Args: message: `event_pb2.WorkerHeartbeatRequest` - Returns: `None` - """ logging.info('Configuring worker heartbeat: %s', text_format.MessageToString(message)) @@ -184,7 +200,6 @@ class WatchdogManager(threading.Thread): """Initialize a watchdog manager. Args: - session: Session connected to worker devices. A cloned session and graph will be created for managing worker pings. devices: Set of devices to monitor. If none, all workers will be @@ -277,16 +292,14 @@ class GracefulShutdownHook(session_run_hook.SessionRunHook): target=training_session.sess_str, graph=self._graph) self._workers = WorkerHeartbeatManager.from_devices( self._session, all_worker_devices(self._session)) - self._heartbeat_supported = self._workers.heartbeat_supported() + self._heartbeat_supported = self._workers.num_workers() > 0 if self._heartbeat_supported: self._workers.configure( event_pb2.WorkerHeartbeatRequest( shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR)) else: logging.warn( - 'Worker heartbeats not supported by all workers. No failure ' - 'handling will be enabled.' - ) + 'No workers support hearbeats. Failure handling will be disabled.') def saver(self): if self._saver: @@ -303,8 +316,7 @@ class GracefulShutdownHook(session_run_hook.SessionRunHook): logging.error( 'Multiple savers in the SAVERS collection. On-demand checkpointing ' 'will be disabled. Pass an explicit `saver` to the constructor to ' - 'override this behavior.' - ) + 'override this behavior.') return None return savers[0] |