From 93763d60644117596ed90e8f4a79c277b59986e0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 16 Oct 2016 03:31:38 -0800 Subject: Fix AttributeErrors in CheckpointSaverHook and SummarySaverHook. When CheckpointSaverHook is not given a saver or scaffold an AttributeError will be raised by _save because it will try to access self._scaffold.saver. Similarly, when SummarySaverHook is not given a scaffold or a summary_op then an AttributeError will be raised in before_run because it will try to access self._scaffold.summary_op. Change: 136282986 --- tensorflow/python/training/basic_session_run_hooks.py | 17 ++++++++++++++--- .../python/training/basic_session_run_hooks_test.py | 18 ++++++++++++++++++ 2 files changed, 32 insertions(+), 3 deletions(-) (limited to 'tensorflow') diff --git a/tensorflow/python/training/basic_session_run_hooks.py b/tensorflow/python/training/basic_session_run_hooks.py index e24ec9faa1..5d00672994 100644 --- a/tensorflow/python/training/basic_session_run_hooks.py +++ b/tensorflow/python/training/basic_session_run_hooks.py @@ -208,8 +208,12 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook): Raises: ValueError: One of `save_steps` or `save_secs` should be set. + ValueError: Exactly one of saver or scaffold should be set. """ logging.info("Create CheckpointSaverHook.") + if ((saver is None and scaffold is None) or + (saver is not None and scaffold is not None)): + raise ValueError("Exactly one of saver or scaffold must be provided.") self._saver = saver self._checkpoint_dir = checkpoint_dir self._summary_writer = SummaryWriterCache.get(checkpoint_dir) @@ -255,10 +259,10 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook): def _save(self, step, session): """Saves the latest checkpoint.""" logging.info("Saving checkpoints for %d into %s.", step, self._save_path) - if self._saver is None: - self._scaffold.saver.save(session, self._save_path, global_step=step) - else: + if self._saver is not None: self._saver.save(session, self._save_path, global_step=step) + elif self._scaffold is not None: + self._scaffold.saver.save(session, self._save_path, global_step=step) self._summary_writer.add_session_log( SessionLog( status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path), @@ -370,7 +374,14 @@ class SummarySaverHook(session_run_hook.SessionRunHook): summary_op: `Tensor` of type `string`. A serialized `Summary` protocol buffer, as output by TF summary methods like `scalar_summary` or `merge_all_summaries`. + + Raises: + ValueError: Exactly one of scaffold or summary_op should be set. """ + if ((scaffold is None and summary_op is None) or + (scaffold is not None and summary_op is not None)): + raise ValueError( + "Exactly one of scaffold or summary_op must be provided.") self._summary_op = summary_op self._summary_writer = summary_writer if summary_writer is None and output_dir: diff --git a/tensorflow/python/training/basic_session_run_hooks_test.py b/tensorflow/python/training/basic_session_run_hooks_test.py index 858f4bc1a8..4a346611d3 100644 --- a/tensorflow/python/training/basic_session_run_hooks_test.py +++ b/tensorflow/python/training/basic_session_run_hooks_test.py @@ -173,6 +173,15 @@ class CheckpointSaverHookTest(tf.test.TestCase): def tearDown(self): shutil.rmtree(self.model_dir, ignore_errors=True) + def test_raise_when_saver_and_scaffold_both_missing(self): + with self.assertRaises(ValueError): + tf.train.CheckpointSaverHook(self.model_dir) + + def test_raise_when_saver_and_scaffold_both_present(self): + with self.assertRaises(ValueError): + tf.train.CheckpointSaverHook( + self.model_dir, saver=self.scaffold.saver, scaffold=self.scaffold) + def test_raise_in_both_secs_and_steps(self): with self.assertRaises(ValueError): tf.train.CheckpointSaverHook(self.model_dir, save_secs=10, save_steps=20) @@ -329,6 +338,15 @@ class SummarySaverHookTest(tf.test.TestCase): global_step = tf.contrib.framework.get_or_create_global_step() self.train_op = tf.assign_add(global_step, 1) + def test_raise_when_scaffold_and_summary_op_both_missing(self): + with self.assertRaises(ValueError): + tf.train.SummarySaverHook() + + def test_raise_when_scaffold_and_summary_op_both_present(self): + with self.assertRaises(ValueError): + tf.train.SummarySaverHook(scaffold=tf.train.Scaffold(), + summary_op=self.summary_op) + def test_raise_in_both_secs_and_steps(self): with self.assertRaises(ValueError): tf.train.SummarySaverHook( -- cgit v1.2.3