aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/monitored_session.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-11-16 11:20:21 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-16 11:26:26 -0800
commit7065160c6c67499df859012c55545218aa6a549a (patch)
tree10245e7285c26c062f461b63df60854affe18d0b /tensorflow/python/training/monitored_session.py
parente47032ece9b5fb8f5683e1eedb8ee8870bd48022 (diff)
Plumb worker max_wait_secs arguments up to tf.contrib.train.train.
PiperOrigin-RevId: 175991159
Diffstat (limited to 'tensorflow/python/training/monitored_session.py')
-rw-r--r--tensorflow/python/training/monitored_session.py22
1 files changed, 18 insertions, 4 deletions
diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py
index 1f6016a91b..e931555470 100644
--- a/tensorflow/python/training/monitored_session.py
+++ b/tensorflow/python/training/monitored_session.py
@@ -281,7 +281,8 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name
save_summaries_secs=USE_DEFAULT,
config=None,
stop_grace_period_secs=120,
- log_step_count_steps=100):
+ log_step_count_steps=100,
+ max_wait_secs=7200):
"""Creates a `MonitoredSession` for training.
For a chief, this utility sets proper session initializer/restorer. It also
@@ -320,6 +321,10 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name
`close()` has been called.
log_step_count_steps: The frequency, in number of global steps, that the
global step/sec is logged.
+ max_wait_secs: Maximum time workers should wait for the session to
+ become available. This should be kept relatively short to help detect
+ incorrect code, but sometimes may need to be increased if the chief takes
+ a while to start up.
Returns:
A `MonitoredSession` object.
@@ -335,7 +340,10 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name
scaffold = scaffold or Scaffold()
if not is_chief:
session_creator = WorkerSessionCreator(
- scaffold=scaffold, master=master, config=config)
+ scaffold=scaffold,
+ master=master,
+ config=config,
+ max_wait_secs=max_wait_secs)
return MonitoredSession(session_creator=session_creator, hooks=hooks or [],
stop_grace_period_secs=stop_grace_period_secs)
@@ -434,7 +442,11 @@ class ChiefSessionCreator(SessionCreator):
class WorkerSessionCreator(SessionCreator):
"""Creates a tf.Session for a worker."""
- def __init__(self, scaffold=None, master='', config=None):
+ def __init__(self,
+ scaffold=None,
+ master='',
+ config=None,
+ max_wait_secs=30 * 60):
"""Initializes a worker session creator.
Args:
@@ -442,11 +454,13 @@ class WorkerSessionCreator(SessionCreator):
not specified a default one is created. It's used to finalize the graph.
master: `String` representation of the TensorFlow master to use.
config: `ConfigProto` proto used to configure the session.
+ max_wait_secs: Maximum time to wait for the session to become available.
"""
self._scaffold = scaffold or Scaffold()
self._session_manager = None
self._master = master
self._config = config
+ self._max_wait_secs = max_wait_secs
def _get_session_manager(self):
if self._session_manager:
@@ -463,7 +477,7 @@ class WorkerSessionCreator(SessionCreator):
self._scaffold.finalize()
return self._get_session_manager().wait_for_session(
self._master, config=self._config,
- max_wait_secs=30 * 60 # Wait up to 30 mins for the session to be ready.
+ max_wait_secs=self._max_wait_secs
)