aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/slim
diff options
context:
space:
mode:
authorGravatar Shanqing Cai <cais@google.com>2017-11-08 19:01:15 -0800
committerGravatar Andrew Selle <aselle@andyselle.com>2017-11-10 16:14:38 -0800
commit17411ee8e7569085c475e8f0bd3f6677a9d44f77 (patch)
tree60b9b6b9dd11fe6875e523134715c977fae4d795 /tensorflow/contrib/slim
parent9d86dc076f74cc9e2683f8ad789930408b0919f7 (diff)
Add hooks keyword argument to slim evaluate_once
to enable TFDBG debugging of slim.evaluation.evaluate_once() Fixes: #13444 PiperOrigin-RevId: 175101022
Diffstat (limited to 'tensorflow/contrib/slim')
-rw-r--r--tensorflow/contrib/slim/BUILD2
-rw-r--r--tensorflow/contrib/slim/python/slim/evaluation.py15
-rw-r--r--tensorflow/contrib/slim/python/slim/evaluation_test.py46
3 files changed, 53 insertions, 10 deletions
diff --git a/tensorflow/contrib/slim/BUILD b/tensorflow/contrib/slim/BUILD
index 23c23af2f4..c2f106c2b2 100644
--- a/tensorflow/contrib/slim/BUILD
+++ b/tensorflow/contrib/slim/BUILD
@@ -39,6 +39,8 @@ py_test(
"//tensorflow/python:summary",
"//tensorflow/python:training",
"//tensorflow/python:variables",
+ "//tensorflow/python/debug:debug_data",
+ "//tensorflow/python/debug:hooks",
"//third_party/py/numpy",
],
)
diff --git a/tensorflow/contrib/slim/python/slim/evaluation.py b/tensorflow/contrib/slim/python/slim/evaluation.py
index 2d4b08df61..cdb720b36b 100644
--- a/tensorflow/contrib/slim/python/slim/evaluation.py
+++ b/tensorflow/contrib/slim/python/slim/evaluation.py
@@ -153,7 +153,8 @@ def evaluate_once(master,
summary_op=_USE_DEFAULT,
summary_op_feed_dict=None,
variables_to_restore=None,
- session_config=None):
+ session_config=None,
+ hooks=None):
"""Evaluates the model at the given checkpoint path.
Args:
@@ -177,6 +178,8 @@ def evaluate_once(master,
slim.variables.GetVariablesToRestore() is used.
session_config: An instance of `tf.ConfigProto` that will be used to
configure the `Session`. If left as `None`, the default will be used.
+ hooks: A list of additional `SessionRunHook` objects to pass during the
+ evaluation.
Returns:
The value of `final_op` or `None` if `final_op` is `None`.
@@ -184,11 +187,13 @@ def evaluate_once(master,
if summary_op == _USE_DEFAULT:
summary_op = summary.merge_all()
- hooks = [evaluation.StopAfterNEvalsHook(num_evals),]
+ all_hooks = [evaluation.StopAfterNEvalsHook(num_evals),]
if summary_op is not None:
- hooks.append(evaluation.SummaryAtEndHook(
+ all_hooks.append(evaluation.SummaryAtEndHook(
log_dir=logdir, summary_op=summary_op, feed_dict=summary_op_feed_dict))
+ if hooks is not None:
+ all_hooks.extend(hooks)
saver = None
if variables_to_restore is not None:
@@ -203,7 +208,7 @@ def evaluate_once(master,
feed_dict=eval_op_feed_dict,
final_ops=final_op,
final_ops_feed_dict=final_op_feed_dict,
- hooks=hooks,
+ hooks=all_hooks,
config=session_config)
@@ -256,7 +261,7 @@ def evaluation_loop(master,
configure the `Session`. If left as `None`, the default will be used.
timeout: The maximum amount of time to wait between checkpoints. If left as
`None`, then the process will wait indefinitely.
- hooks: A list of additional SessionRunHook objects to pass during
+ hooks: A list of additional `SessionRunHook` objects to pass during
repeated evaluations.
Returns:
diff --git a/tensorflow/contrib/slim/python/slim/evaluation_test.py b/tensorflow/contrib/slim/python/slim/evaluation_test.py
index d9e0f54b72..870f504d10 100644
--- a/tensorflow/contrib/slim/python/slim/evaluation_test.py
+++ b/tensorflow/contrib/slim/python/slim/evaluation_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import glob
import os
+import shutil
import time
import numpy as np
@@ -29,6 +30,8 @@ from tensorflow.contrib.metrics.python.ops import metric_ops
from tensorflow.contrib.slim.python.slim import evaluation
from tensorflow.contrib.training.python.training import evaluation as evaluation_lib
from tensorflow.core.protobuf import saver_pb2
+from tensorflow.python.debug.lib import debug_data
+from tensorflow.python.debug.wrappers import hooks
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -230,11 +233,7 @@ class SingleEvaluationTest(test.TestCase):
with self.assertRaises(errors.NotFoundError):
evaluation.evaluate_once('', checkpoint_path, log_dir)
- def testRestoredModelPerformance(self):
- checkpoint_path = os.path.join(self.get_temp_dir(), 'model.ckpt')
- log_dir = os.path.join(self.get_temp_dir(), 'log_dir1/')
-
- # First, save out the current model to a checkpoint:
+ def _prepareCheckpoint(self, checkpoint_path):
init_op = control_flow_ops.group(variables.global_variables_initializer(),
variables.local_variables_initializer())
saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V1)
@@ -242,6 +241,13 @@ class SingleEvaluationTest(test.TestCase):
sess.run(init_op)
saver.save(sess, checkpoint_path)
+ def testRestoredModelPerformance(self):
+ checkpoint_path = os.path.join(self.get_temp_dir(), 'model.ckpt')
+ log_dir = os.path.join(self.get_temp_dir(), 'log_dir1/')
+
+ # First, save out the current model to a checkpoint:
+ self._prepareCheckpoint(checkpoint_path)
+
# Next, determine the metric to evaluate:
value_op, update_op = metric_ops.streaming_accuracy(self._predictions,
self._labels)
@@ -251,6 +257,36 @@ class SingleEvaluationTest(test.TestCase):
'', checkpoint_path, log_dir, eval_op=update_op, final_op=value_op)
self.assertAlmostEqual(accuracy_value, self._expected_accuracy)
+ def testAdditionalHooks(self):
+ checkpoint_path = os.path.join(self.get_temp_dir(), 'model.ckpt')
+ log_dir = os.path.join(self.get_temp_dir(), 'log_dir1/')
+
+ # First, save out the current model to a checkpoint:
+ self._prepareCheckpoint(checkpoint_path)
+
+ # Next, determine the metric to evaluate:
+ value_op, update_op = metric_ops.streaming_accuracy(self._predictions,
+ self._labels)
+
+ dumping_root = os.path.join(self.get_temp_dir(), 'tfdbg_dump_dir')
+ dumping_hook = hooks.DumpingDebugHook(dumping_root, log_usage=False)
+ try:
+ # Run the evaluation and verify the results:
+ accuracy_value = evaluation.evaluate_once(
+ '', checkpoint_path, log_dir, eval_op=update_op, final_op=value_op,
+ hooks=[dumping_hook])
+ self.assertAlmostEqual(accuracy_value, self._expected_accuracy)
+
+ dump = debug_data.DebugDumpDir(
+ glob.glob(os.path.join(dumping_root, 'run_*'))[0])
+ # Here we simply assert that the dumped data has been loaded and is
+ # non-empty. We do not care about the detailed model-internal tensors or
+ # their values.
+ self.assertTrue(dump.dumped_tensor_data)
+ finally:
+ if os.path.isdir(dumping_root):
+ shutil.rmtree(dumping_root)
+
if __name__ == '__main__':
test.main()