diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-30 14:00:14 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-30 14:09:08 -0700 |
commit | 48abc110a1a51ef5d6482583c6791c8163873721 (patch) | |
tree | 86b2b5f0ac13e5e288752c13ae6837f59842ccfc /tensorflow/contrib/training | |
parent | feffc075befac53dddc721572493796c8fbffe3c (diff) |
Make SummaryAtEndHook work even if there are no summaries in the graph.
Besides just general resilience to general user code, another motivation is that it
still makes sense to use the hook when there are no summaries in the graph for the side effect of writing out the graph summary.
PiperOrigin-RevId: 210975165
Diffstat (limited to 'tensorflow/contrib/training')
-rw-r--r-- | tensorflow/contrib/training/python/training/evaluation.py | 9 | ||||
-rw-r--r-- | tensorflow/contrib/training/python/training/evaluation_test.py | 31 |
2 files changed, 34 insertions, 6 deletions
diff --git a/tensorflow/contrib/training/python/training/evaluation.py b/tensorflow/contrib/training/python/training/evaluation.py index 01bac891da..16a647bf66 100644 --- a/tensorflow/contrib/training/python/training/evaluation.py +++ b/tensorflow/contrib/training/python/training/evaluation.py @@ -296,6 +296,7 @@ class SummaryAtEndHook(session_run_hook.SessionRunHook): def begin(self): if self._replace_summary_op: + # This can still remain None if there are no summaries. self._summary_op = summary.merge_all() self._global_step = training_util.get_or_create_global_step() @@ -304,10 +305,12 @@ class SummaryAtEndHook(session_run_hook.SessionRunHook): self._summary_writer = summary.FileWriterCache.get(self._log_dir) def end(self, session): - global_step = training_util.global_step(session, self._global_step) - summary_str = session.run(self._summary_op, self._feed_dict) + if self._summary_op is not None: + global_step = training_util.global_step(session, self._global_step) + summary_str = session.run(self._summary_op, self._feed_dict) + if self._summary_writer: + self._summary_writer.add_summary(summary_str, global_step) if self._summary_writer: - self._summary_writer.add_summary(summary_str, global_step) self._summary_writer.flush() diff --git a/tensorflow/contrib/training/python/training/evaluation_test.py b/tensorflow/contrib/training/python/training/evaluation_test.py index ec47fe5d97..ddd135f047 100644 --- a/tensorflow/contrib/training/python/training/evaluation_test.py +++ b/tensorflow/contrib/training/python/training/evaluation_test.py @@ -427,9 +427,11 @@ class EvaluateRepeatedlyTest(test.TestCase): names_to_updates = {'Accuracy': update_op0, 'Another_accuracy': update_op1} return names_to_values, names_to_updates - def _verify_summaries(self, output_dir, names_to_values): + def _verify_events(self, output_dir, names_to_values): """Verifies that the given `names_to_values` are found in the summaries. + Also checks that a GraphDef was written out to the events file. + Args: output_dir: An existing directory where summaries are found. names_to_values: A dictionary of strings to values. @@ -440,7 +442,13 @@ class EvaluateRepeatedlyTest(test.TestCase): self.assertEqual(len(output_filepath), 1) events = summary_iterator.summary_iterator(output_filepath[0]) - summaries = [e.summary for e in events if e.summary.value] + summaries = [] + graph_def = None + for event in events: + if event.summary.value: + summaries.append(event.summary) + elif event.graph_def: + graph_def = event.graph_def values = [] for summary in summaries: for value in summary.value: @@ -448,6 +456,7 @@ class EvaluateRepeatedlyTest(test.TestCase): saved_results = {v.tag: v.simple_value for v in values} for name in names_to_values: self.assertAlmostEqual(names_to_values[name], saved_results[name], 5) + self.assertIsNotNone(graph_def) def testSummariesAreFlushedToDisk(self): checkpoint_dir = os.path.join(self.get_temp_dir(), 'summaries_are_flushed') @@ -475,7 +484,23 @@ class EvaluateRepeatedlyTest(test.TestCase): ], max_number_of_evaluations=1) - self._verify_summaries(logdir, names_to_values) + self._verify_events(logdir, names_to_values) + + def testSummaryAtEndHookWithoutSummaries(self): + logdir = os.path.join(self.get_temp_dir(), + 'summary_at_end_hook_without_summaires') + if gfile.Exists(logdir): + gfile.DeleteRecursively(logdir) + + with ops.Graph().as_default(): + # Purposefully don't add any summaries. The hook will just dump the + # GraphDef event. + hook = evaluation.SummaryAtEndHook(log_dir=logdir) + hook.begin() + with self.cached_session() as session: + hook.after_create_session(session, None) + hook.end(session) + self._verify_events(logdir, {}) if __name__ == '__main__': |