aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-10-16 03:31:38 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-16 04:48:52 -0700
commit93763d60644117596ed90e8f4a79c277b59986e0 (patch)
tree39d8942db1882954346f0b7be503cfceeea66ef4 /tensorflow
parent55be1727e2f2196055dd29813485930c48583b41 (diff)
Fix AttributeErrors in CheckpointSaverHook and SummarySaverHook.
When CheckpointSaverHook is not given a saver or scaffold an AttributeError will be raised by _save because it will try to access self._scaffold.saver. Similarly, when SummarySaverHook is not given a scaffold or a summary_op then an AttributeError will be raised in before_run because it will try to access self._scaffold.summary_op. Change: 136282986
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/python/training/basic_session_run_hooks.py17
-rw-r--r--tensorflow/python/training/basic_session_run_hooks_test.py18
2 files changed, 32 insertions, 3 deletions
diff --git a/tensorflow/python/training/basic_session_run_hooks.py b/tensorflow/python/training/basic_session_run_hooks.py
index e24ec9faa1..5d00672994 100644
--- a/tensorflow/python/training/basic_session_run_hooks.py
+++ b/tensorflow/python/training/basic_session_run_hooks.py
@@ -208,8 +208,12 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
Raises:
ValueError: One of `save_steps` or `save_secs` should be set.
+ ValueError: Exactly one of saver or scaffold should be set.
"""
logging.info("Create CheckpointSaverHook.")
+ if ((saver is None and scaffold is None) or
+ (saver is not None and scaffold is not None)):
+ raise ValueError("Exactly one of saver or scaffold must be provided.")
self._saver = saver
self._checkpoint_dir = checkpoint_dir
self._summary_writer = SummaryWriterCache.get(checkpoint_dir)
@@ -255,10 +259,10 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
def _save(self, step, session):
"""Saves the latest checkpoint."""
logging.info("Saving checkpoints for %d into %s.", step, self._save_path)
- if self._saver is None:
- self._scaffold.saver.save(session, self._save_path, global_step=step)
- else:
+ if self._saver is not None:
self._saver.save(session, self._save_path, global_step=step)
+ elif self._scaffold is not None:
+ self._scaffold.saver.save(session, self._save_path, global_step=step)
self._summary_writer.add_session_log(
SessionLog(
status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path),
@@ -370,7 +374,14 @@ class SummarySaverHook(session_run_hook.SessionRunHook):
summary_op: `Tensor` of type `string`. A serialized `Summary` protocol
buffer, as output by TF summary methods like `scalar_summary` or
`merge_all_summaries`.
+
+ Raises:
+ ValueError: Exactly one of scaffold or summary_op should be set.
"""
+ if ((scaffold is None and summary_op is None) or
+ (scaffold is not None and summary_op is not None)):
+ raise ValueError(
+ "Exactly one of scaffold or summary_op must be provided.")
self._summary_op = summary_op
self._summary_writer = summary_writer
if summary_writer is None and output_dir:
diff --git a/tensorflow/python/training/basic_session_run_hooks_test.py b/tensorflow/python/training/basic_session_run_hooks_test.py
index 858f4bc1a8..4a346611d3 100644
--- a/tensorflow/python/training/basic_session_run_hooks_test.py
+++ b/tensorflow/python/training/basic_session_run_hooks_test.py
@@ -173,6 +173,15 @@ class CheckpointSaverHookTest(tf.test.TestCase):
def tearDown(self):
shutil.rmtree(self.model_dir, ignore_errors=True)
+ def test_raise_when_saver_and_scaffold_both_missing(self):
+ with self.assertRaises(ValueError):
+ tf.train.CheckpointSaverHook(self.model_dir)
+
+ def test_raise_when_saver_and_scaffold_both_present(self):
+ with self.assertRaises(ValueError):
+ tf.train.CheckpointSaverHook(
+ self.model_dir, saver=self.scaffold.saver, scaffold=self.scaffold)
+
def test_raise_in_both_secs_and_steps(self):
with self.assertRaises(ValueError):
tf.train.CheckpointSaverHook(self.model_dir, save_secs=10, save_steps=20)
@@ -329,6 +338,15 @@ class SummarySaverHookTest(tf.test.TestCase):
global_step = tf.contrib.framework.get_or_create_global_step()
self.train_op = tf.assign_add(global_step, 1)
+ def test_raise_when_scaffold_and_summary_op_both_missing(self):
+ with self.assertRaises(ValueError):
+ tf.train.SummarySaverHook()
+
+ def test_raise_when_scaffold_and_summary_op_both_present(self):
+ with self.assertRaises(ValueError):
+ tf.train.SummarySaverHook(scaffold=tf.train.Scaffold(),
+ summary_op=self.summary_op)
+
def test_raise_in_both_secs_and_steps(self):
with self.assertRaises(ValueError):
tf.train.SummarySaverHook(