aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/monitored_session.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-14 02:21:04 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-14 02:23:56 -0700
commit83a48e092b6282f7fdbf4b0059eb0da146b68f42 (patch)
treee7e1e8e57fbc8ddc57eed0bcd0c662d7ca693c6f /tensorflow/python/training/monitored_session.py
parent8d9787bed57f1dd5d697ff847cd5598ecc032620 (diff)
Provide the ability to specify, in tf.train.MonitoredTrainingSession(), a separate summary directory.
When set, summary_dir is passed as output directory to StepCounterHook and SummarySaverHook. When unset, the behavior is unchanged and checkpoint_dir is used instead. PiperOrigin-RevId: 200526130
Diffstat (limited to 'tensorflow/python/training/monitored_session.py')
-rw-r--r--tensorflow/python/training/monitored_session.py14
1 files changed, 10 insertions, 4 deletions
diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py
index fece3370f3..7b06bffa4b 100644
--- a/tensorflow/python/training/monitored_session.py
+++ b/tensorflow/python/training/monitored_session.py
@@ -298,7 +298,8 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name
stop_grace_period_secs=120,
log_step_count_steps=100,
max_wait_secs=7200,
- save_checkpoint_steps=USE_DEFAULT):
+ save_checkpoint_steps=USE_DEFAULT,
+ summary_dir=None):
"""Creates a `MonitoredSession` for training.
For a chief, this utility sets proper session initializer/restorer. It also
@@ -348,6 +349,8 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name
`save_checkpoint_steps` and `save_checkpoint_secs` are set to `None`, then
the default checkpoint saver isn't used. If both are provided, then only
`save_checkpoint_secs` is used. Default not enabled.
+ summary_dir: A string. Optional path to a directory where to
+ save summaries. If None, checkpoint_dir is used instead.
Returns:
A `MonitoredSession` object.
@@ -388,11 +391,12 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name
master=master,
config=config)
- if checkpoint_dir:
+ summary_dir = summary_dir or checkpoint_dir
+ if summary_dir:
if log_step_count_steps and log_step_count_steps > 0:
all_hooks.append(
basic_session_run_hooks.StepCounterHook(
- output_dir=checkpoint_dir, every_n_steps=log_step_count_steps))
+ output_dir=summary_dir, every_n_steps=log_step_count_steps))
if (save_summaries_steps and save_summaries_steps > 0) or (
save_summaries_secs and save_summaries_secs > 0):
@@ -400,7 +404,9 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name
scaffold=scaffold,
save_steps=save_summaries_steps,
save_secs=save_summaries_secs,
- output_dir=checkpoint_dir))
+ output_dir=summary_dir))
+
+ if checkpoint_dir:
if (save_checkpoint_secs and save_checkpoint_secs > 0) or (
save_checkpoint_steps and save_checkpoint_steps > 0):
all_hooks.append(basic_session_run_hooks.CheckpointSaverHook(