diff options
author | 2017-11-16 11:20:21 -0800 | |
---|---|---|
committer | 2017-11-16 11:26:26 -0800 | |
commit | 7065160c6c67499df859012c55545218aa6a549a (patch) | |
tree | 10245e7285c26c062f461b63df60854affe18d0b /tensorflow/python/training/monitored_session.py | |
parent | e47032ece9b5fb8f5683e1eedb8ee8870bd48022 (diff) |
Plumb worker max_wait_secs arguments up to tf.contrib.train.train.
PiperOrigin-RevId: 175991159
Diffstat (limited to 'tensorflow/python/training/monitored_session.py')
-rw-r--r-- | tensorflow/python/training/monitored_session.py | 22 |
1 files changed, 18 insertions, 4 deletions
diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py index 1f6016a91b..e931555470 100644 --- a/tensorflow/python/training/monitored_session.py +++ b/tensorflow/python/training/monitored_session.py @@ -281,7 +281,8 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name save_summaries_secs=USE_DEFAULT, config=None, stop_grace_period_secs=120, - log_step_count_steps=100): + log_step_count_steps=100, + max_wait_secs=7200): """Creates a `MonitoredSession` for training. For a chief, this utility sets proper session initializer/restorer. It also @@ -320,6 +321,10 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name `close()` has been called. log_step_count_steps: The frequency, in number of global steps, that the global step/sec is logged. + max_wait_secs: Maximum time workers should wait for the session to + become available. This should be kept relatively short to help detect + incorrect code, but sometimes may need to be increased if the chief takes + a while to start up. Returns: A `MonitoredSession` object. @@ -335,7 +340,10 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name scaffold = scaffold or Scaffold() if not is_chief: session_creator = WorkerSessionCreator( - scaffold=scaffold, master=master, config=config) + scaffold=scaffold, + master=master, + config=config, + max_wait_secs=max_wait_secs) return MonitoredSession(session_creator=session_creator, hooks=hooks or [], stop_grace_period_secs=stop_grace_period_secs) @@ -434,7 +442,11 @@ class ChiefSessionCreator(SessionCreator): class WorkerSessionCreator(SessionCreator): """Creates a tf.Session for a worker.""" - def __init__(self, scaffold=None, master='', config=None): + def __init__(self, + scaffold=None, + master='', + config=None, + max_wait_secs=30 * 60): """Initializes a worker session creator. Args: @@ -442,11 +454,13 @@ class WorkerSessionCreator(SessionCreator): not specified a default one is created. It's used to finalize the graph. master: `String` representation of the TensorFlow master to use. config: `ConfigProto` proto used to configure the session. + max_wait_secs: Maximum time to wait for the session to become available. """ self._scaffold = scaffold or Scaffold() self._session_manager = None self._master = master self._config = config + self._max_wait_secs = max_wait_secs def _get_session_manager(self): if self._session_manager: @@ -463,7 +477,7 @@ class WorkerSessionCreator(SessionCreator): self._scaffold.finalize() return self._get_session_manager().wait_for_session( self._master, config=self._config, - max_wait_secs=30 * 60 # Wait up to 30 mins for the session to be ready. + max_wait_secs=self._max_wait_secs ) |