aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tpu
diff options
context:
space:
mode:
authorGravatar Russell Power <power@google.com>2018-09-27 14:03:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-27 14:10:40 -0700
commit2fb9377a5ec610b8eff853fd1d2d53eabf711eda (patch)
tree9ecd96e8a4c39caa31e769ca34f49333b272ec48 /tensorflow/contrib/tpu
parent5220e565b7cc32a5f757896c76c7d57c33bcd323 (diff)
Enable worker heartbeat polling for all available workers.
PiperOrigin-RevId: 214831772
Diffstat (limited to 'tensorflow/contrib/tpu')
-rw-r--r--tensorflow/contrib/tpu/python/tpu/session_support.py52
1 files changed, 32 insertions, 20 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/session_support.py b/tensorflow/contrib/tpu/python/tpu/session_support.py
index 3e91e2df32..24b9bd136b 100644
--- a/tensorflow/contrib/tpu/python/tpu/session_support.py
+++ b/tensorflow/contrib/tpu/python/tpu/session_support.py
@@ -41,6 +41,25 @@ class CoordinatorShutdownException(Exception):
pass
+def _make_heartbeat_op(session, device, request_ph):
+ """Return a heartbeat op or None if heartbeats are not supported by device."""
+ try:
+ with ops.device(device):
+ heartbeat_op = tpu_ops.worker_heartbeat(request_ph)
+ request = event_pb2.WorkerHeartbeatRequest()
+ options = config_pb2.RunOptions(timeout_in_ms=5000)
+ session.run(
+ heartbeat_op,
+ feed_dict={request_ph: request.SerializeToString()},
+ options=options)
+ return heartbeat_op
+ except errors.InvalidArgumentError as _:
+ return None
+ except errors.DeadlineExceededError as _:
+ logging.warning('Timeout connecting to %s when testing heartbeat', device)
+ return None
+
+
class WorkerHeartbeatManager(object):
"""Manages the status/heartbeat monitor for a set of workers."""
@@ -72,30 +91,27 @@ class WorkerHeartbeatManager(object):
name='worker_heartbeat_request', dtype=dtypes.string)
heartbeat_ops = []
+ kept_devices = []
for device in devices:
- with ops.device(device):
- heartbeat_ops.append(tpu_ops.worker_heartbeat(request_placeholder))
+ heartbeat_op = _make_heartbeat_op(session, device, request_placeholder)
+ if heartbeat_op is not None:
+ kept_devices.append(device)
+ heartbeat_ops.append(heartbeat_op)
+ else:
+ logging.warning('Heartbeat support not available for %s', device)
- return WorkerHeartbeatManager(session, devices, heartbeat_ops,
+ return WorkerHeartbeatManager(session, kept_devices, heartbeat_ops,
request_placeholder)
- def heartbeat_supported(self):
- """Returns True if heartbeat operations are supported on all workers."""
- try:
- # Send ping to verify worker has heartbeat support.
- self.ping()
- return True
- except errors.InvalidArgumentError as _:
- return False
+ def num_workers(self):
+ return len(self._devices)
def configure(self, message):
"""Configure heartbeat manager for all devices.
Args:
message: `event_pb2.WorkerHeartbeatRequest`
-
Returns: `None`
-
"""
logging.info('Configuring worker heartbeat: %s',
text_format.MessageToString(message))
@@ -184,7 +200,6 @@ class WatchdogManager(threading.Thread):
"""Initialize a watchdog manager.
Args:
-
session: Session connected to worker devices. A cloned session and graph
will be created for managing worker pings.
devices: Set of devices to monitor. If none, all workers will be
@@ -277,16 +292,14 @@ class GracefulShutdownHook(session_run_hook.SessionRunHook):
target=training_session.sess_str, graph=self._graph)
self._workers = WorkerHeartbeatManager.from_devices(
self._session, all_worker_devices(self._session))
- self._heartbeat_supported = self._workers.heartbeat_supported()
+ self._heartbeat_supported = self._workers.num_workers() > 0
if self._heartbeat_supported:
self._workers.configure(
event_pb2.WorkerHeartbeatRequest(
shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR))
else:
logging.warn(
- 'Worker heartbeats not supported by all workers. No failure '
- 'handling will be enabled.'
- )
+ 'No workers support hearbeats. Failure handling will be disabled.')
def saver(self):
if self._saver:
@@ -303,8 +316,7 @@ class GracefulShutdownHook(session_run_hook.SessionRunHook):
logging.error(
'Multiple savers in the SAVERS collection. On-demand checkpointing '
'will be disabled. Pass an explicit `saver` to the constructor to '
- 'override this behavior.'
- )
+ 'override this behavior.')
return None
return savers[0]