aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/monitored_session.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/monitored_session.py')
-rw-r--r--tensorflow/python/training/monitored_session.py33
1 files changed, 27 insertions, 6 deletions
diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py
index 6c5c9e01a7..4ce6f6d002 100644
--- a/tensorflow/python/training/monitored_session.py
+++ b/tensorflow/python/training/monitored_session.py
@@ -281,13 +281,14 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name
scaffold=None,
hooks=None,
chief_only_hooks=None,
- save_checkpoint_secs=600,
+ save_checkpoint_secs=USE_DEFAULT,
save_summaries_steps=USE_DEFAULT,
save_summaries_secs=USE_DEFAULT,
config=None,
stop_grace_period_secs=120,
log_step_count_steps=100,
- max_wait_secs=7200):
+ max_wait_secs=7200,
+ save_checkpoint_steps=USE_DEFAULT):
"""Creates a `MonitoredSession` for training.
For a chief, this utility sets proper session initializer/restorer. It also
@@ -310,8 +311,10 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name
chief_only_hooks: list of `SessionRunHook` objects. Activate these hooks if
`is_chief==True`, ignore otherwise.
save_checkpoint_secs: The frequency, in seconds, that a checkpoint is saved
- using a default checkpoint saver. If `save_checkpoint_secs` is set to
- `None`, then the default checkpoint saver isn't used.
+ using a default checkpoint saver. If both `save_checkpoint_steps` and
+ `save_checkpoint_secs` are set to `None`, then the default checkpoint
+ saver isn't used. If both are provided, then only `save_checkpoint_secs`
+ is used. Default 600.
save_summaries_steps: The frequency, in number of global steps, that the
summaries are written to disk using a default summary saver. If both
`save_summaries_steps` and `save_summaries_secs` are set to `None`, then
@@ -330,6 +333,11 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name
become available. This should be kept relatively short to help detect
incorrect code, but sometimes may need to be increased if the chief takes
a while to start up.
+ save_checkpoint_steps: The frequency, in number of global steps, that a
+ checkpoint is saved using a default checkpoint saver. If both
+ `save_checkpoint_steps` and `save_checkpoint_secs` are set to `None`, then
+ the default checkpoint saver isn't used. If both are provided, then only
+ `save_checkpoint_secs` is used. Default not enabled.
Returns:
A `MonitoredSession` object.
@@ -342,6 +350,15 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name
elif save_summaries_steps == USE_DEFAULT:
save_summaries_steps = None
+ if (save_checkpoint_steps == USE_DEFAULT and
+ save_checkpoint_secs == USE_DEFAULT):
+ save_checkpoint_steps = None
+ save_checkpoint_secs = 600
+ elif save_checkpoint_secs == USE_DEFAULT:
+ save_checkpoint_secs = None
+ elif save_checkpoint_steps == USE_DEFAULT:
+ save_checkpoint_steps = None
+
scaffold = scaffold or Scaffold()
if not is_chief:
session_creator = WorkerSessionCreator(
@@ -374,9 +391,13 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name
save_steps=save_summaries_steps,
save_secs=save_summaries_secs,
output_dir=checkpoint_dir))
- if save_checkpoint_secs and save_checkpoint_secs > 0:
+ if (save_checkpoint_secs and save_checkpoint_secs > 0) or (
+ save_checkpoint_steps and save_checkpoint_steps > 0):
all_hooks.append(basic_session_run_hooks.CheckpointSaverHook(
- checkpoint_dir, save_secs=save_checkpoint_secs, scaffold=scaffold))
+ checkpoint_dir,
+ save_steps=save_checkpoint_steps,
+ save_secs=save_checkpoint_secs,
+ scaffold=scaffold))
if hooks:
all_hooks.extend(hooks)