aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/training
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-30 14:00:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-30 14:09:08 -0700
commit48abc110a1a51ef5d6482583c6791c8163873721 (patch)
tree86b2b5f0ac13e5e288752c13ae6837f59842ccfc /tensorflow/contrib/training
parentfeffc075befac53dddc721572493796c8fbffe3c (diff)
Make SummaryAtEndHook work even if there are no summaries in the graph.
Besides just general resilience to general user code, another motivation is that it still makes sense to use the hook when there are no summaries in the graph for the side effect of writing out the graph summary. PiperOrigin-RevId: 210975165
Diffstat (limited to 'tensorflow/contrib/training')
-rw-r--r--tensorflow/contrib/training/python/training/evaluation.py9
-rw-r--r--tensorflow/contrib/training/python/training/evaluation_test.py31
2 files changed, 34 insertions, 6 deletions
diff --git a/tensorflow/contrib/training/python/training/evaluation.py b/tensorflow/contrib/training/python/training/evaluation.py
index 01bac891da..16a647bf66 100644
--- a/tensorflow/contrib/training/python/training/evaluation.py
+++ b/tensorflow/contrib/training/python/training/evaluation.py
@@ -296,6 +296,7 @@ class SummaryAtEndHook(session_run_hook.SessionRunHook):
def begin(self):
if self._replace_summary_op:
+ # This can still remain None if there are no summaries.
self._summary_op = summary.merge_all()
self._global_step = training_util.get_or_create_global_step()
@@ -304,10 +305,12 @@ class SummaryAtEndHook(session_run_hook.SessionRunHook):
self._summary_writer = summary.FileWriterCache.get(self._log_dir)
def end(self, session):
- global_step = training_util.global_step(session, self._global_step)
- summary_str = session.run(self._summary_op, self._feed_dict)
+ if self._summary_op is not None:
+ global_step = training_util.global_step(session, self._global_step)
+ summary_str = session.run(self._summary_op, self._feed_dict)
+ if self._summary_writer:
+ self._summary_writer.add_summary(summary_str, global_step)
if self._summary_writer:
- self._summary_writer.add_summary(summary_str, global_step)
self._summary_writer.flush()
diff --git a/tensorflow/contrib/training/python/training/evaluation_test.py b/tensorflow/contrib/training/python/training/evaluation_test.py
index ec47fe5d97..ddd135f047 100644
--- a/tensorflow/contrib/training/python/training/evaluation_test.py
+++ b/tensorflow/contrib/training/python/training/evaluation_test.py
@@ -427,9 +427,11 @@ class EvaluateRepeatedlyTest(test.TestCase):
names_to_updates = {'Accuracy': update_op0, 'Another_accuracy': update_op1}
return names_to_values, names_to_updates
- def _verify_summaries(self, output_dir, names_to_values):
+ def _verify_events(self, output_dir, names_to_values):
"""Verifies that the given `names_to_values` are found in the summaries.
+ Also checks that a GraphDef was written out to the events file.
+
Args:
output_dir: An existing directory where summaries are found.
names_to_values: A dictionary of strings to values.
@@ -440,7 +442,13 @@ class EvaluateRepeatedlyTest(test.TestCase):
self.assertEqual(len(output_filepath), 1)
events = summary_iterator.summary_iterator(output_filepath[0])
- summaries = [e.summary for e in events if e.summary.value]
+ summaries = []
+ graph_def = None
+ for event in events:
+ if event.summary.value:
+ summaries.append(event.summary)
+ elif event.graph_def:
+ graph_def = event.graph_def
values = []
for summary in summaries:
for value in summary.value:
@@ -448,6 +456,7 @@ class EvaluateRepeatedlyTest(test.TestCase):
saved_results = {v.tag: v.simple_value for v in values}
for name in names_to_values:
self.assertAlmostEqual(names_to_values[name], saved_results[name], 5)
+ self.assertIsNotNone(graph_def)
def testSummariesAreFlushedToDisk(self):
checkpoint_dir = os.path.join(self.get_temp_dir(), 'summaries_are_flushed')
@@ -475,7 +484,23 @@ class EvaluateRepeatedlyTest(test.TestCase):
],
max_number_of_evaluations=1)
- self._verify_summaries(logdir, names_to_values)
+ self._verify_events(logdir, names_to_values)
+
+ def testSummaryAtEndHookWithoutSummaries(self):
+ logdir = os.path.join(self.get_temp_dir(),
+ 'summary_at_end_hook_without_summaires')
+ if gfile.Exists(logdir):
+ gfile.DeleteRecursively(logdir)
+
+ with ops.Graph().as_default():
+ # Purposefully don't add any summaries. The hook will just dump the
+ # GraphDef event.
+ hook = evaluation.SummaryAtEndHook(log_dir=logdir)
+ hook.begin()
+ with self.cached_session() as session:
+ hook.after_create_session(session, None)
+ hook.end(session)
+ self._verify_events(logdir, {})
if __name__ == '__main__':