diff options
author | 2018-06-14 02:21:04 -0700 | |
---|---|---|
committer | 2018-06-14 02:23:56 -0700 | |
commit | 83a48e092b6282f7fdbf4b0059eb0da146b68f42 (patch) | |
tree | e7e1e8e57fbc8ddc57eed0bcd0c662d7ca693c6f /tensorflow/python/training/monitored_session.py | |
parent | 8d9787bed57f1dd5d697ff847cd5598ecc032620 (diff) |
Provide the ability to specify, in tf.train.MonitoredTrainingSession(), a separate summary directory.
When set, summary_dir is passed as output directory to StepCounterHook and SummarySaverHook.
When unset, the behavior is unchanged and checkpoint_dir is used instead.
PiperOrigin-RevId: 200526130
Diffstat (limited to 'tensorflow/python/training/monitored_session.py')
-rw-r--r-- | tensorflow/python/training/monitored_session.py | 14 |
1 files changed, 10 insertions, 4 deletions
diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py index fece3370f3..7b06bffa4b 100644 --- a/tensorflow/python/training/monitored_session.py +++ b/tensorflow/python/training/monitored_session.py @@ -298,7 +298,8 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name stop_grace_period_secs=120, log_step_count_steps=100, max_wait_secs=7200, - save_checkpoint_steps=USE_DEFAULT): + save_checkpoint_steps=USE_DEFAULT, + summary_dir=None): """Creates a `MonitoredSession` for training. For a chief, this utility sets proper session initializer/restorer. It also @@ -348,6 +349,8 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name `save_checkpoint_steps` and `save_checkpoint_secs` are set to `None`, then the default checkpoint saver isn't used. If both are provided, then only `save_checkpoint_secs` is used. Default not enabled. + summary_dir: A string. Optional path to a directory where to + save summaries. If None, checkpoint_dir is used instead. Returns: A `MonitoredSession` object. @@ -388,11 +391,12 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name master=master, config=config) - if checkpoint_dir: + summary_dir = summary_dir or checkpoint_dir + if summary_dir: if log_step_count_steps and log_step_count_steps > 0: all_hooks.append( basic_session_run_hooks.StepCounterHook( - output_dir=checkpoint_dir, every_n_steps=log_step_count_steps)) + output_dir=summary_dir, every_n_steps=log_step_count_steps)) if (save_summaries_steps and save_summaries_steps > 0) or ( save_summaries_secs and save_summaries_secs > 0): @@ -400,7 +404,9 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name scaffold=scaffold, save_steps=save_summaries_steps, save_secs=save_summaries_secs, - output_dir=checkpoint_dir)) + output_dir=summary_dir)) + + if checkpoint_dir: if (save_checkpoint_secs and save_checkpoint_secs > 0) or ( save_checkpoint_steps and save_checkpoint_steps > 0): all_hooks.append(basic_session_run_hooks.CheckpointSaverHook( |