diff options
author | Shanqing Cai <cais@google.com> | 2017-11-08 19:01:15 -0800 |
---|---|---|
committer | Andrew Selle <aselle@andyselle.com> | 2017-11-10 16:14:38 -0800 |
commit | 17411ee8e7569085c475e8f0bd3f6677a9d44f77 (patch) | |
tree | 60b9b6b9dd11fe6875e523134715c977fae4d795 /tensorflow/contrib/slim | |
parent | 9d86dc076f74cc9e2683f8ad789930408b0919f7 (diff) |
Add hooks keyword argument to slim evaluate_once
to enable TFDBG debugging of slim.evaluation.evaluate_once()
Fixes: #13444
PiperOrigin-RevId: 175101022
Diffstat (limited to 'tensorflow/contrib/slim')
-rw-r--r-- | tensorflow/contrib/slim/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/contrib/slim/python/slim/evaluation.py | 15 | ||||
-rw-r--r-- | tensorflow/contrib/slim/python/slim/evaluation_test.py | 46 |
3 files changed, 53 insertions, 10 deletions
diff --git a/tensorflow/contrib/slim/BUILD b/tensorflow/contrib/slim/BUILD index 23c23af2f4..c2f106c2b2 100644 --- a/tensorflow/contrib/slim/BUILD +++ b/tensorflow/contrib/slim/BUILD @@ -39,6 +39,8 @@ py_test( "//tensorflow/python:summary", "//tensorflow/python:training", "//tensorflow/python:variables", + "//tensorflow/python/debug:debug_data", + "//tensorflow/python/debug:hooks", "//third_party/py/numpy", ], ) diff --git a/tensorflow/contrib/slim/python/slim/evaluation.py b/tensorflow/contrib/slim/python/slim/evaluation.py index 2d4b08df61..cdb720b36b 100644 --- a/tensorflow/contrib/slim/python/slim/evaluation.py +++ b/tensorflow/contrib/slim/python/slim/evaluation.py @@ -153,7 +153,8 @@ def evaluate_once(master, summary_op=_USE_DEFAULT, summary_op_feed_dict=None, variables_to_restore=None, - session_config=None): + session_config=None, + hooks=None): """Evaluates the model at the given checkpoint path. Args: @@ -177,6 +178,8 @@ def evaluate_once(master, slim.variables.GetVariablesToRestore() is used. session_config: An instance of `tf.ConfigProto` that will be used to configure the `Session`. If left as `None`, the default will be used. + hooks: A list of additional `SessionRunHook` objects to pass during the + evaluation. Returns: The value of `final_op` or `None` if `final_op` is `None`. @@ -184,11 +187,13 @@ def evaluate_once(master, if summary_op == _USE_DEFAULT: summary_op = summary.merge_all() - hooks = [evaluation.StopAfterNEvalsHook(num_evals),] + all_hooks = [evaluation.StopAfterNEvalsHook(num_evals),] if summary_op is not None: - hooks.append(evaluation.SummaryAtEndHook( + all_hooks.append(evaluation.SummaryAtEndHook( log_dir=logdir, summary_op=summary_op, feed_dict=summary_op_feed_dict)) + if hooks is not None: + all_hooks.extend(hooks) saver = None if variables_to_restore is not None: @@ -203,7 +208,7 @@ def evaluate_once(master, feed_dict=eval_op_feed_dict, final_ops=final_op, final_ops_feed_dict=final_op_feed_dict, - hooks=hooks, + hooks=all_hooks, config=session_config) @@ -256,7 +261,7 @@ def evaluation_loop(master, configure the `Session`. If left as `None`, the default will be used. timeout: The maximum amount of time to wait between checkpoints. If left as `None`, then the process will wait indefinitely. - hooks: A list of additional SessionRunHook objects to pass during + hooks: A list of additional `SessionRunHook` objects to pass during repeated evaluations. Returns: diff --git a/tensorflow/contrib/slim/python/slim/evaluation_test.py b/tensorflow/contrib/slim/python/slim/evaluation_test.py index d9e0f54b72..870f504d10 100644 --- a/tensorflow/contrib/slim/python/slim/evaluation_test.py +++ b/tensorflow/contrib/slim/python/slim/evaluation_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import glob import os +import shutil import time import numpy as np @@ -29,6 +30,8 @@ from tensorflow.contrib.metrics.python.ops import metric_ops from tensorflow.contrib.slim.python.slim import evaluation from tensorflow.contrib.training.python.training import evaluation as evaluation_lib from tensorflow.core.protobuf import saver_pb2 +from tensorflow.python.debug.lib import debug_data +from tensorflow.python.debug.wrappers import hooks from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -230,11 +233,7 @@ class SingleEvaluationTest(test.TestCase): with self.assertRaises(errors.NotFoundError): evaluation.evaluate_once('', checkpoint_path, log_dir) - def testRestoredModelPerformance(self): - checkpoint_path = os.path.join(self.get_temp_dir(), 'model.ckpt') - log_dir = os.path.join(self.get_temp_dir(), 'log_dir1/') - - # First, save out the current model to a checkpoint: + def _prepareCheckpoint(self, checkpoint_path): init_op = control_flow_ops.group(variables.global_variables_initializer(), variables.local_variables_initializer()) saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V1) @@ -242,6 +241,13 @@ class SingleEvaluationTest(test.TestCase): sess.run(init_op) saver.save(sess, checkpoint_path) + def testRestoredModelPerformance(self): + checkpoint_path = os.path.join(self.get_temp_dir(), 'model.ckpt') + log_dir = os.path.join(self.get_temp_dir(), 'log_dir1/') + + # First, save out the current model to a checkpoint: + self._prepareCheckpoint(checkpoint_path) + # Next, determine the metric to evaluate: value_op, update_op = metric_ops.streaming_accuracy(self._predictions, self._labels) @@ -251,6 +257,36 @@ class SingleEvaluationTest(test.TestCase): '', checkpoint_path, log_dir, eval_op=update_op, final_op=value_op) self.assertAlmostEqual(accuracy_value, self._expected_accuracy) + def testAdditionalHooks(self): + checkpoint_path = os.path.join(self.get_temp_dir(), 'model.ckpt') + log_dir = os.path.join(self.get_temp_dir(), 'log_dir1/') + + # First, save out the current model to a checkpoint: + self._prepareCheckpoint(checkpoint_path) + + # Next, determine the metric to evaluate: + value_op, update_op = metric_ops.streaming_accuracy(self._predictions, + self._labels) + + dumping_root = os.path.join(self.get_temp_dir(), 'tfdbg_dump_dir') + dumping_hook = hooks.DumpingDebugHook(dumping_root, log_usage=False) + try: + # Run the evaluation and verify the results: + accuracy_value = evaluation.evaluate_once( + '', checkpoint_path, log_dir, eval_op=update_op, final_op=value_op, + hooks=[dumping_hook]) + self.assertAlmostEqual(accuracy_value, self._expected_accuracy) + + dump = debug_data.DebugDumpDir( + glob.glob(os.path.join(dumping_root, 'run_*'))[0]) + # Here we simply assert that the dumped data has been loaded and is + # non-empty. We do not care about the detailed model-internal tensors or + # their values. + self.assertTrue(dump.dumped_tensor_data) + finally: + if os.path.isdir(dumping_root): + shutil.rmtree(dumping_root) + if __name__ == '__main__': test.main() |